diff --git a/sea-orm-macros/src/derives/active_enum.rs b/sea-orm-macros/src/derives/active_enum.rs new file mode 100644 index 00000000..6f6b7fbc --- /dev/null +++ b/sea-orm-macros/src/derives/active_enum.rs @@ -0,0 +1,219 @@ +use proc_macro2::TokenStream; +use quote::{quote, quote_spanned}; +use syn::{punctuated::Punctuated, token::Comma, Lit, Meta}; + +enum Error { + InputNotEnum, + Syn(syn::Error), +} + +struct ActiveEnum { + ident: syn::Ident, + rs_type: TokenStream, + db_type: TokenStream, + variants: syn::punctuated::Punctuated, +} + +impl ActiveEnum { + fn new(input: syn::DeriveInput) -> Result { + let ident = input.ident; + + let mut rs_type = None; + let mut db_type = None; + for attr in input.attrs.iter() { + if let Some(ident) = attr.path.get_ident() { + if ident != "sea_orm" { + continue; + } + } else { + continue; + } + if let Ok(list) = attr.parse_args_with(Punctuated::::parse_terminated) { + for meta in list.iter() { + if let Meta::NameValue(nv) = meta { + if let Some(name) = nv.path.get_ident() { + if name == "rs_type" { + if let Lit::Str(litstr) = &nv.lit { + rs_type = syn::parse_str::(&litstr.value()).ok(); + } + } else if name == "db_type" { + if let Lit::Str(litstr) = &nv.lit { + db_type = syn::parse_str::(&litstr.value()).ok(); + } + } + } + } + } + } + } + let rs_type = rs_type.expect("Missing rs_type"); + let db_type = db_type.expect("Missing db_type"); + + let variants = match input.data { + syn::Data::Enum(syn::DataEnum { variants, .. }) => variants, + _ => return Err(Error::InputNotEnum), + }; + + Ok(ActiveEnum { + ident, + rs_type, + db_type, + variants, + }) + } + + fn expand(&self) -> syn::Result { + let expanded_impl_active_enum = self.impl_active_enum(); + + Ok(expanded_impl_active_enum) + } + + fn impl_active_enum(&self) -> TokenStream { + let Self { + ident, + rs_type, + db_type, + variants, + } = self; + + let variant_idents: Vec = variants + .iter() + .map(|variant| variant.ident.clone()) + .collect(); + + let mut is_string = false; + + let variant_values: Vec = variants + .iter() + .map(|variant| { + let mut string_value = None; + let mut num_value = None; + for attr in variant.attrs.iter() { + if let Some(ident) = attr.path.get_ident() { + if ident != "sea_orm" { + continue; + } + } else { + continue; + } + if let Ok(list) = + attr.parse_args_with(Punctuated::::parse_terminated) + { + for meta in list.iter() { + if let Meta::NameValue(nv) = meta { + if let Some(name) = nv.path.get_ident() { + if name == "string_value" { + if let Lit::Str(litstr) = &nv.lit { + string_value = Some(litstr.value()); + } + } else if name == "num_value" { + if let Lit::Int(litstr) = &nv.lit { + num_value = litstr.base10_parse::().ok(); + } + } + } + } + } + } + } + + if let Some(string_value) = string_value { + is_string = true; + quote! { #string_value } + } else if let Some(num_value) = num_value { + quote! { #num_value } + } else { + panic!("Either string_value or num_value should be specified") + } + }) + .collect(); + + let val = if is_string { + quote! { v.as_ref() } + } else { + quote! { v } + }; + + quote!( + #[automatically_derived] + impl sea_orm::ActiveEnum for #ident { + type Value = #rs_type; + + fn to_value(&self) -> Self::Value { + match self { + #( Self::#variant_idents => #variant_values, )* + } + .to_owned() + } + + fn try_from_value(v: &Self::Value) -> Result { + match #val { + #( #variant_values => Ok(Self::#variant_idents), )* + _ => Err(sea_orm::DbErr::Query(format!( + "unexpected value for {} enum: {}", + stringify!(#ident), + v + ))), + } + } + + fn db_type() -> sea_orm::ColumnDef { + sea_orm::ColumnType::#db_type.def() + } + } + + #[automatically_derived] + impl Into for #ident { + fn into(self) -> sea_query::Value { + ::to_value(&self).into() + } + } + + #[automatically_derived] + impl sea_orm::TryGetable for #ident { + fn try_get(res: &sea_orm::QueryResult, pre: &str, col: &str) -> Result { + let value = <::Value as sea_orm::TryGetable>::try_get(res, pre, col)?; + ::try_from_value(&value).map_err(|e| sea_orm::TryGetError::DbErr(e)) + } + } + + #[automatically_derived] + impl sea_query::ValueType for #ident { + fn try_from(v: sea_query::Value) -> Result { + let value = <::Value as sea_query::ValueType>::try_from(v)?; + ::try_from_value(&value).map_err(|_| sea_query::ValueTypeErr) + } + + fn type_name() -> String { + <::Value as sea_query::ValueType>::type_name() + } + + fn column_type() -> sea_query::ColumnType { + ::db_type() + .get_column_type() + .to_owned() + .into() + } + } + + #[automatically_derived] + impl sea_query::Nullable for #ident { + fn null() -> sea_query::Value { + <::Value as sea_query::Nullable>::null() + } + } + ) + } +} + +pub fn expand_derive_active_enum(input: syn::DeriveInput) -> syn::Result { + let ident_span = input.ident.span(); + + match ActiveEnum::new(input) { + Ok(model) => model.expand(), + Err(Error::InputNotEnum) => Ok(quote_spanned! { + ident_span => compile_error!("you can only derive ActiveEnum on enums"); + }), + Err(Error::Syn(err)) => Err(err), + } +} diff --git a/sea-orm-macros/src/derives/entity_model.rs b/sea-orm-macros/src/derives/entity_model.rs index 3b0dac26..0963c2ca 100644 --- a/sea-orm-macros/src/derives/entity_model.rs +++ b/sea-orm-macros/src/derives/entity_model.rs @@ -1,7 +1,7 @@ use crate::util::{escape_rust_keyword, trim_starting_raw_identifier}; use convert_case::{Case, Casing}; use proc_macro2::{Ident, Span, TokenStream}; -use quote::quote; +use quote::{format_ident, quote, quote_spanned}; use syn::{ parse::Error, punctuated::Punctuated, spanned::Spanned, token::Comma, Attribute, Data, Fields, Lit, Meta, @@ -192,8 +192,8 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res primary_keys.push(quote! { #field_name }); } - let field_type = match sql_type { - Some(t) => t, + let col_type = match sql_type { + Some(t) => quote! { sea_orm::prelude::ColumnType::#t.def() }, None => { let field_type = &field.ty; let temp = quote! { #field_type } @@ -205,7 +205,7 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res } else { temp.as_str() }; - match temp { + let col_type = match temp { "char" => quote! { Char(None) }, "String" | "&str" => quote! { String(None) }, "u8" | "i8" => quote! { TinyInteger }, @@ -228,16 +228,24 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res "Decimal" => quote! { Decimal(None) }, "Vec" => quote! { Binary }, _ => { - return Err(Error::new( - field.span(), - format!("unrecognized type {}", temp), - )) + // Assumed it's ActiveEnum if none of the above type matches + quote! {} } + }; + if col_type.is_empty() { + let field_span = field.span(); + let ty = format_ident!("{}", temp); + let def = quote_spanned! { field_span => { + <#ty as ActiveEnum>::db_type() + }}; + quote! { #def } + } else { + quote! { sea_orm::prelude::ColumnType::#col_type.def() } } } }; - let mut match_row = quote! { Self::#field_name => sea_orm::prelude::ColumnType::#field_type.def() }; + let mut match_row = quote! { Self::#field_name => #col_type }; if nullable { match_row = quote! { #match_row.nullable() }; } diff --git a/sea-orm-macros/src/derives/mod.rs b/sea-orm-macros/src/derives/mod.rs index 6ba19a92..36b9f669 100644 --- a/sea-orm-macros/src/derives/mod.rs +++ b/sea-orm-macros/src/derives/mod.rs @@ -1,3 +1,4 @@ +mod active_enum; mod active_model; mod active_model_behavior; mod column; @@ -9,6 +10,7 @@ mod model; mod primary_key; mod relation; +pub use active_enum::*; pub use active_model::*; pub use active_model_behavior::*; pub use column::*; diff --git a/sea-orm-macros/src/lib.rs b/sea-orm-macros/src/lib.rs index cf8c2f3c..ccba2fab 100644 --- a/sea-orm-macros/src/lib.rs +++ b/sea-orm-macros/src/lib.rs @@ -102,6 +102,15 @@ pub fn derive_active_model_behavior(input: TokenStream) -> TokenStream { } } +#[proc_macro_derive(DeriveActiveEnum, attributes(sea_orm))] +pub fn derive_active_enum(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + match derives::expand_derive_active_enum(input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} + #[proc_macro_derive(FromQueryResult)] pub fn derive_from_query_result(input: TokenStream) -> TokenStream { let DeriveInput { ident, data, .. } = parse_macro_input!(input); diff --git a/src/entity/active_enum.rs b/src/entity/active_enum.rs new file mode 100644 index 00000000..68c58c64 --- /dev/null +++ b/src/entity/active_enum.rs @@ -0,0 +1,184 @@ +use crate::{ColumnDef, DbErr, TryGetable}; +use sea_query::{Nullable, Value, ValueType}; +use std::fmt::Debug; + +pub trait ActiveEnum: Sized { + type Value: Sized + Send + Debug + PartialEq + Into + ValueType + Nullable + TryGetable; + + fn to_value(&self) -> Self::Value; + + fn try_from_value(v: &Self::Value) -> Result; + + fn db_type() -> ColumnDef; +} + +#[cfg(test)] +mod tests { + use crate as sea_orm; + use crate::{entity::prelude::*, *}; + use pretty_assertions::assert_eq; + + #[test] + fn active_enum_1() { + #[derive(Debug, PartialEq)] + pub enum Category { + Big, + Small, + } + + impl ActiveEnum for Category { + type Value = String; + + fn to_value(&self) -> Self::Value { + match self { + Self::Big => "B", + Self::Small => "S", + } + .to_owned() + } + + fn try_from_value(v: &Self::Value) -> Result { + match v.as_ref() { + "B" => Ok(Self::Big), + "S" => Ok(Self::Small), + _ => Err(DbErr::Query(format!( + "unexpected value for Category enum: {}", + v + ))), + } + } + + fn db_type() -> ColumnDef { + ColumnType::String(Some(1)).def() + } + } + + #[derive(Debug, PartialEq, DeriveActiveEnum)] + #[sea_orm(rs_type = "String", db_type = "String(Some(1))")] + pub enum DeriveCategory { + #[sea_orm(string_value = "B")] + Big, + #[sea_orm(string_value = "S")] + Small, + } + + assert_eq!(Category::Big.to_value(), "B".to_owned()); + assert_eq!(Category::Small.to_value(), "S".to_owned()); + assert_eq!(DeriveCategory::Big.to_value(), "B".to_owned()); + assert_eq!(DeriveCategory::Small.to_value(), "S".to_owned()); + + assert_eq!( + Category::try_from_value(&"A".to_owned()).err(), + Some(DbErr::Query( + "unexpected value for Category enum: A".to_owned() + )) + ); + assert_eq!( + Category::try_from_value(&"B".to_owned()).ok(), + Some(Category::Big) + ); + assert_eq!( + Category::try_from_value(&"S".to_owned()).ok(), + Some(Category::Small) + ); + assert_eq!( + DeriveCategory::try_from_value(&"A".to_owned()).err(), + Some(DbErr::Query( + "unexpected value for DeriveCategory enum: A".to_owned() + )) + ); + assert_eq!( + DeriveCategory::try_from_value(&"B".to_owned()).ok(), + Some(DeriveCategory::Big) + ); + assert_eq!( + DeriveCategory::try_from_value(&"S".to_owned()).ok(), + Some(DeriveCategory::Small) + ); + + assert_eq!(Category::db_type(), ColumnType::String(Some(1)).def()); + assert_eq!(DeriveCategory::db_type(), ColumnType::String(Some(1)).def()); + } + + #[test] + fn active_enum_2() { + #[derive(Debug, PartialEq)] + pub enum Category { + Big, + Small, + } + + impl ActiveEnum for Category { + type Value = i32; // FIXME: only support i32 for now + + fn to_value(&self) -> Self::Value { + match self { + Self::Big => 1, + Self::Small => 0, + } + .to_owned() + } + + fn try_from_value(v: &Self::Value) -> Result { + match v { + 1 => Ok(Self::Big), + 0 => Ok(Self::Small), + _ => Err(DbErr::Query(format!( + "unexpected value for Category enum: {}", + v + ))), + } + } + + fn db_type() -> ColumnDef { + ColumnType::Integer.def() + } + } + + #[derive(Debug, PartialEq, DeriveActiveEnum)] + #[sea_orm(rs_type = "i32", db_type = "Integer")] + pub enum DeriveCategory { + #[sea_orm(num_value = 1)] + Big, + #[sea_orm(num_value = 0)] + Small, + } + + assert_eq!(Category::Big.to_value(), 1); + assert_eq!(Category::Small.to_value(), 0); + assert_eq!(DeriveCategory::Big.to_value(), 1); + assert_eq!(DeriveCategory::Small.to_value(), 0); + + assert_eq!( + Category::try_from_value(&2).err(), + Some(DbErr::Query( + "unexpected value for Category enum: 2".to_owned() + )) + ); + assert_eq!( + Category::try_from_value(&1).ok(), + Some(Category::Big) + ); + assert_eq!( + Category::try_from_value(&0).ok(), + Some(Category::Small) + ); + assert_eq!( + DeriveCategory::try_from_value(&2).err(), + Some(DbErr::Query( + "unexpected value for DeriveCategory enum: 2".to_owned() + )) + ); + assert_eq!( + DeriveCategory::try_from_value(&1).ok(), + Some(DeriveCategory::Big) + ); + assert_eq!( + DeriveCategory::try_from_value(&0).ok(), + Some(DeriveCategory::Small) + ); + + assert_eq!(Category::db_type(), ColumnType::Integer.def()); + assert_eq!(DeriveCategory::db_type(), ColumnType::Integer.def()); + } +} diff --git a/src/entity/column.rs b/src/entity/column.rs index a27215e5..25ed8447 100644 --- a/src/entity/column.rs +++ b/src/entity/column.rs @@ -262,6 +262,10 @@ impl ColumnDef { self.indexed = true; self } + + pub fn get_column_type(&self) -> &ColumnType { + &self.col_type + } } impl From for sea_query::ColumnType { diff --git a/src/entity/mod.rs b/src/entity/mod.rs index c6d15052..7e8b7830 100644 --- a/src/entity/mod.rs +++ b/src/entity/mod.rs @@ -1,3 +1,4 @@ +mod active_enum; mod active_model; mod base_entity; mod column; @@ -8,6 +9,7 @@ pub mod prelude; mod primary_key; mod relation; +pub use active_enum::*; pub use active_model::*; pub use base_entity::*; pub use column::*; diff --git a/src/entity/prelude.rs b/src/entity/prelude.rs index ac4d50fa..98f89c92 100644 --- a/src/entity/prelude.rs +++ b/src/entity/prelude.rs @@ -1,14 +1,15 @@ pub use crate::{ - error::*, ActiveModelBehavior, ActiveModelTrait, ColumnDef, ColumnTrait, ColumnType, - DatabaseConnection, DbConn, EntityName, EntityTrait, EnumIter, ForeignKeyAction, Iden, - IdenStatic, Linked, ModelTrait, PrimaryKeyToColumn, PrimaryKeyTrait, QueryFilter, QueryResult, - Related, RelationDef, RelationTrait, Select, Value, + error::*, ActiveEnum, ActiveModelBehavior, ActiveModelTrait, ColumnDef, ColumnTrait, + ColumnType, DatabaseConnection, DbConn, EntityName, EntityTrait, EnumIter, ForeignKeyAction, + Iden, IdenStatic, Linked, ModelTrait, PrimaryKeyToColumn, PrimaryKeyTrait, QueryFilter, + QueryResult, Related, RelationDef, RelationTrait, Select, Value, }; #[cfg(feature = "macros")] pub use crate::{ - DeriveActiveModel, DeriveActiveModelBehavior, DeriveColumn, DeriveCustomColumn, DeriveEntity, - DeriveEntityModel, DeriveIntoActiveModel, DeriveModel, DerivePrimaryKey, DeriveRelation, + DeriveActiveEnum, DeriveActiveModel, DeriveActiveModelBehavior, DeriveColumn, + DeriveCustomColumn, DeriveEntity, DeriveEntityModel, DeriveIntoActiveModel, DeriveModel, + DerivePrimaryKey, DeriveRelation, }; #[cfg(feature = "with-json")] diff --git a/src/lib.rs b/src/lib.rs index 745692b9..d40db473 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -288,9 +288,9 @@ pub use schema::*; #[cfg(feature = "macros")] pub use sea_orm_macros::{ - DeriveActiveModel, DeriveActiveModelBehavior, DeriveColumn, DeriveCustomColumn, DeriveEntity, - DeriveEntityModel, DeriveIntoActiveModel, DeriveModel, DerivePrimaryKey, DeriveRelation, - FromQueryResult, + DeriveActiveEnum, DeriveActiveModel, DeriveActiveModelBehavior, DeriveColumn, + DeriveCustomColumn, DeriveEntity, DeriveEntityModel, DeriveIntoActiveModel, DeriveModel, + DerivePrimaryKey, DeriveRelation, FromQueryResult, }; pub use sea_query; diff --git a/tests/active_enum_tests.rs b/tests/active_enum_tests.rs new file mode 100644 index 00000000..90472fde --- /dev/null +++ b/tests/active_enum_tests.rs @@ -0,0 +1,39 @@ +pub mod common; + +pub use common::{features::*, setup::*, TestContext}; +use sea_orm::{entity::prelude::*, entity::*, DatabaseConnection}; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +async fn main() -> Result<(), DbErr> { + let ctx = TestContext::new("active_enum_tests").await; + create_tables(&ctx.db).await?; + insert_active_enum(&ctx.db).await?; + ctx.delete().await; + + Ok(()) +} + +pub async fn insert_active_enum(db: &DatabaseConnection) -> Result<(), DbErr> { + active_enum::ActiveModel { + category: Set(active_enum::Category::Big), + ..Default::default() + } + .insert(db) + .await?; + + assert_eq!( + active_enum::Entity::find().one(db).await?.unwrap(), + active_enum::Model { + id: 1, + category: active_enum::Category::Big, + category_opt: None, + } + ); + + Ok(()) +} diff --git a/tests/common/features/active_enum.rs b/tests/common/features/active_enum.rs new file mode 100644 index 00000000..1b68d546 --- /dev/null +++ b/tests/common/features/active_enum.rs @@ -0,0 +1,87 @@ +use sea_orm::{entity::prelude::*, TryGetError, TryGetable}; +use sea_query::{Nullable, ValueType, ValueTypeErr}; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "active_enum")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub category: Category, + pub category_opt: Option, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Debug, Clone, PartialEq)] +pub enum Category { + Big, + Small, +} + +impl ActiveEnum for Category { + type Value = String; + + fn to_value(&self) -> Self::Value { + match self { + Self::Big => "B", + Self::Small => "S", + } + .to_owned() + } + + fn try_from_value(v: &Self::Value) -> Result { + match v.as_ref() { + "B" => Ok(Self::Big), + "S" => Ok(Self::Small), + _ => Err(DbErr::Query(format!( + "unexpected value for {} enum: {}", + stringify!(Category), + v + ))), + } + } + + fn db_type() -> ColumnDef { + ColumnType::String(Some(1)).def() + } +} + +impl Into for Category { + fn into(self) -> Value { + self.to_value().into() + } +} + +impl TryGetable for Category { + fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result { + let value = <::Value as TryGetable>::try_get(res, pre, col)?; + Self::try_from_value(&value).map_err(|e| TryGetError::DbErr(e)) + } +} + +impl ValueType for Category { + fn try_from(v: Value) -> Result { + let value = <::Value as ValueType>::try_from(v)?; + Self::try_from_value(&value).map_err(|_| ValueTypeErr) + } + + fn type_name() -> String { + <::Value as ValueType>::type_name() + } + + fn column_type() -> sea_query::ColumnType { + ::db_type() + .get_column_type() + .to_owned() + .into() + } +} + +impl Nullable for Category { + fn null() -> Value { + <::Value as Nullable>::null() + } +} diff --git a/tests/common/features/mod.rs b/tests/common/features/mod.rs index ff716f01..f0db35b7 100644 --- a/tests/common/features/mod.rs +++ b/tests/common/features/mod.rs @@ -1,8 +1,10 @@ +pub mod active_enum; pub mod applog; pub mod metadata; pub mod repository; pub mod schema; +pub use active_enum::Entity as ActiveEnum; pub use applog::Entity as Applog; pub use metadata::Entity as Metadata; pub use repository::Entity as Repository; diff --git a/tests/common/features/schema.rs b/tests/common/features/schema.rs index 44b011c5..5292f4fc 100644 --- a/tests/common/features/schema.rs +++ b/tests/common/features/schema.rs @@ -9,6 +9,7 @@ pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> { create_log_table(db).await?; create_metadata_table(db).await?; create_repository_table(db).await?; + create_active_enum_table(db).await?; Ok(()) } @@ -75,3 +76,24 @@ pub async fn create_repository_table(db: &DbConn) -> Result { create_table(db, &stmt, Repository).await } + +pub async fn create_active_enum_table(db: &DbConn) -> Result { + let stmt = sea_query::Table::create() + .table(active_enum::Entity) + .col( + ColumnDef::new(active_enum::Column::Id) + .integer() + .not_null() + .primary_key() + .auto_increment(), + ) + .col( + ColumnDef::new(active_enum::Column::Category) + .string_len(1) + .not_null(), + ) + .col(ColumnDef::new(active_enum::Column::CategoryOpt).string_len(1)) + .to_owned(); + + create_table(db, &stmt, ActiveEnum).await +}