diff --git a/sea-orm-macros/src/attributes.rs b/sea-orm-macros/src/attributes.rs index ea9ae9ba..bd893d77 100644 --- a/sea-orm-macros/src/attributes.rs +++ b/sea-orm-macros/src/attributes.rs @@ -19,6 +19,8 @@ pub mod field_attr { #[derive(Default, FromAttributes)] pub struct SeaOrm { pub belongs_to: Option, + pub has_one: Option, + pub has_many: Option, pub from: Option, pub to: Option, } diff --git a/sea-orm-macros/src/derives/entity_model.rs b/sea-orm-macros/src/derives/entity_model.rs index d7767133..eba2d33f 100644 --- a/sea-orm-macros/src/derives/entity_model.rs +++ b/sea-orm-macros/src/derives/entity_model.rs @@ -101,18 +101,19 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res } } else if name == "auto_increment" { if let Lit::Str(litstr) = &nv.lit { - auto_increment = - match litstr.value().as_str() { - "true" => true, - "false" => false, - _ => return Err(Error::new( + auto_increment = match litstr.value().as_str() { + "true" => true, + "false" => false, + _ => { + return Err(Error::new( field.span(), format!( "Invalid auto_increment = {}", litstr.value() ), - )), - }; + )) + } + }; } else { return Err(Error::new( field.span(), diff --git a/sea-orm-macros/src/derives/model.rs b/sea-orm-macros/src/derives/model.rs index 4c3b0a03..669dbe9e 100644 --- a/sea-orm-macros/src/derives/model.rs +++ b/sea-orm-macros/src/derives/model.rs @@ -1,8 +1,8 @@ -use std::iter::FromIterator; +use crate::attributes::derive_attr; use heck::CamelCase; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; -use crate::attributes::derive_attr; +use std::iter::FromIterator; enum Error { InputNotStruct, diff --git a/sea-orm-macros/src/derives/relation.rs b/sea-orm-macros/src/derives/relation.rs index 0f69c1de..290d02d0 100644 --- a/sea-orm-macros/src/derives/relation.rs +++ b/sea-orm-macros/src/derives/relation.rs @@ -52,32 +52,71 @@ impl DeriveRelation { .map(|variant| { let variant_ident = &variant.ident; let attr = field_attr::SeaOrm::from_attributes(&variant.attrs)?; - let belongs_to = attr - .belongs_to - .as_ref() - .map(Self::parse_lit_string) - .ok_or_else(|| { - syn::Error::new_spanned(variant, "Missing attribute 'belongs_to'") - })??; - let from = attr - .from - .as_ref() - .map(Self::parse_lit_string) - .ok_or_else(|| { - syn::Error::new_spanned(variant, "Missing attribute 'from'") - })??; - let to = attr - .to - .as_ref() - .map(Self::parse_lit_string) - .ok_or_else(|| syn::Error::new_spanned(variant, "Missing attribute 'to'"))??; + let mut relation_type = quote! { error }; + let related_to = if attr.belongs_to.is_some() { + relation_type = quote! { belongs_to }; + attr.belongs_to + .as_ref() + .map(Self::parse_lit_string) + .ok_or_else(|| { + syn::Error::new_spanned(variant, "Missing value for 'belongs_to'") + }) + } else if attr.has_one.is_some() { + relation_type = quote! { has_one }; + attr.has_one + .as_ref() + .map(Self::parse_lit_string) + .ok_or_else(|| { + syn::Error::new_spanned(variant, "Missing value for 'has_one'") + }) + } else if attr.has_many.is_some() { + relation_type = quote! { has_many }; + attr.has_many + .as_ref() + .map(Self::parse_lit_string) + .ok_or_else(|| { + syn::Error::new_spanned(variant, "Missing value for 'has_many'") + }) + } else { + Err(syn::Error::new_spanned( + variant, + "Missing one of 'has_one', 'has_many' or 'belongs_to'", + )) + }??; - Result::<_, syn::Error>::Ok(quote!( - Self::#variant_ident => #entity_ident::belongs_to(#belongs_to) - .from(#from) - .to(#to) - .into() - )) + let mut result = quote!( + Self::#variant_ident => #entity_ident::#relation_type(#related_to) + ); + + if attr.from.is_some() { + let from = + attr.from + .as_ref() + .map(Self::parse_lit_string) + .ok_or_else(|| { + syn::Error::new_spanned(variant, "Missing value for 'from'") + })??; + result = quote! { #result.from(#from) }; + } else if attr.belongs_to.is_some() { + return Err(syn::Error::new_spanned(variant, "Missing attribute 'from'")); + } + + if attr.to.is_some() { + let to = attr + .to + .as_ref() + .map(Self::parse_lit_string) + .ok_or_else(|| { + syn::Error::new_spanned(variant, "Missing value for 'to'") + })??; + result = quote! { #result.to(#to) }; + } else if attr.belongs_to.is_some() { + return Err(syn::Error::new_spanned(variant, "Missing attribute 'to'")); + } + + result = quote! { #result.into() }; + + Result::<_, syn::Error>::Ok(result) }) .collect::, _>>()?; diff --git a/src/entity/prelude.rs b/src/entity/prelude.rs index 98089970..8d87a4b2 100644 --- a/src/entity/prelude.rs +++ b/src/entity/prelude.rs @@ -1,9 +1,9 @@ pub use crate::{ error::*, ActiveModelBehavior, ActiveModelTrait, ColumnDef, ColumnTrait, ColumnType, DeriveActiveModel, DeriveActiveModelBehavior, DeriveColumn, DeriveCustomColumn, DeriveEntity, - DeriveEntityModel, DeriveModel, DerivePrimaryKey, EntityName, EntityTrait, EnumIter, - ForeignKeyAction, Iden, IdenStatic, Linked, ModelTrait, PrimaryKeyToColumn, PrimaryKeyTrait, - QueryFilter, QueryResult, Related, RelationDef, RelationTrait, Select, Value, + DeriveEntityModel, DeriveModel, DerivePrimaryKey, DeriveRelation, EntityName, EntityTrait, + EnumIter, ForeignKeyAction, Iden, IdenStatic, Linked, ModelTrait, PrimaryKeyToColumn, + PrimaryKeyTrait, QueryFilter, QueryResult, Related, RelationDef, RelationTrait, Select, Value, }; #[cfg(feature = "with-json")] diff --git a/src/lib.rs b/src/lib.rs index b2bbb019..d5162935 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -227,7 +227,7 @@ pub use query::*; pub use sea_orm_macros::{ DeriveActiveModel, DeriveActiveModelBehavior, DeriveColumn, DeriveCustomColumn, DeriveEntity, - DeriveEntityModel, DeriveModel, DerivePrimaryKey, FromQueryResult, + DeriveEntityModel, DeriveModel, DerivePrimaryKey, DeriveRelation, FromQueryResult, }; pub use sea_query; diff --git a/src/tests_cfg/cake.rs b/src/tests_cfg/cake.rs index 74b37eaf..63030009 100644 --- a/src/tests_cfg/cake.rs +++ b/src/tests_cfg/cake.rs @@ -1,5 +1,5 @@ use crate as sea_orm; -use sea_orm::entity::prelude::*; +use crate::entity::prelude::*; #[derive(Clone, Debug, PartialEq, DeriveEntityModel, DeriveModel, DeriveActiveModel)] #[sea_orm(table_name = "cake")] @@ -9,19 +9,12 @@ pub struct Model { pub name: String, } -#[derive(Copy, Clone, Debug, EnumIter)] +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation { + #[sea_orm(has_many = "super::fruit::Entity")] Fruit, } -impl RelationTrait for Relation { - fn def(&self) -> RelationDef { - match self { - Self::Fruit => Entity::has_many(super::fruit::Entity).into(), - } - } -} - impl Related for Entity { fn to() -> RelationDef { Relation::Fruit.def() diff --git a/src/tests_cfg/fruit.rs b/src/tests_cfg/fruit.rs index 0cbfb16f..102fe10e 100644 --- a/src/tests_cfg/fruit.rs +++ b/src/tests_cfg/fruit.rs @@ -32,8 +32,13 @@ impl PrimaryKeyTrait for PrimaryKey { } } -#[derive(Copy, Clone, Debug, EnumIter)] +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation { + #[sea_orm( + belongs_to = "super::cake::Entity", + from = "Column::CakeId", + to = "super::cake::Column::Id" + )] Cake, } @@ -49,17 +54,6 @@ impl ColumnTrait for Column { } } -impl RelationTrait for Relation { - fn def(&self) -> RelationDef { - match self { - Self::Cake => Entity::belongs_to(super::cake::Entity) - .from(Column::CakeId) - .to(super::cake::Column::Id) - .into(), - } - } -} - impl Related for Entity { fn to() -> RelationDef { Relation::Cake.def()