diff --git a/sea-orm-macros/src/attributes.rs b/sea-orm-macros/src/attributes.rs index 39cccaa8..ea9ae9ba 100644 --- a/sea-orm-macros/src/attributes.rs +++ b/sea-orm-macros/src/attributes.rs @@ -18,15 +18,8 @@ pub mod field_attr { #[derive(Default, FromAttributes)] pub struct SeaOrm { - pub auto_increment: Option, pub belongs_to: Option, - pub column_type: Option, - pub column_type_raw: Option, pub from: Option, - pub indexed: Option<()>, - pub null: Option<()>, - pub primary_key: Option<()>, pub to: Option, - pub unique: Option<()>, } } diff --git a/sea-orm-macros/src/derives/mod.rs b/sea-orm-macros/src/derives/mod.rs index 4a62e7de..8dfb7ac1 100644 --- a/sea-orm-macros/src/derives/mod.rs +++ b/sea-orm-macros/src/derives/mod.rs @@ -6,6 +6,7 @@ mod entity_model; mod from_query_result; mod model; mod primary_key; +mod relation; pub use active_model::*; pub use active_model_behavior::*; @@ -15,3 +16,4 @@ pub use entity_model::*; pub use from_query_result::*; pub use model::*; pub use primary_key::*; +pub use relation::*; diff --git a/sea-orm-macros/src/derives/relation.rs b/sea-orm-macros/src/derives/relation.rs new file mode 100644 index 00000000..0f69c1de --- /dev/null +++ b/sea-orm-macros/src/derives/relation.rs @@ -0,0 +1,117 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote, quote_spanned}; + +use crate::attributes::{derive_attr, field_attr}; + +enum Error { + InputNotEnum, + Syn(syn::Error), +} + +struct DeriveRelation { + entity_ident: syn::Ident, + ident: syn::Ident, + variants: syn::punctuated::Punctuated, +} + +impl DeriveRelation { + fn new(input: syn::DeriveInput) -> Result { + let variants = match input.data { + syn::Data::Enum(syn::DataEnum { variants, .. }) => variants, + _ => return Err(Error::InputNotEnum), + }; + + let sea_attr = derive_attr::SeaOrm::try_from_attributes(&input.attrs) + .map_err(Error::Syn)? + .unwrap_or_default(); + + let ident = input.ident; + let entity_ident = sea_attr.entity.unwrap_or_else(|| format_ident!("Entity")); + + Ok(DeriveRelation { + entity_ident, + ident, + variants, + }) + } + + fn expand(&self) -> syn::Result { + let expanded_impl_relation_trait = self.impl_relation_trait()?; + + Ok(expanded_impl_relation_trait) + } + + fn impl_relation_trait(&self) -> syn::Result { + let ident = &self.ident; + let entity_ident = &self.entity_ident; + let no_relation_def_msg = format!("No RelationDef for {}", ident); + + let variant_relation_defs: Vec = self + .variants + .iter() + .map(|variant| { + let variant_ident = &variant.ident; + let attr = field_attr::SeaOrm::from_attributes(&variant.attrs)?; + let belongs_to = attr + .belongs_to + .as_ref() + .map(Self::parse_lit_string) + .ok_or_else(|| { + syn::Error::new_spanned(variant, "Missing attribute 'belongs_to'") + })??; + let from = attr + .from + .as_ref() + .map(Self::parse_lit_string) + .ok_or_else(|| { + syn::Error::new_spanned(variant, "Missing attribute 'from'") + })??; + let to = attr + .to + .as_ref() + .map(Self::parse_lit_string) + .ok_or_else(|| syn::Error::new_spanned(variant, "Missing attribute 'to'"))??; + + Result::<_, syn::Error>::Ok(quote!( + Self::#variant_ident => #entity_ident::belongs_to(#belongs_to) + .from(#from) + .to(#to) + .into() + )) + }) + .collect::, _>>()?; + + Ok(quote!( + impl sea_orm::entity::RelationTrait for #ident { + fn def(&self) -> sea_orm::entity::RelationDef { + match self { + #( #variant_relation_defs, )* + _ => panic!(#no_relation_def_msg) + } + } + } + )) + } + + fn parse_lit_string(lit: &syn::Lit) -> syn::Result { + match lit { + syn::Lit::Str(lit_str) => lit_str + .value() + .parse() + .map_err(|_| syn::Error::new_spanned(lit, "attribute not valid")), + _ => Err(syn::Error::new_spanned(lit, "attribute must be a string")), + } + } +} + +pub fn expand_derive_relation(input: syn::DeriveInput) -> syn::Result { + let ident_span = input.ident.span(); + + match DeriveRelation::new(input) { + Ok(model) => model.expand(), + Err(Error::InputNotEnum) => Ok(quote_spanned! { + ident_span => compile_error!("you can only derive DeriveRelation on enums"); + }), + Err(Error::Syn(err)) => Err(err), + } +} diff --git a/sea-orm-macros/src/lib.rs b/sea-orm-macros/src/lib.rs index 22b66661..f2103647 100644 --- a/sea-orm-macros/src/lib.rs +++ b/sea-orm-macros/src/lib.rs @@ -98,6 +98,14 @@ pub fn derive_from_query_result(input: TokenStream) -> TokenStream { } } +#[proc_macro_derive(DeriveRelation, attributes(sea_orm))] +pub fn derive_relation(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + derives::expand_derive_relation(input) + .unwrap_or_else(Error::into_compile_error) + .into() +} + #[doc(hidden)] #[proc_macro_attribute] pub fn test(_: TokenStream, input: TokenStream) -> TokenStream {