From 75cb269ceba69064ac68d66faf2b7edc34c8918f Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Wed, 29 Jun 2022 00:27:55 +0800 Subject: [PATCH] Custom join on conditions (#793) * Custom join on conditions * Try lambda approach * Implement debug for relation * Add example without `rev` * Add more examples * Docs --- sea-orm-macros/src/attributes.rs | 1 + sea-orm-macros/src/derives/relation.rs | 11 ++ src/entity/link.rs | 22 ++-- src/entity/relation.rs | 116 ++++++++++++++++- src/query/helper.rs | 14 ++- src/query/join.rs | 167 +++++++++++++++++++++++-- src/tests_cfg/cake.rs | 5 + src/tests_cfg/entity_linked.rs | 48 +++++++ 8 files changed, 364 insertions(+), 20 deletions(-) diff --git a/sea-orm-macros/src/attributes.rs b/sea-orm-macros/src/attributes.rs index 3b7ed47b..6b545f68 100644 --- a/sea-orm-macros/src/attributes.rs +++ b/sea-orm-macros/src/attributes.rs @@ -26,6 +26,7 @@ pub mod field_attr { pub has_many: Option, pub on_update: Option, pub on_delete: Option, + pub on_condition: Option, pub from: Option, pub to: Option, pub fk_name: Option, diff --git a/sea-orm-macros/src/derives/relation.rs b/sea-orm-macros/src/derives/relation.rs index 4b6d7d61..e86455e8 100644 --- a/sea-orm-macros/src/derives/relation.rs +++ b/sea-orm-macros/src/derives/relation.rs @@ -136,6 +136,17 @@ impl DeriveRelation { result = quote! { #result.on_delete(sea_orm::prelude::ForeignKeyAction::#on_delete) }; } + if attr.on_condition.is_some() { + let on_condition = attr + .on_condition + .as_ref() + .map(Self::parse_lit_string) + .ok_or_else(|| { + syn::Error::new_spanned(variant, "Missing value for 'on_condition'") + })??; + result = quote! { #result.on_condition(|_, _| sea_orm::sea_query::IntoCondition::into_condition(#on_condition)) }; + } + if attr.fk_name.is_some() { let fk_name = attr .fk_name diff --git a/src/entity/link.rs b/src/entity/link.rs index 4d91b8a0..0ab60dec 100644 --- a/src/entity/link.rs +++ b/src/entity/link.rs @@ -1,7 +1,7 @@ use crate::{ join_tbl_on_condition, unpack_table_ref, EntityTrait, QuerySelect, RelationDef, Select, }; -use sea_query::{Alias, IntoIden, JoinType, SeaRc}; +use sea_query::{Alias, Condition, IntoIden, JoinType, SeaRc}; /// Same as [RelationDef] pub type LinkDef = RelationDef; @@ -20,20 +20,28 @@ pub trait Linked { /// Find all the Entities that are linked to the Entity fn find_linked(&self) -> Select { let mut select = Select::new(); - for (i, rel) in self.link().into_iter().rev().enumerate() { + for (i, mut rel) in self.link().into_iter().rev().enumerate() { let from_tbl = Alias::new(&format!("r{}", i)).into_iden(); let to_tbl = if i > 0 { Alias::new(&format!("r{}", i - 1)).into_iden() } else { unpack_table_ref(&rel.to_tbl) }; + let table_ref = rel.from_tbl; - select.query().join_as( - JoinType::InnerJoin, - rel.from_tbl, + let mut condition = Condition::all().add(join_tbl_on_condition( SeaRc::clone(&from_tbl), - join_tbl_on_condition(from_tbl, to_tbl, rel.from_col, rel.to_col), - ); + SeaRc::clone(&to_tbl), + rel.from_col, + rel.to_col, + )); + if let Some(f) = rel.on_condition.take() { + condition = condition.add(f(SeaRc::clone(&from_tbl), SeaRc::clone(&to_tbl))); + } + + select + .query() + .join_as(JoinType::InnerJoin, table_ref, from_tbl, condition); } select } diff --git a/src/entity/relation.rs b/src/entity/relation.rs index d43c856e..f4723352 100644 --- a/src/entity/relation.rs +++ b/src/entity/relation.rs @@ -1,6 +1,6 @@ use crate::{EntityTrait, Identity, IdentityOf, Iterable, QuerySelect, Select}; use core::marker::PhantomData; -use sea_query::{JoinType, TableRef}; +use sea_query::{Alias, Condition, DynIden, JoinType, SeaRc, TableRef}; use std::fmt::Debug; /// Defines the type of relationship @@ -42,7 +42,6 @@ where } /// Defines a relationship -#[derive(Debug)] pub struct RelationDef { /// The type of relationship defined in [RelationType] pub rel_type: RelationType, @@ -62,12 +61,49 @@ pub struct RelationDef { /// Defines an operation to be performed on a Foreign Key when a /// `UPDATE` Operation is performed pub on_update: Option, + /// Custom join ON condition + pub on_condition: Option Condition>>, /// The name of foreign key constraint pub fk_name: Option, } +impl std::fmt::Debug for RelationDef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut d = f.debug_struct("RelationDef"); + d.field("rel_type", &self.rel_type) + .field("from_tbl", &self.from_tbl) + .field("to_tbl", &self.to_tbl) + .field("from_col", &self.from_col) + .field("to_col", &self.to_col) + .field("is_owner", &self.is_owner) + .field("on_delete", &self.on_delete) + .field("on_update", &self.on_update); + debug_on_condition(&mut d, &self.on_condition); + d.field("fk_name", &self.fk_name).finish() + } +} + +fn debug_on_condition( + d: &mut core::fmt::DebugStruct<'_, '_>, + on_condition: &Option Condition>>, +) { + match on_condition { + Some(func) => { + d.field( + "on_condition", + &func( + SeaRc::new(Alias::new("left")), + SeaRc::new(Alias::new("right")), + ), + ); + } + None => { + d.field("on_condition", &Option::::None); + } + } +} + /// Defines a helper to build a relation -#[derive(Debug)] pub struct RelationBuilder where E: EntityTrait, @@ -82,9 +118,31 @@ where is_owner: bool, on_delete: Option, on_update: Option, + on_condition: Option Condition>>, fk_name: Option, } +impl std::fmt::Debug for RelationBuilder +where + E: EntityTrait, + R: EntityTrait, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut d = f.debug_struct("RelationBuilder"); + d.field("entities", &self.entities) + .field("rel_type", &self.rel_type) + .field("from_tbl", &self.from_tbl) + .field("to_tbl", &self.to_tbl) + .field("from_col", &self.from_col) + .field("to_col", &self.to_col) + .field("is_owner", &self.is_owner) + .field("on_delete", &self.on_delete) + .field("on_update", &self.on_update); + debug_on_condition(&mut d, &self.on_condition); + d.field("fk_name", &self.fk_name).finish() + } +} + impl RelationDef { /// Reverse this relation (swap from and to) pub fn rev(self) -> Self { @@ -97,9 +155,46 @@ impl RelationDef { is_owner: !self.is_owner, on_delete: self.on_delete, on_update: self.on_update, + on_condition: self.on_condition, fk_name: None, } } + + /// Set custom join ON condition. + /// + /// This method takes a closure with parameters + /// denoting the left-hand side and right-hand side table in the join expression. + /// + /// # Examples + /// + /// assert_eq!( + /// cake::Entity::find() + /// .join( + /// JoinType::LeftJoin, + /// cake_filling::Relation::Cake + /// .def() + /// .rev() + /// .on_condition(|_left, right| { + /// Expr::tbl(right, cake_filling::Column::CakeId) + /// .gt(10) + /// .into_condition() + /// }) + /// ) + /// .build(DbBackend::MySql) + /// .to_string(), + /// [ + /// "SELECT `cake`.`id`, `cake`.`name` FROM `cake`", + /// "LEFT JOIN `cake_filling` ON `cake`.`id` = `cake_filling`.`cake_id` AND `cake_filling`.`cake_id` > 10", + /// ] + /// .join(" ") + /// ); + pub fn on_condition(mut self, f: F) -> Self + where + F: Fn(DynIden, DynIden) -> Condition + 'static, + { + self.on_condition = Some(Box::new(f)); + self + } } impl RelationBuilder @@ -118,6 +213,7 @@ where is_owner, on_delete: None, on_update: None, + on_condition: None, fk_name: None, } } @@ -133,6 +229,7 @@ where is_owner, on_delete: None, on_update: None, + on_condition: None, fk_name: None, } } @@ -167,6 +264,18 @@ where self } + /// Set custom join ON condition. + /// + /// This method takes a closure with parameters + /// denoting the left-hand side and right-hand side table in the join expression. + pub fn on_condition(mut self, f: F) -> Self + where + F: Fn(DynIden, DynIden) -> Condition + 'static, + { + self.on_condition = Some(Box::new(f)); + self + } + /// Set the name of foreign key constraint pub fn fk_name(mut self, fk_name: &str) -> Self { self.fk_name = Some(fk_name.to_owned()); @@ -189,6 +298,7 @@ where is_owner: b.is_owner, on_delete: b.on_delete, on_update: b.on_update, + on_condition: b.on_condition, fk_name: b.fk_name, } } diff --git a/src/query/helper.rs b/src/query/helper.rs index 02a6d36e..e239f926 100644 --- a/src/query/helper.rs +++ b/src/query/helper.rs @@ -430,13 +430,23 @@ pub trait QueryFilter: Sized { } } -pub(crate) fn join_condition(rel: RelationDef) -> SimpleExpr { +pub(crate) fn join_condition(mut rel: RelationDef) -> Condition { let from_tbl = unpack_table_ref(&rel.from_tbl); let to_tbl = unpack_table_ref(&rel.to_tbl); let owner_keys = rel.from_col; let foreign_keys = rel.to_col; - join_tbl_on_condition(from_tbl, to_tbl, owner_keys, foreign_keys) + let mut condition = Condition::all().add(join_tbl_on_condition( + SeaRc::clone(&from_tbl), + SeaRc::clone(&to_tbl), + owner_keys, + foreign_keys, + )); + if let Some(f) = rel.on_condition.take() { + condition = condition.add(f(from_tbl, to_tbl)); + } + + condition } pub(crate) fn join_tbl_on_condition( diff --git a/src/query/join.rs b/src/query/join.rs index 1e6ac9de..f936e546 100644 --- a/src/query/join.rs +++ b/src/query/join.rs @@ -3,7 +3,7 @@ use crate::{ Linked, QuerySelect, Related, Select, SelectA, SelectB, SelectTwo, SelectTwoMany, }; pub use sea_query::JoinType; -use sea_query::{Alias, DynIden, Expr, IntoIden, SeaRc, SelectExpr}; +use sea_query::{Alias, Condition, DynIden, Expr, IntoIden, SeaRc, SelectExpr}; impl Select where @@ -69,20 +69,27 @@ where T: EntityTrait, { let mut slf = self; - for (i, rel) in l.link().into_iter().enumerate() { + for (i, mut rel) in l.link().into_iter().enumerate() { let to_tbl = Alias::new(&format!("r{}", i)).into_iden(); let from_tbl = if i > 0 { Alias::new(&format!("r{}", i - 1)).into_iden() } else { unpack_table_ref(&rel.from_tbl) }; + let table_ref = rel.to_tbl; - slf.query().join_as( - JoinType::LeftJoin, - rel.to_tbl, + let mut condition = Condition::all().add(join_tbl_on_condition( + SeaRc::clone(&from_tbl), SeaRc::clone(&to_tbl), - join_tbl_on_condition(from_tbl, to_tbl, rel.from_col, rel.to_col), - ); + rel.from_col, + rel.to_col, + )); + if let Some(f) = rel.on_condition.take() { + condition = condition.add(f(SeaRc::clone(&from_tbl), SeaRc::clone(&to_tbl))); + } + + slf.query() + .join_as(JoinType::LeftJoin, table_ref, to_tbl, condition); } slf = slf.apply_alias(SelectA.as_str()); let text_type = SeaRc::new(Alias::new("text")) as DynIden; @@ -112,8 +119,12 @@ where #[cfg(test)] mod tests { use crate::tests_cfg::{cake, cake_filling, cake_filling_price, entity_linked, filling, fruit}; - use crate::{ColumnTrait, DbBackend, EntityTrait, ModelTrait, QueryFilter, QueryTrait}; + use crate::{ + ColumnTrait, DbBackend, EntityTrait, ModelTrait, QueryFilter, QuerySelect, QueryTrait, + RelationTrait, + }; use pretty_assertions::assert_eq; + use sea_query::{Expr, IntoCondition, JoinType}; #[test] fn join_1() { @@ -355,4 +366,144 @@ mod tests { .join(" ") ); } + + #[test] + fn join_14() { + assert_eq!( + cake::Entity::find() + .join(JoinType::LeftJoin, cake::Relation::TropicalFruit.def()) + .build(DbBackend::MySql) + .to_string(), + [ + "SELECT `cake`.`id`, `cake`.`name` FROM `cake`", + "LEFT JOIN `fruit` ON `cake`.`id` = `fruit`.`cake_id` AND `fruit`.`name` LIKE '%tropical%'", + ] + .join(" ") + ); + } + + #[test] + fn join_15() { + let cake_model = cake::Model { + id: 18, + name: "".to_owned(), + }; + + assert_eq!( + cake_model + .find_linked(entity_linked::CheeseCakeToFillingVendor) + .build(DbBackend::MySql) + .to_string(), + [ + r#"SELECT `vendor`.`id`, `vendor`.`name`"#, + r#"FROM `vendor`"#, + r#"INNER JOIN `filling` AS `r0` ON `r0`.`vendor_id` = `vendor`.`id`"#, + r#"INNER JOIN `cake_filling` AS `r1` ON `r1`.`filling_id` = `r0`.`id`"#, + r#"INNER JOIN `cake` AS `r2` ON `r2`.`id` = `r1`.`cake_id` AND `r2`.`name` LIKE '%cheese%'"#, + r#"WHERE `r2`.`id` = 18"#, + ] + .join(" ") + ); + } + + #[test] + fn join_16() { + let cake_model = cake::Model { + id: 18, + name: "".to_owned(), + }; + assert_eq!( + cake_model + .find_linked(entity_linked::JoinWithoutReverse) + .build(DbBackend::MySql) + .to_string(), + [ + r#"SELECT `vendor`.`id`, `vendor`.`name`"#, + r#"FROM `vendor`"#, + r#"INNER JOIN `filling` AS `r0` ON `r0`.`vendor_id` = `vendor`.`id`"#, + r#"INNER JOIN `cake_filling` AS `r1` ON `r1`.`filling_id` = `r0`.`id`"#, + r#"INNER JOIN `cake_filling` AS `r2` ON `r2`.`cake_id` = `r1`.`id` AND `r2`.`name` LIKE '%cheese%'"#, + r#"WHERE `r2`.`id` = 18"#, + ] + .join(" ") + ); + } + + #[test] + fn join_17() { + assert_eq!( + cake::Entity::find() + .find_also_linked(entity_linked::CheeseCakeToFillingVendor) + .build(DbBackend::MySql) + .to_string(), + [ + r#"SELECT `cake`.`id` AS `A_id`, `cake`.`name` AS `A_name`,"#, + r#"`r2`.`id` AS `B_id`, `r2`.`name` AS `B_name`"#, + r#"FROM `cake`"#, + r#"LEFT JOIN `cake_filling` AS `r0` ON `cake`.`id` = `r0`.`cake_id` AND `cake`.`name` LIKE '%cheese%'"#, + r#"LEFT JOIN `filling` AS `r1` ON `r0`.`filling_id` = `r1`.`id`"#, + r#"LEFT JOIN `vendor` AS `r2` ON `r1`.`vendor_id` = `r2`.`id`"#, + ] + .join(" ") + ); + } + + #[test] + fn join_18() { + assert_eq!( + cake::Entity::find() + .find_also_linked(entity_linked::JoinWithoutReverse) + .build(DbBackend::MySql) + .to_string(), + [ + r#"SELECT `cake`.`id` AS `A_id`, `cake`.`name` AS `A_name`,"#, + r#"`r2`.`id` AS `B_id`, `r2`.`name` AS `B_name`"#, + r#"FROM `cake`"#, + r#"LEFT JOIN `cake` AS `r0` ON `cake_filling`.`cake_id` = `r0`.`id` AND `cake_filling`.`name` LIKE '%cheese%'"#, + r#"LEFT JOIN `filling` AS `r1` ON `r0`.`filling_id` = `r1`.`id`"#, + r#"LEFT JOIN `vendor` AS `r2` ON `r1`.`vendor_id` = `r2`.`id`"#, + ] + .join(" ") + ); + } + + #[test] + fn join_19() { + assert_eq!( + cake::Entity::find() + .join(JoinType::LeftJoin, cake::Relation::TropicalFruit.def()) + .join( + JoinType::LeftJoin, + cake_filling::Relation::Cake + .def() + .rev() + .on_condition(|_left, right| { + Expr::tbl(right, cake_filling::Column::CakeId) + .gt(10) + .into_condition() + }) + ) + .join( + JoinType::LeftJoin, + cake_filling::Relation::Filling + .def() + .on_condition(|_left, right| { + Expr::tbl(right, filling::Column::Name) + .like("%lemon%") + .into_condition() + }) + ) + .join(JoinType::LeftJoin, filling::Relation::Vendor.def()) + .build(DbBackend::MySql) + .to_string(), + [ + "SELECT `cake`.`id`, `cake`.`name` FROM `cake`", + "LEFT JOIN `fruit` ON `cake`.`id` = `fruit`.`cake_id` AND `fruit`.`name` LIKE '%tropical%'", + "LEFT JOIN `cake_filling` ON `cake`.`id` = `cake_filling`.`cake_id` AND `cake_filling`.`cake_id` > 10", + "LEFT JOIN `filling` ON `cake_filling`.`filling_id` = `filling`.`id` AND `filling`.`name` LIKE '%lemon%'", + "LEFT JOIN `vendor` ON `filling`.`vendor_id` = `vendor`.`id`", + ] + .join(" ") + ); + } } diff --git a/src/tests_cfg/cake.rs b/src/tests_cfg/cake.rs index eb01ed5c..875bd506 100644 --- a/src/tests_cfg/cake.rs +++ b/src/tests_cfg/cake.rs @@ -18,6 +18,11 @@ pub struct Model { pub enum Relation { #[sea_orm(has_many = "super::fruit::Entity")] Fruit, + #[sea_orm( + has_many = "super::fruit::Entity", + on_condition = r#"super::fruit::Column::Name.like("%tropical%")"# + )] + TropicalFruit, } impl Related for Entity { diff --git a/src/tests_cfg/entity_linked.rs b/src/tests_cfg/entity_linked.rs index a4057a6c..92da57c1 100644 --- a/src/tests_cfg/entity_linked.rs +++ b/src/tests_cfg/entity_linked.rs @@ -1,4 +1,5 @@ use crate::entity::prelude::*; +use sea_query::{Expr, IntoCondition}; #[derive(Debug)] pub struct CakeToFilling; @@ -32,3 +33,50 @@ impl Linked for CakeToFillingVendor { ] } } + +#[derive(Debug)] +pub struct CheeseCakeToFillingVendor; + +impl Linked for CheeseCakeToFillingVendor { + type FromEntity = super::cake::Entity; + + type ToEntity = super::vendor::Entity; + + fn link(&self) -> Vec { + vec![ + super::cake_filling::Relation::Cake + .def() + .on_condition(|left, _right| { + Expr::tbl(left, super::cake::Column::Name) + .like("%cheese%") + .into_condition() + }) + .rev(), + super::cake_filling::Relation::Filling.def(), + super::filling::Relation::Vendor.def(), + ] + } +} + +#[derive(Debug)] +pub struct JoinWithoutReverse; + +impl Linked for JoinWithoutReverse { + type FromEntity = super::cake::Entity; + + type ToEntity = super::vendor::Entity; + + fn link(&self) -> Vec { + vec![ + super::cake_filling::Relation::Cake + .def() + .on_condition(|left, _right| { + Expr::tbl(left, super::cake::Column::Name) + .like("%cheese%") + .into_condition() + }), + super::cake_filling::Relation::Filling.def(), + super::filling::Relation::Vendor.def(), + ] + } +}