diff --git a/sea-orm-macros/src/derives/active_model.rs b/sea-orm-macros/src/derives/active_model.rs index 3a96860e..2227f09b 100644 --- a/sea-orm-macros/src/derives/active_model.rs +++ b/sea-orm-macros/src/derives/active_model.rs @@ -2,7 +2,7 @@ use crate::util::field_not_ignored; use heck::CamelCase; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote, quote_spanned}; -use syn::{Data, DataStruct, Field, Fields, Type}; +use syn::{punctuated::Punctuated, token::Comma, Data, DataStruct, Field, Fields, Lit, Meta, Type}; pub fn expand_derive_active_model(ident: Ident, data: Data) -> syn::Result { let fields = match data { @@ -28,7 +28,36 @@ pub fn expand_derive_active_model(ident: Ident, data: Data) -> syn::Result = fields .clone() .into_iter() - .map(|Field { ident, .. }| format_ident!("{}", ident.unwrap().to_string().to_camel_case())) + .map(|field| { + let mut ident = format_ident!( + "{}", + field.ident.as_ref().unwrap().to_string().to_camel_case() + ); + for attr in field.attrs.iter() { + if let Some(ident) = attr.path.get_ident() { + if ident != "sea_orm" { + continue; + } + } else { + continue; + } + if let Ok(list) = attr.parse_args_with(Punctuated::::parse_terminated) + { + for meta in list.iter() { + if let Meta::NameValue(nv) = meta { + if let Some(name) = nv.path.get_ident() { + if name == "enum_name" { + if let Lit::Str(litstr) = &nv.lit { + ident = syn::parse_str(&litstr.value()).unwrap(); + } + } + } + } + } + } + } + ident + }) .collect(); let ty: Vec = fields.into_iter().map(|Field { ty, .. }| ty).collect(); diff --git a/sea-orm-macros/src/derives/column.rs b/sea-orm-macros/src/derives/column.rs index 16f9bb22..5fc471e9 100644 --- a/sea-orm-macros/src/derives/column.rs +++ b/sea-orm-macros/src/derives/column.rs @@ -1,7 +1,7 @@ use heck::{MixedCase, SnakeCase}; use proc_macro2::{Ident, TokenStream}; use quote::{quote, quote_spanned}; -use syn::{Data, DataEnum, Fields, Variant}; +use syn::{punctuated::Punctuated, token::Comma, Data, DataEnum, Fields, Lit, Meta, Variant}; pub fn impl_default_as_str(ident: &Ident, data: &Data) -> syn::Result { let variants = match data { @@ -25,8 +25,31 @@ pub fn impl_default_as_str(ident: &Ident, data: &Data) -> syn::Result = variants .iter() .map(|v| { - let ident = v.ident.to_string().to_snake_case(); - quote! { #ident } + let mut column_name = v.ident.to_string().to_snake_case(); + for attr in v.attrs.iter() { + if let Some(ident) = attr.path.get_ident() { + if ident != "sea_orm" { + continue; + } + } else { + continue; + } + if let Ok(list) = attr.parse_args_with(Punctuated::::parse_terminated) + { + for meta in list.iter() { + if let Meta::NameValue(nv) = meta { + if let Some(name) = nv.path.get_ident() { + if name == "column_name" { + if let Lit::Str(litstr) = &nv.lit { + column_name = litstr.value(); + } + } + } + } + } + } + } + quote! { #column_name } }) .collect(); diff --git a/sea-orm-macros/src/derives/entity_model.rs b/sea-orm-macros/src/derives/entity_model.rs index f0772763..e7a6a23a 100644 --- a/sea-orm-macros/src/derives/entity_model.rs +++ b/sea-orm-macros/src/derives/entity_model.rs @@ -60,9 +60,8 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res if let Fields::Named(fields) = item_struct.fields { for field in fields.named { if let Some(ident) = &field.ident { - let field_name = + let mut field_name = Ident::new(&ident.to_string().to_case(Case::Pascal), Span::call_site()); - columns_enum.push(quote! { #field_name }); let mut nullable = false; let mut default_value = None; @@ -71,7 +70,10 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res let mut ignore = false; let mut unique = false; let mut sql_type = None; - // search for #[sea_orm(primary_key, auto_increment = false, column_type = "String(Some(255))", default_value = "new user", default_expr = "gen_random_uuid()", nullable, indexed, unique)] + let mut column_name = None; + let mut enum_name = None; + let mut is_primary_key = false; + // search for #[sea_orm(primary_key, auto_increment = false, column_type = "String(Some(255))", default_value = "new user", default_expr = "gen_random_uuid()", column_name = "name", enum_name = "Name", nullable, indexed, unique)] for attr in field.attrs.iter() { if let Some(ident) = attr.path.get_ident() { if ident != "sea_orm" { @@ -116,6 +118,26 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res default_value = Some(nv.lit.to_owned()); } else if name == "default_expr" { default_expr = Some(nv.lit.to_owned()); + } else if name == "column_name" { + if let Lit::Str(litstr) = &nv.lit { + column_name = Some(litstr.value()); + } else { + return Err(Error::new( + field.span(), + format!("Invalid column_name {:?}", nv.lit), + )); + } + } else if name == "enum_name" { + if let Lit::Str(litstr) = &nv.lit { + let ty: Ident = + syn::parse_str(&litstr.value())?; + enum_name = Some(ty); + } else { + return Err(Error::new( + field.span(), + format!("Invalid enum_name {:?}", nv.lit), + )); + } } } } @@ -125,7 +147,7 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res ignore = true; break; } else if name == "primary_key" { - primary_keys.push(quote! { #field_name }); + is_primary_key = true; primary_key_types.push(field.ty.clone()); } else if name == "nullable" { nullable = true; @@ -142,9 +164,27 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res } } + if let Some(enum_name) = enum_name { + field_name = enum_name; + } + if ignore { - columns_enum.pop(); continue; + } else { + let variant_attrs = match &column_name { + Some(column_name) => quote! { + #[sea_orm(column_name = #column_name)] + }, + None => quote! {}, + }; + columns_enum.push(quote! { + #variant_attrs + #field_name + }); + } + + if is_primary_key { + primary_keys.push(quote! { #field_name }); } let field_type = match sql_type { diff --git a/sea-orm-macros/src/derives/model.rs b/sea-orm-macros/src/derives/model.rs index 9d619991..a43b487f 100644 --- a/sea-orm-macros/src/derives/model.rs +++ b/sea-orm-macros/src/derives/model.rs @@ -3,7 +3,7 @@ use heck::CamelCase; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; use std::iter::FromIterator; -use syn::Ident; +use syn::{punctuated::Punctuated, token::Comma, Ident, Lit, Meta}; enum Error { InputNotStruct, @@ -43,10 +43,35 @@ impl DeriveModel { let column_idents = fields .iter() .map(|field| { - format_ident!( + let mut ident = format_ident!( "{}", field.ident.as_ref().unwrap().to_string().to_camel_case() - ) + ); + for attr in field.attrs.iter() { + if let Some(ident) = attr.path.get_ident() { + if ident != "sea_orm" { + continue; + } + } else { + continue; + } + if let Ok(list) = + attr.parse_args_with(Punctuated::::parse_terminated) + { + for meta in list.iter() { + if let Meta::NameValue(nv) = meta { + if let Some(name) = nv.path.get_ident() { + if name == "enum_name" { + if let Lit::Str(litstr) = &nv.lit { + ident = syn::parse_str(&litstr.value()).unwrap(); + } + } + } + } + } + } + } + ident }) .collect(); diff --git a/sea-orm-macros/src/lib.rs b/sea-orm-macros/src/lib.rs index 629c5c18..c3f8a50f 100644 --- a/sea-orm-macros/src/lib.rs +++ b/sea-orm-macros/src/lib.rs @@ -46,7 +46,7 @@ pub fn derive_primary_key(input: TokenStream) -> TokenStream { } } -#[proc_macro_derive(DeriveColumn)] +#[proc_macro_derive(DeriveColumn, attributes(sea_orm))] pub fn derive_column(input: TokenStream) -> TokenStream { let DeriveInput { ident, data, .. } = parse_macro_input!(input);