diff --git a/sea-orm-codegen/src/entity/relation.rs b/sea-orm-codegen/src/entity/relation.rs index 67fa77ad..8f3ef217 100644 --- a/sea-orm-codegen/src/entity/relation.rs +++ b/sea-orm-codegen/src/entity/relation.rs @@ -3,6 +3,7 @@ use heck::{ToSnakeCase, ToUpperCamelCase}; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; use sea_query::{ForeignKeyAction, TableForeignKey}; +use syn::{punctuated::Punctuated, token::Comma}; use crate::util::escape_rust_keyword; @@ -66,17 +67,27 @@ impl Relation { } } RelationType::BelongsTo => { - let column_camel_case = self.get_column_camel_case(); - let ref_column_camel_case = self.get_ref_column_camel_case(); - let to_col = if module_name.is_some() { - quote! { super::#module_name::Column::#ref_column_camel_case } - } else { - quote! { Column::#ref_column_camel_case } + let map_src_column = |src_column: &Ident| { + quote! { Column::#src_column } }; + let map_ref_column = |ref_column: &Ident| { + if module_name.is_some() { + quote! { super::#module_name::Column::#ref_column } + } else { + quote! { Column::#ref_column } + } + }; + let map_punctuated = + |punctuated: Punctuated| match punctuated.len() { + 0..=1 => quote! { #punctuated }, + _ => quote! { (#punctuated) }, + }; + let (from, to) = + self.get_src_ref_columns(map_src_column, map_ref_column, map_punctuated); quote! { Entity::#rel_type(#ref_entity) - .from(Column::#column_camel_case) - .to(#to_col) + .from(#from) + .to(#to) .into() } } @@ -98,10 +109,20 @@ impl Relation { } } RelationType::BelongsTo => { - let column_camel_case = self.get_column_camel_case(); - let ref_column_camel_case = self.get_ref_column_camel_case(); - let from = format!("Column::{column_camel_case}"); - let to = format!("{module_name}Column::{ref_column_camel_case}"); + let map_src_column = |src_column: &Ident| format!("Column::{src_column}"); + let map_ref_column = + |ref_column: &Ident| format!("{module_name}Column::{ref_column}"); + let map_punctuated = |punctuated: Vec| { + let len = punctuated.len(); + let punctuated = punctuated.join(", "); + match len { + 0..=1 => punctuated, + _ => format!("({})", punctuated), + } + }; + let (from, to) = + self.get_src_ref_columns(map_src_column, map_ref_column, map_punctuated); + let on_update = if let Some(action) = &self.on_update { let action = Self::get_foreign_key_action(action); quote! { @@ -139,12 +160,18 @@ impl Relation { } } - pub fn get_column_camel_case(&self) -> Ident { - format_ident!("{}", self.columns[0].to_upper_camel_case()) + pub fn get_column_camel_case(&self) -> Vec { + self.columns + .iter() + .map(|col| format_ident!("{}", col.to_upper_camel_case())) + .collect() } - pub fn get_ref_column_camel_case(&self) -> Ident { - format_ident!("{}", self.ref_columns[0].to_upper_camel_case()) + pub fn get_ref_column_camel_case(&self) -> Vec { + self.ref_columns + .iter() + .map(|col| format_ident!("{}", col.to_upper_camel_case())) + .collect() } pub fn get_foreign_key_action(action: &ForeignKeyAction) -> String { @@ -157,6 +184,36 @@ impl Relation { } .to_owned() } + + pub fn get_src_ref_columns( + &self, + map_src_column: F1, + map_ref_column: F2, + map_punctuated: F3, + ) -> (T, T) + where + F1: Fn(&Ident) -> T, + F2: Fn(&Ident) -> T, + F3: Fn(I) -> T, + I: Extend + Default, + { + let from: I = + self.get_column_camel_case() + .iter() + .fold(I::default(), |mut acc, src_column| { + acc.extend([map_src_column(src_column)]); + acc + }); + let to: I = + self.get_ref_column_camel_case() + .iter() + .fold(I::default(), |mut acc, ref_column| { + acc.extend([map_ref_column(ref_column)]); + acc + }); + + (map_punctuated(from), map_punctuated(to)) + } } impl From<&TableForeignKey> for Relation { @@ -278,7 +335,7 @@ mod tests { let relations = setup(); let cols = vec!["Id", "FillingId", "FillingId"]; for (rel, col) in relations.into_iter().zip(cols) { - assert_eq!(rel.get_column_camel_case(), col); + assert_eq!(rel.get_column_camel_case(), [col]); } } @@ -287,7 +344,7 @@ mod tests { let relations = setup(); let ref_cols = vec!["CakeId", "Id", "Id"]; for (rel, ref_col) in relations.into_iter().zip(ref_cols) { - assert_eq!(rel.get_ref_column_camel_case(), ref_col); + assert_eq!(rel.get_ref_column_camel_case(), [ref_col]); } } } diff --git a/sea-orm-codegen/src/entity/writer.rs b/sea-orm-codegen/src/entity/writer.rs index 637e6eb2..dc60654f 100644 --- a/sea-orm-codegen/src/entity/writer.rs +++ b/sea-orm-codegen/src/entity/writer.rs @@ -908,6 +908,52 @@ mod tests { }, ], }, + Entity { + table_name: "cake_filling_price".to_owned(), + columns: vec![ + Column { + name: "cake_id".to_owned(), + col_type: ColumnType::Integer, + auto_increment: false, + not_null: true, + unique: false, + }, + Column { + name: "filling_id".to_owned(), + col_type: ColumnType::Integer, + auto_increment: false, + not_null: true, + unique: false, + }, + Column { + name: "price".to_owned(), + col_type: ColumnType::Decimal(None), + auto_increment: false, + not_null: true, + unique: false, + }, + ], + relations: vec![Relation { + ref_table: "cake_filling".to_owned(), + columns: vec!["cake_id".to_owned(), "filling_id".to_owned()], + ref_columns: vec!["cake_id".to_owned(), "filling_id".to_owned()], + rel_type: RelationType::BelongsTo, + on_delete: None, + on_update: None, + self_referencing: false, + num_suffix: 0, + impl_related: true, + }], + conjunct_relations: vec![], + primary_keys: vec![ + PrimaryKey { + name: "cake_id".to_owned(), + }, + PrimaryKey { + name: "filling_id".to_owned(), + }, + ], + }, Entity { table_name: "filling".to_owned(), columns: vec![ @@ -1361,9 +1407,10 @@ mod tests { #[test] fn test_gen_expanded_code_blocks() -> io::Result<()> { let entities = setup(); - const ENTITY_FILES: [&str; 10] = [ + const ENTITY_FILES: [&str; 11] = [ include_str!("../../tests/expanded/cake.rs"), include_str!("../../tests/expanded/cake_filling.rs"), + include_str!("../../tests/expanded/cake_filling_price.rs"), include_str!("../../tests/expanded/filling.rs"), include_str!("../../tests/expanded/fruit.rs"), include_str!("../../tests/expanded/vendor.rs"), @@ -1373,9 +1420,10 @@ mod tests { include_str!("../../tests/expanded/collection.rs"), include_str!("../../tests/expanded/collection_float.rs"), ]; - const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 10] = [ + const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 11] = [ include_str!("../../tests/expanded_with_schema_name/cake.rs"), include_str!("../../tests/expanded_with_schema_name/cake_filling.rs"), + include_str!("../../tests/expanded_with_schema_name/cake_filling_price.rs"), include_str!("../../tests/expanded_with_schema_name/filling.rs"), include_str!("../../tests/expanded_with_schema_name/fruit.rs"), include_str!("../../tests/expanded_with_schema_name/vendor.rs"), @@ -1460,9 +1508,10 @@ mod tests { #[test] fn test_gen_compact_code_blocks() -> io::Result<()> { let entities = setup(); - const ENTITY_FILES: [&str; 10] = [ + const ENTITY_FILES: [&str; 11] = [ include_str!("../../tests/compact/cake.rs"), include_str!("../../tests/compact/cake_filling.rs"), + include_str!("../../tests/compact/cake_filling_price.rs"), include_str!("../../tests/compact/filling.rs"), include_str!("../../tests/compact/fruit.rs"), include_str!("../../tests/compact/vendor.rs"), @@ -1472,9 +1521,10 @@ mod tests { include_str!("../../tests/compact/collection.rs"), include_str!("../../tests/compact/collection_float.rs"), ]; - const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 10] = [ + const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 11] = [ include_str!("../../tests/compact_with_schema_name/cake.rs"), include_str!("../../tests/compact_with_schema_name/cake_filling.rs"), + include_str!("../../tests/compact_with_schema_name/cake_filling_price.rs"), include_str!("../../tests/compact_with_schema_name/filling.rs"), include_str!("../../tests/compact_with_schema_name/fruit.rs"), include_str!("../../tests/compact_with_schema_name/vendor.rs"), diff --git a/sea-orm-codegen/tests/compact/cake_filling_price.rs b/sea-orm-codegen/tests/compact/cake_filling_price.rs new file mode 100644 index 00000000..4cbef160 --- /dev/null +++ b/sea-orm-codegen/tests/compact/cake_filling_price.rs @@ -0,0 +1,31 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] +#[sea_orm(table_name = "cake_filling_price")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub cake_id: i32, + #[sea_orm(primary_key, auto_increment = false)] + pub filling_id: i32, + pub price: Decimal, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::cake_filling::Entity", + from = "(Column::CakeId, Column::FillingId)", + to = "(super::cake_filling::Column::CakeId, super::cake_filling::Column::FillingId)", + )] + CakeFilling, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::CakeFilling.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} \ No newline at end of file diff --git a/sea-orm-codegen/tests/compact_with_schema_name/cake_filling_price.rs b/sea-orm-codegen/tests/compact_with_schema_name/cake_filling_price.rs new file mode 100644 index 00000000..6a75dd3f --- /dev/null +++ b/sea-orm-codegen/tests/compact_with_schema_name/cake_filling_price.rs @@ -0,0 +1,31 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] +#[sea_orm(schema_name = "schema_name", table_name = "cake_filling_price")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub cake_id: i32, + #[sea_orm(primary_key, auto_increment = false)] + pub filling_id: i32, + pub price: Decimal, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::cake_filling::Entity", + from = "(Column::CakeId, Column::FillingId)", + to = "(super::cake_filling::Column::CakeId, super::cake_filling::Column::FillingId)", + )] + CakeFilling, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::CakeFilling.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} \ No newline at end of file diff --git a/sea-orm-codegen/tests/expanded/cake_filling_price.rs b/sea-orm-codegen/tests/expanded/cake_filling_price.rs new file mode 100644 index 00000000..c680c881 --- /dev/null +++ b/sea-orm-codegen/tests/expanded/cake_filling_price.rs @@ -0,0 +1,79 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +use sea_orm::entity::prelude::*; + +#[derive(Copy, Clone, Default, Debug, DeriveEntity)] +pub struct Entity; + +impl EntityName for Entity { + fn table_name(&self) -> &str { + "cake_filling_price" + } +} + +#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel, Eq)] +pub struct Model { + pub cake_id: i32, + pub filling_id: i32, + pub price: Decimal, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +pub enum Column { + CakeId, + FillingId, + Price, +} + +#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)] +pub enum PrimaryKey { + CakeId, + FillingId, +} + +impl PrimaryKeyTrait for PrimaryKey { + type ValueType = (i32, i32); + + fn auto_increment() -> bool { + false + } +} + +#[derive(Copy, Clone, Debug, EnumIter)] +pub enum Relation { + CakeFilling, +} + +impl ColumnTrait for Column { + type EntityName = Entity; + + fn def(&self) -> ColumnDef { + match self { + Self::CakeId => ColumnType::Integer.def(), + Self::FillingId => ColumnType::Integer.def(), + Self::Price => ColumnType::Decimal(None).def(), + } + } +} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + match self { + Self::CakeFilling => Entity::belongs_to(super::cake_filling::Entity) + .from((Column::CakeId, Column::FillingId)) + .to(( + super::cake_filling::Column::CakeId, + super::cake_filling::Column::FillingId + )) + .into(), + } + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::CakeFilling.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/sea-orm-codegen/tests/expanded_with_schema_name/cake_filling_price.rs b/sea-orm-codegen/tests/expanded_with_schema_name/cake_filling_price.rs new file mode 100644 index 00000000..8f613c9c --- /dev/null +++ b/sea-orm-codegen/tests/expanded_with_schema_name/cake_filling_price.rs @@ -0,0 +1,83 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +use sea_orm::entity::prelude::*; + +#[derive(Copy, Clone, Default, Debug, DeriveEntity)] +pub struct Entity; + +impl EntityName for Entity { + fn schema_name(&self) -> Option< &str > { + Some("schema_name") + } + + fn table_name(&self) -> &str { + "cake_filling_price" + } +} + +#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel, Eq)] +pub struct Model { + pub cake_id: i32, + pub filling_id: i32, + pub price: Decimal, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +pub enum Column { + CakeId, + FillingId, + Price, +} + +#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)] +pub enum PrimaryKey { + CakeId, + FillingId, +} + +impl PrimaryKeyTrait for PrimaryKey { + type ValueType = (i32, i32); + + fn auto_increment() -> bool { + false + } +} + +#[derive(Copy, Clone, Debug, EnumIter)] +pub enum Relation { + CakeFilling, +} + +impl ColumnTrait for Column { + type EntityName = Entity; + + fn def(&self) -> ColumnDef { + match self { + Self::CakeId => ColumnType::Integer.def(), + Self::FillingId => ColumnType::Integer.def(), + Self::Price => ColumnType::Decimal(None).def(), + } + } +} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + match self { + Self::CakeFilling => Entity::belongs_to(super::cake_filling::Entity) + .from((Column::CakeId, Column::FillingId)) + .to(( + super::cake_filling::Column::CakeId, + super::cake_filling::Column::FillingId + )) + .into(), + } + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::CakeFilling.def() + } +} + +impl ActiveModelBehavior for ActiveModel {}