Fix: generate relation for related entity of composite foreign key (#1693)

* Fix: generate relation for related entity of composite foreign key

* clippy
This commit is contained in:
Billy Chan 2023-09-22 02:28:49 +08:00 committed by GitHub
parent 59754f5ff9
commit 9d033d01a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 353 additions and 22 deletions

View File

@ -3,6 +3,7 @@ use heck::{ToSnakeCase, ToUpperCamelCase};
use proc_macro2::{Ident, TokenStream}; use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use sea_query::{ForeignKeyAction, TableForeignKey}; use sea_query::{ForeignKeyAction, TableForeignKey};
use syn::{punctuated::Punctuated, token::Comma};
use crate::util::escape_rust_keyword; use crate::util::escape_rust_keyword;
@ -66,17 +67,27 @@ impl Relation {
} }
} }
RelationType::BelongsTo => { RelationType::BelongsTo => {
let column_camel_case = self.get_column_camel_case(); let map_src_column = |src_column: &Ident| {
let ref_column_camel_case = self.get_ref_column_camel_case(); quote! { Column::#src_column }
let to_col = if module_name.is_some() {
quote! { super::#module_name::Column::#ref_column_camel_case }
} else {
quote! { Column::#ref_column_camel_case }
}; };
let map_ref_column = |ref_column: &Ident| {
if module_name.is_some() {
quote! { super::#module_name::Column::#ref_column }
} else {
quote! { Column::#ref_column }
}
};
let map_punctuated =
|punctuated: Punctuated<TokenStream, Comma>| match punctuated.len() {
0..=1 => quote! { #punctuated },
_ => quote! { (#punctuated) },
};
let (from, to) =
self.get_src_ref_columns(map_src_column, map_ref_column, map_punctuated);
quote! { quote! {
Entity::#rel_type(#ref_entity) Entity::#rel_type(#ref_entity)
.from(Column::#column_camel_case) .from(#from)
.to(#to_col) .to(#to)
.into() .into()
} }
} }
@ -98,10 +109,20 @@ impl Relation {
} }
} }
RelationType::BelongsTo => { RelationType::BelongsTo => {
let column_camel_case = self.get_column_camel_case(); let map_src_column = |src_column: &Ident| format!("Column::{src_column}");
let ref_column_camel_case = self.get_ref_column_camel_case(); let map_ref_column =
let from = format!("Column::{column_camel_case}"); |ref_column: &Ident| format!("{module_name}Column::{ref_column}");
let to = format!("{module_name}Column::{ref_column_camel_case}"); let map_punctuated = |punctuated: Vec<String>| {
let len = punctuated.len();
let punctuated = punctuated.join(", ");
match len {
0..=1 => punctuated,
_ => format!("({})", punctuated),
}
};
let (from, to) =
self.get_src_ref_columns(map_src_column, map_ref_column, map_punctuated);
let on_update = if let Some(action) = &self.on_update { let on_update = if let Some(action) = &self.on_update {
let action = Self::get_foreign_key_action(action); let action = Self::get_foreign_key_action(action);
quote! { quote! {
@ -139,12 +160,18 @@ impl Relation {
} }
} }
pub fn get_column_camel_case(&self) -> Ident { pub fn get_column_camel_case(&self) -> Vec<Ident> {
format_ident!("{}", self.columns[0].to_upper_camel_case()) self.columns
.iter()
.map(|col| format_ident!("{}", col.to_upper_camel_case()))
.collect()
} }
pub fn get_ref_column_camel_case(&self) -> Ident { pub fn get_ref_column_camel_case(&self) -> Vec<Ident> {
format_ident!("{}", self.ref_columns[0].to_upper_camel_case()) self.ref_columns
.iter()
.map(|col| format_ident!("{}", col.to_upper_camel_case()))
.collect()
} }
pub fn get_foreign_key_action(action: &ForeignKeyAction) -> String { pub fn get_foreign_key_action(action: &ForeignKeyAction) -> String {
@ -157,6 +184,36 @@ impl Relation {
} }
.to_owned() .to_owned()
} }
pub fn get_src_ref_columns<F1, F2, F3, T, I>(
&self,
map_src_column: F1,
map_ref_column: F2,
map_punctuated: F3,
) -> (T, T)
where
F1: Fn(&Ident) -> T,
F2: Fn(&Ident) -> T,
F3: Fn(I) -> T,
I: Extend<T> + Default,
{
let from: I =
self.get_column_camel_case()
.iter()
.fold(I::default(), |mut acc, src_column| {
acc.extend([map_src_column(src_column)]);
acc
});
let to: I =
self.get_ref_column_camel_case()
.iter()
.fold(I::default(), |mut acc, ref_column| {
acc.extend([map_ref_column(ref_column)]);
acc
});
(map_punctuated(from), map_punctuated(to))
}
} }
impl From<&TableForeignKey> for Relation { impl From<&TableForeignKey> for Relation {
@ -278,7 +335,7 @@ mod tests {
let relations = setup(); let relations = setup();
let cols = vec!["Id", "FillingId", "FillingId"]; let cols = vec!["Id", "FillingId", "FillingId"];
for (rel, col) in relations.into_iter().zip(cols) { for (rel, col) in relations.into_iter().zip(cols) {
assert_eq!(rel.get_column_camel_case(), col); assert_eq!(rel.get_column_camel_case(), [col]);
} }
} }
@ -287,7 +344,7 @@ mod tests {
let relations = setup(); let relations = setup();
let ref_cols = vec!["CakeId", "Id", "Id"]; let ref_cols = vec!["CakeId", "Id", "Id"];
for (rel, ref_col) in relations.into_iter().zip(ref_cols) { for (rel, ref_col) in relations.into_iter().zip(ref_cols) {
assert_eq!(rel.get_ref_column_camel_case(), ref_col); assert_eq!(rel.get_ref_column_camel_case(), [ref_col]);
} }
} }
} }

View File

@ -908,6 +908,52 @@ mod tests {
}, },
], ],
}, },
Entity {
table_name: "cake_filling_price".to_owned(),
columns: vec![
Column {
name: "cake_id".to_owned(),
col_type: ColumnType::Integer,
auto_increment: false,
not_null: true,
unique: false,
},
Column {
name: "filling_id".to_owned(),
col_type: ColumnType::Integer,
auto_increment: false,
not_null: true,
unique: false,
},
Column {
name: "price".to_owned(),
col_type: ColumnType::Decimal(None),
auto_increment: false,
not_null: true,
unique: false,
},
],
relations: vec![Relation {
ref_table: "cake_filling".to_owned(),
columns: vec!["cake_id".to_owned(), "filling_id".to_owned()],
ref_columns: vec!["cake_id".to_owned(), "filling_id".to_owned()],
rel_type: RelationType::BelongsTo,
on_delete: None,
on_update: None,
self_referencing: false,
num_suffix: 0,
impl_related: true,
}],
conjunct_relations: vec![],
primary_keys: vec![
PrimaryKey {
name: "cake_id".to_owned(),
},
PrimaryKey {
name: "filling_id".to_owned(),
},
],
},
Entity { Entity {
table_name: "filling".to_owned(), table_name: "filling".to_owned(),
columns: vec![ columns: vec![
@ -1361,9 +1407,10 @@ mod tests {
#[test] #[test]
fn test_gen_expanded_code_blocks() -> io::Result<()> { fn test_gen_expanded_code_blocks() -> io::Result<()> {
let entities = setup(); let entities = setup();
const ENTITY_FILES: [&str; 10] = [ const ENTITY_FILES: [&str; 11] = [
include_str!("../../tests/expanded/cake.rs"), include_str!("../../tests/expanded/cake.rs"),
include_str!("../../tests/expanded/cake_filling.rs"), include_str!("../../tests/expanded/cake_filling.rs"),
include_str!("../../tests/expanded/cake_filling_price.rs"),
include_str!("../../tests/expanded/filling.rs"), include_str!("../../tests/expanded/filling.rs"),
include_str!("../../tests/expanded/fruit.rs"), include_str!("../../tests/expanded/fruit.rs"),
include_str!("../../tests/expanded/vendor.rs"), include_str!("../../tests/expanded/vendor.rs"),
@ -1373,9 +1420,10 @@ mod tests {
include_str!("../../tests/expanded/collection.rs"), include_str!("../../tests/expanded/collection.rs"),
include_str!("../../tests/expanded/collection_float.rs"), include_str!("../../tests/expanded/collection_float.rs"),
]; ];
const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 10] = [ const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 11] = [
include_str!("../../tests/expanded_with_schema_name/cake.rs"), include_str!("../../tests/expanded_with_schema_name/cake.rs"),
include_str!("../../tests/expanded_with_schema_name/cake_filling.rs"), include_str!("../../tests/expanded_with_schema_name/cake_filling.rs"),
include_str!("../../tests/expanded_with_schema_name/cake_filling_price.rs"),
include_str!("../../tests/expanded_with_schema_name/filling.rs"), include_str!("../../tests/expanded_with_schema_name/filling.rs"),
include_str!("../../tests/expanded_with_schema_name/fruit.rs"), include_str!("../../tests/expanded_with_schema_name/fruit.rs"),
include_str!("../../tests/expanded_with_schema_name/vendor.rs"), include_str!("../../tests/expanded_with_schema_name/vendor.rs"),
@ -1460,9 +1508,10 @@ mod tests {
#[test] #[test]
fn test_gen_compact_code_blocks() -> io::Result<()> { fn test_gen_compact_code_blocks() -> io::Result<()> {
let entities = setup(); let entities = setup();
const ENTITY_FILES: [&str; 10] = [ const ENTITY_FILES: [&str; 11] = [
include_str!("../../tests/compact/cake.rs"), include_str!("../../tests/compact/cake.rs"),
include_str!("../../tests/compact/cake_filling.rs"), include_str!("../../tests/compact/cake_filling.rs"),
include_str!("../../tests/compact/cake_filling_price.rs"),
include_str!("../../tests/compact/filling.rs"), include_str!("../../tests/compact/filling.rs"),
include_str!("../../tests/compact/fruit.rs"), include_str!("../../tests/compact/fruit.rs"),
include_str!("../../tests/compact/vendor.rs"), include_str!("../../tests/compact/vendor.rs"),
@ -1472,9 +1521,10 @@ mod tests {
include_str!("../../tests/compact/collection.rs"), include_str!("../../tests/compact/collection.rs"),
include_str!("../../tests/compact/collection_float.rs"), include_str!("../../tests/compact/collection_float.rs"),
]; ];
const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 10] = [ const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 11] = [
include_str!("../../tests/compact_with_schema_name/cake.rs"), include_str!("../../tests/compact_with_schema_name/cake.rs"),
include_str!("../../tests/compact_with_schema_name/cake_filling.rs"), include_str!("../../tests/compact_with_schema_name/cake_filling.rs"),
include_str!("../../tests/compact_with_schema_name/cake_filling_price.rs"),
include_str!("../../tests/compact_with_schema_name/filling.rs"), include_str!("../../tests/compact_with_schema_name/filling.rs"),
include_str!("../../tests/compact_with_schema_name/fruit.rs"), include_str!("../../tests/compact_with_schema_name/fruit.rs"),
include_str!("../../tests/compact_with_schema_name/vendor.rs"), include_str!("../../tests/compact_with_schema_name/vendor.rs"),

View File

@ -0,0 +1,31 @@
//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)]
#[sea_orm(table_name = "cake_filling_price")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub cake_id: i32,
#[sea_orm(primary_key, auto_increment = false)]
pub filling_id: i32,
pub price: Decimal,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::cake_filling::Entity",
from = "(Column::CakeId, Column::FillingId)",
to = "(super::cake_filling::Column::CakeId, super::cake_filling::Column::FillingId)",
)]
CakeFilling,
}
impl Related<super::cake_filling::Entity> for Entity {
fn to() -> RelationDef {
Relation::CakeFilling.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,31 @@
//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)]
#[sea_orm(schema_name = "schema_name", table_name = "cake_filling_price")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub cake_id: i32,
#[sea_orm(primary_key, auto_increment = false)]
pub filling_id: i32,
pub price: Decimal,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::cake_filling::Entity",
from = "(Column::CakeId, Column::FillingId)",
to = "(super::cake_filling::Column::CakeId, super::cake_filling::Column::FillingId)",
)]
CakeFilling,
}
impl Related<super::cake_filling::Entity> for Entity {
fn to() -> RelationDef {
Relation::CakeFilling.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,79 @@
//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0
use sea_orm::entity::prelude::*;
#[derive(Copy, Clone, Default, Debug, DeriveEntity)]
pub struct Entity;
impl EntityName for Entity {
fn table_name(&self) -> &str {
"cake_filling_price"
}
}
#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel, Eq)]
pub struct Model {
pub cake_id: i32,
pub filling_id: i32,
pub price: Decimal,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
pub enum Column {
CakeId,
FillingId,
Price,
}
#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)]
pub enum PrimaryKey {
CakeId,
FillingId,
}
impl PrimaryKeyTrait for PrimaryKey {
type ValueType = (i32, i32);
fn auto_increment() -> bool {
false
}
}
#[derive(Copy, Clone, Debug, EnumIter)]
pub enum Relation {
CakeFilling,
}
impl ColumnTrait for Column {
type EntityName = Entity;
fn def(&self) -> ColumnDef {
match self {
Self::CakeId => ColumnType::Integer.def(),
Self::FillingId => ColumnType::Integer.def(),
Self::Price => ColumnType::Decimal(None).def(),
}
}
}
impl RelationTrait for Relation {
fn def(&self) -> RelationDef {
match self {
Self::CakeFilling => Entity::belongs_to(super::cake_filling::Entity)
.from((Column::CakeId, Column::FillingId))
.to((
super::cake_filling::Column::CakeId,
super::cake_filling::Column::FillingId
))
.into(),
}
}
}
impl Related<super::cake_filling::Entity> for Entity {
fn to() -> RelationDef {
Relation::CakeFilling.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,83 @@
//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0
use sea_orm::entity::prelude::*;
#[derive(Copy, Clone, Default, Debug, DeriveEntity)]
pub struct Entity;
impl EntityName for Entity {
fn schema_name(&self) -> Option< &str > {
Some("schema_name")
}
fn table_name(&self) -> &str {
"cake_filling_price"
}
}
#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel, Eq)]
pub struct Model {
pub cake_id: i32,
pub filling_id: i32,
pub price: Decimal,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
pub enum Column {
CakeId,
FillingId,
Price,
}
#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)]
pub enum PrimaryKey {
CakeId,
FillingId,
}
impl PrimaryKeyTrait for PrimaryKey {
type ValueType = (i32, i32);
fn auto_increment() -> bool {
false
}
}
#[derive(Copy, Clone, Debug, EnumIter)]
pub enum Relation {
CakeFilling,
}
impl ColumnTrait for Column {
type EntityName = Entity;
fn def(&self) -> ColumnDef {
match self {
Self::CakeId => ColumnType::Integer.def(),
Self::FillingId => ColumnType::Integer.def(),
Self::Price => ColumnType::Decimal(None).def(),
}
}
}
impl RelationTrait for Relation {
fn def(&self) -> RelationDef {
match self {
Self::CakeFilling => Entity::belongs_to(super::cake_filling::Entity)
.from((Column::CakeId, Column::FillingId))
.to((
super::cake_filling::Column::CakeId,
super::cake_filling::Column::FillingId
))
.into(),
}
}
}
impl Related<super::cake_filling::Entity> for Entity {
fn to() -> RelationDef {
Relation::CakeFilling.def()
}
}
impl ActiveModelBehavior for ActiveModel {}