diff --git a/sea-orm-macros/src/derives/column.rs b/sea-orm-macros/src/derives/column.rs index bf40f58c..034e966d 100644 --- a/sea-orm-macros/src/derives/column.rs +++ b/sea-orm-macros/src/derives/column.rs @@ -3,7 +3,7 @@ use proc_macro2::{Ident, TokenStream}; use quote::{quote, quote_spanned}; use syn::{Data, DataEnum, Fields, Variant}; -pub fn expand_derive_column(ident: &Ident, data: &Data) -> syn::Result { +pub fn impl_default_as_str(ident: &Ident, data: &Data) -> syn::Result { let variants = match data { syn::Data::Enum(DataEnum { variants, .. }) => variants, _ => { @@ -30,13 +30,9 @@ pub fn expand_derive_column(ident: &Ident, data: &Data) -> syn::Result &str { + impl #ident { + fn default_as_str(&self) -> &str { match self { #(Self::#variant => #name),* } @@ -45,8 +41,26 @@ pub fn expand_derive_column(ident: &Ident, data: &Data) -> syn::Result syn::Result { +pub fn expand_derive_column(ident: &Ident, data: &Data) -> syn::Result { + let impl_iden = expand_derive_custom_column(ident, data)?; + Ok(quote!( + #impl_iden + + impl sea_orm::IdenStatic for #ident { + fn as_str(&self) -> &str { + self.default_as_str() + } + } + )) +} + +pub fn expand_derive_custom_column(ident: &Ident, data: &Data) -> syn::Result { + let impl_default_as_str = impl_default_as_str(ident, data)?; + + Ok(quote!( + #impl_default_as_str + impl sea_orm::Iden for #ident { fn unquoted(&self, s: &mut dyn std::fmt::Write) { write!(s, "{}", self.as_str()).unwrap(); diff --git a/sea-orm-macros/src/lib.rs b/sea-orm-macros/src/lib.rs index 7e3f4761..2ac0ac75 100644 --- a/sea-orm-macros/src/lib.rs +++ b/sea-orm-macros/src/lib.rs @@ -37,9 +37,9 @@ pub fn derive_column(input: TokenStream) -> TokenStream { #[proc_macro_derive(DeriveCustomColumn)] pub fn derive_custom_column(input: TokenStream) -> TokenStream { - let DeriveInput { ident, .. } = parse_macro_input!(input); + let DeriveInput { ident, data, .. } = parse_macro_input!(input); - match derives::expand_derive_custom_column(&ident) { + match derives::expand_derive_custom_column(&ident, &data) { Ok(ts) => ts.into(), Err(e) => e.to_compile_error().into(), } diff --git a/src/tests_cfg/filling.rs b/src/tests_cfg/filling.rs index 838f96b9..b439af7b 100644 --- a/src/tests_cfg/filling.rs +++ b/src/tests_cfg/filling.rs @@ -27,8 +27,10 @@ pub enum Column { impl IdenStatic for Column { fn as_str(&self) -> &str { match self { + // Override column names Self::Id => "id", - Self::Name => "name", + // Leave all other columns using default snake-case values + _ => self.default_as_str(), } } }