diff --git a/src/entity/base_entity.rs b/src/entity/base_entity.rs index 0c9813f7..79caaf0b 100644 --- a/src/entity/base_entity.rs +++ b/src/entity/base_entity.rs @@ -293,7 +293,7 @@ pub trait EntityTrait: EntityName { /// assert_eq!( /// db.into_transaction_log(), /// vec![Transaction::from_sql_and_values( - /// DbBackend::Postgres, r#"INSERT INTO "cake" ("name") VALUES ($1)"#, vec!["Apple Pie".into()] + /// DbBackend::Postgres, r#"INSERT INTO "cake" ("name") VALUES ($1) RETURNING "id""#, vec!["Apple Pie".into()] /// )]); /// ``` fn insert(model: A) -> Insert @@ -346,7 +346,7 @@ pub trait EntityTrait: EntityName { /// assert_eq!( /// db.into_transaction_log(), /// vec![Transaction::from_sql_and_values( - /// DbBackend::Postgres, r#"INSERT INTO "cake" ("name") VALUES ($1), ($2)"#, + /// DbBackend::Postgres, r#"INSERT INTO "cake" ("name") VALUES ($1), ($2) RETURNING "id""#, /// vec!["Apple Pie".into(), "Orange Scone".into()] /// )]); /// ``` diff --git a/src/entity/link.rs b/src/entity/link.rs index 97c9af1b..d0c262d7 100644 --- a/src/entity/link.rs +++ b/src/entity/link.rs @@ -1,5 +1,7 @@ -use crate::{EntityTrait, QuerySelect, RelationDef, Select}; -use sea_query::JoinType; +use crate::{ + join_tbl_on_condition, unpack_table_ref, EntityTrait, QuerySelect, RelationDef, Select, +}; +use sea_query::{Alias, IntoIden, JoinType, SeaRc}; pub type LinkDef = RelationDef; @@ -12,8 +14,20 @@ pub trait Linked { fn find_linked(&self) -> Select { let mut select = Select::new(); - for rel in self.link().into_iter().rev() { - select = select.join_rev(JoinType::InnerJoin, rel); + for (i, rel) in self.link().into_iter().rev().enumerate() { + let from_tbl = Alias::new(&format!("r{}", i)).into_iden(); + let to_tbl = if i > 0 { + Alias::new(&format!("r{}", i - 1)).into_iden() + } else { + unpack_table_ref(&rel.to_tbl) + }; + + select.query().join_as( + JoinType::InnerJoin, + unpack_table_ref(&rel.from_tbl), + SeaRc::clone(&from_tbl), + join_tbl_on_condition(from_tbl, to_tbl, rel.from_col, rel.to_col), + ); } select } diff --git a/src/entity/model.rs b/src/entity/model.rs index c6129ad7..ce4a3a4c 100644 --- a/src/entity/model.rs +++ b/src/entity/model.rs @@ -24,7 +24,8 @@ pub trait ModelTrait: Clone + Send + Debug { where L: Linked, { - l.find_linked().belongs_to(self) + let tbl_alias = &format!("r{}", l.link().len() - 1); + l.find_linked().belongs_to_tbl_alias(self, tbl_alias) } } diff --git a/src/executor/insert.rs b/src/executor/insert.rs index aa8905be..630ee6a7 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,5 +1,5 @@ use crate::{ - error::*, ActiveModelTrait, DatabaseConnection, EntityTrait, Insert, PrimaryKeyTrait, + error::*, ActiveModelTrait, DatabaseConnection, DbBackend, EntityTrait, Insert, PrimaryKeyTrait, Statement, TryFromU64, }; use sea_query::InsertStatement; @@ -38,8 +38,7 @@ where // TODO: extract primary key's value from query // so that self is dropped before entering await let mut query = self.query; - #[cfg(feature = "sqlx-postgres")] - if let DatabaseConnection::SqlxPostgresPoolConnection(_) = db { + if db.get_database_backend() == DbBackend::Postgres { use crate::{sea_query::Query, Iterable}; if ::PrimaryKey::iter().count() > 0 { query.returning( diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index 24822111..608d9dc1 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -1,4 +1,4 @@ -use crate::{error::*, DatabaseConnection, SelectorTrait}; +use crate::{error::*, DatabaseConnection, DbBackend, SelectorTrait}; use async_stream::stream; use futures::Stream; use sea_query::{Alias, Expr, SelectStatement}; @@ -63,11 +63,8 @@ where Some(res) => res, None => return Ok(0), }; - let num_items = match self.db { - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(_) => { - result.try_get::("", "num_items")? as usize - } + let num_items = match builder { + DbBackend::Postgres => result.try_get::("", "num_items")? as usize, _ => result.try_get::("", "num_items")? as usize, }; Ok(num_items) @@ -192,7 +189,7 @@ mod tests { (db, vec![page1, page2, page3]) } - fn setup_num_items() -> (DatabaseConnection, i32) { + fn setup_num_items() -> (DatabaseConnection, i64) { let num_items = 3; let db = MockDatabase::new(DbBackend::Postgres) .append_query_results(vec![vec![maplit::btreemap! { diff --git a/src/query/helper.rs b/src/query/helper.rs index d93595f1..e6df3637 100644 --- a/src/query/helper.rs +++ b/src/query/helper.rs @@ -2,10 +2,11 @@ use crate::{ ColumnTrait, EntityTrait, Identity, IntoIdentity, IntoSimpleExpr, Iterable, ModelTrait, PrimaryKeyToColumn, RelationDef, }; -pub use sea_query::{Condition, ConditionalStatement, DynIden, JoinType, Order, OrderedStatement}; use sea_query::{ - Expr, IntoCondition, LockType, SeaRc, SelectExpr, SelectStatement, SimpleExpr, TableRef, + Alias, Expr, Iden, IntoCondition, LockType, SeaRc, SelectExpr, SelectStatement, SimpleExpr, + TableRef, }; +pub use sea_query::{Condition, ConditionalStatement, DynIden, JoinType, Order, OrderedStatement}; // LINT: when the column does not appear in tables selected from // LINT: when there is a group by clause, but some columns don't have aggregate functions @@ -287,14 +288,35 @@ pub trait QueryFilter: Sized { } self } + + fn belongs_to_tbl_alias(mut self, model: &M, tbl_alias: &str) -> Self + where + M: ModelTrait, + { + for key in ::PrimaryKey::iter() { + let col = key.into_column(); + let expr = Expr::tbl(Alias::new(tbl_alias), col).eq(model.get(col)); + self = self.filter(expr); + } + self + } } -fn join_condition(rel: RelationDef) -> SimpleExpr { +pub(crate) fn join_condition(rel: RelationDef) -> SimpleExpr { let from_tbl = unpack_table_ref(&rel.from_tbl); let to_tbl = unpack_table_ref(&rel.to_tbl); let owner_keys = rel.from_col; let foreign_keys = rel.to_col; + join_tbl_on_condition(from_tbl, to_tbl, owner_keys, foreign_keys) +} + +pub(crate) fn join_tbl_on_condition( + from_tbl: SeaRc, + to_tbl: SeaRc, + owner_keys: Identity, + foreign_keys: Identity, +) -> SimpleExpr { match (owner_keys, foreign_keys) { (Identity::Unary(o1), Identity::Unary(f1)) => { Expr::tbl(SeaRc::clone(&from_tbl), o1).equals(SeaRc::clone(&to_tbl), f1) diff --git a/src/query/join.rs b/src/query/join.rs index 19dc3160..50592933 100644 --- a/src/query/join.rs +++ b/src/query/join.rs @@ -74,8 +74,9 @@ where #[cfg(test)] mod tests { - use crate::tests_cfg::{cake, cake_filling, cake_filling_price, filling, fruit}; + use crate::tests_cfg::{cake, cake_filling, cake_filling_price, entity_linked, filling, fruit}; use crate::{ColumnTrait, DbBackend, EntityTrait, ModelTrait, QueryFilter, QueryTrait}; + use pretty_assertions::assert_eq; #[test] fn join_1() { @@ -188,7 +189,7 @@ mod tests { assert_eq!( find_filling.build(DbBackend::MySql).to_string(), [ - "SELECT `filling`.`id`, `filling`.`name` FROM `filling`", + "SELECT `filling`.`id`, `filling`.`name`, `filling`.`vendor_id` FROM `filling`", "INNER JOIN `cake_filling` ON `cake_filling`.`filling_id` = `filling`.`id`", "INNER JOIN `cake` ON `cake`.`id` = `cake_filling`.`cake_id`", ] @@ -243,15 +244,15 @@ mod tests { assert_eq!( cake_model - .find_linked(cake::CakeToFilling) + .find_linked(entity_linked::CakeToFilling) .build(DbBackend::MySql) .to_string(), [ - r#"SELECT `filling`.`id`, `filling`.`name`"#, + r#"SELECT `filling`.`id`, `filling`.`name`, `filling`.`vendor_id`"#, r#"FROM `filling`"#, - r#"INNER JOIN `cake_filling` ON `cake_filling`.`filling_id` = `filling`.`id`"#, - r#"INNER JOIN `cake` ON `cake`.`id` = `cake_filling`.`cake_id`"#, - r#"WHERE `cake`.`id` = 12"#, + r#"INNER JOIN `cake_filling` AS `r0` ON `r0`.`filling_id` = `filling`.`id`"#, + r#"INNER JOIN `cake` AS `r1` ON `r1`.`id` = `r0`.`cake_id`"#, + r#"WHERE `r1`.`id` = 12"#, ] .join(" ") ); @@ -259,14 +260,38 @@ mod tests { #[test] fn join_11() { + let cake_model = cake::Model { + id: 18, + name: "".to_owned(), + }; + + assert_eq!( + cake_model + .find_linked(entity_linked::CakeToFillingVendor) + .build(DbBackend::MySql) + .to_string(), + [ + r#"SELECT `vendor`.`id`, `vendor`.`name`"#, + r#"FROM `vendor`"#, + r#"INNER JOIN `filling` AS `r0` ON `r0`.`vendor_id` = `vendor`.`id`"#, + r#"INNER JOIN `cake_filling` AS `r1` ON `r1`.`filling_id` = `r0`.`id`"#, + r#"INNER JOIN `cake` AS `r2` ON `r2`.`id` = `r1`.`cake_id`"#, + r#"WHERE `r2`.`id` = 18"#, + ] + .join(" ") + ); + } + + #[test] + fn join_12() { assert_eq!( cake::Entity::find() - .find_also_linked(cake::CakeToFilling) + .find_also_linked(entity_linked::CakeToFilling) .build(DbBackend::MySql) .to_string(), [ r#"SELECT `cake`.`id` AS `A_id`, `cake`.`name` AS `A_name`,"#, - r#"`filling`.`id` AS `B_id`, `filling`.`name` AS `B_name`"#, + r#"`filling`.`id` AS `B_id`, `filling`.`name` AS `B_name`, `filling`.`vendor_id` AS `B_vendor_id`"#, r#"FROM `cake`"#, r#"LEFT JOIN `cake_filling` ON `cake`.`id` = `cake_filling`.`cake_id`"#, r#"LEFT JOIN `filling` ON `cake_filling`.`filling_id` = `filling`.`id`"#, @@ -274,4 +299,23 @@ mod tests { .join(" ") ); } + + #[test] + fn join_13() { + assert_eq!( + cake::Entity::find() + .find_also_linked(entity_linked::CakeToFillingVendor) + .build(DbBackend::MySql) + .to_string(), + [ + r#"SELECT `cake`.`id` AS `A_id`, `cake`.`name` AS `A_name`,"#, + r#"`vendor`.`id` AS `B_id`, `vendor`.`name` AS `B_name`"#, + r#"FROM `cake`"#, + r#"LEFT JOIN `cake_filling` ON `cake`.`id` = `cake_filling`.`cake_id`"#, + r#"LEFT JOIN `filling` ON `cake_filling`.`filling_id` = `filling`.`id`"#, + r#"LEFT JOIN `vendor` ON `filling`.`vendor_id` = `vendor`.`id`"#, + ] + .join(" ") + ); + } } diff --git a/src/tests_cfg/cake.rs b/src/tests_cfg/cake.rs index 920e1fea..8628492b 100644 --- a/src/tests_cfg/cake.rs +++ b/src/tests_cfg/cake.rs @@ -31,20 +31,4 @@ impl Related for Entity { } } -#[derive(Debug)] -pub struct CakeToFilling; - -impl Linked for CakeToFilling { - type FromEntity = Entity; - - type ToEntity = super::filling::Entity; - - fn link(&self) -> Vec { - vec![ - super::cake_filling::Relation::Cake.def().rev(), - super::cake_filling::Relation::Filling.def(), - ] - } -} - impl ActiveModelBehavior for ActiveModel {} diff --git a/src/tests_cfg/cake_expanded.rs b/src/tests_cfg/cake_expanded.rs index 0eeb0738..b4306313 100644 --- a/src/tests_cfg/cake_expanded.rs +++ b/src/tests_cfg/cake_expanded.rs @@ -75,20 +75,4 @@ impl Related for Entity { } } -#[derive(Debug)] -pub struct CakeToFilling; - -impl Linked for CakeToFilling { - type FromEntity = Entity; - - type ToEntity = super::filling::Entity; - - fn link(&self) -> Vec { - vec![ - super::cake_filling::Relation::Cake.def().rev(), - super::cake_filling::Relation::Filling.def(), - ] - } -} - impl ActiveModelBehavior for ActiveModel {} diff --git a/src/tests_cfg/entity_linked.rs b/src/tests_cfg/entity_linked.rs new file mode 100644 index 00000000..a4057a6c --- /dev/null +++ b/src/tests_cfg/entity_linked.rs @@ -0,0 +1,34 @@ +use crate::entity::prelude::*; + +#[derive(Debug)] +pub struct CakeToFilling; + +impl Linked for CakeToFilling { + type FromEntity = super::cake::Entity; + + type ToEntity = super::filling::Entity; + + fn link(&self) -> Vec { + vec![ + super::cake_filling::Relation::Cake.def().rev(), + super::cake_filling::Relation::Filling.def(), + ] + } +} + +#[derive(Debug)] +pub struct CakeToFillingVendor; + +impl Linked for CakeToFillingVendor { + type FromEntity = super::cake::Entity; + + type ToEntity = super::vendor::Entity; + + fn link(&self) -> Vec { + vec![ + super::cake_filling::Relation::Cake.def().rev(), + super::cake_filling::Relation::Filling.def(), + super::filling::Relation::Vendor.def(), + ] + } +} diff --git a/src/tests_cfg/filling.rs b/src/tests_cfg/filling.rs index 14f7a849..4de591c5 100644 --- a/src/tests_cfg/filling.rs +++ b/src/tests_cfg/filling.rs @@ -9,6 +9,7 @@ pub struct Entity; pub struct Model { pub id: i32, pub name: String, + pub vendor_id: Option, #[sea_orm(ignore)] pub ignored_attr: i32, } @@ -18,6 +19,7 @@ pub struct Model { pub enum Column { Id, Name, + VendorId, } // Then, customize each column names here. @@ -46,7 +48,9 @@ impl PrimaryKeyTrait for PrimaryKey { } #[derive(Copy, Clone, Debug, EnumIter)] -pub enum Relation {} +pub enum Relation { + Vendor, +} impl ColumnTrait for Column { type EntityName = Entity; @@ -55,13 +59,19 @@ impl ColumnTrait for Column { match self { Self::Id => ColumnType::Integer.def(), Self::Name => ColumnType::String(None).def(), + Self::VendorId => ColumnType::Integer.def().nullable(), } } } impl RelationTrait for Relation { fn def(&self) -> RelationDef { - panic!() + match self { + Self::Vendor => Entity::belongs_to(super::vendor::Entity) + .from(Column::VendorId) + .to(super::vendor::Column::Id) + .into(), + } } } diff --git a/src/tests_cfg/mod.rs b/src/tests_cfg/mod.rs index 81d553ac..6bc86aed 100644 --- a/src/tests_cfg/mod.rs +++ b/src/tests_cfg/mod.rs @@ -4,8 +4,10 @@ pub mod cake; pub mod cake_expanded; pub mod cake_filling; pub mod cake_filling_price; +pub mod entity_linked; pub mod filling; pub mod fruit; +pub mod vendor; pub use cake::Entity as Cake; pub use cake_expanded::Entity as CakeExpanded; @@ -13,3 +15,4 @@ pub use cake_filling::Entity as CakeFilling; pub use cake_filling_price::Entity as CakeFillingPrice; pub use filling::Entity as Filling; pub use fruit::Entity as Fruit; +pub use vendor::Entity as Vendor; diff --git a/src/tests_cfg/vendor.rs b/src/tests_cfg/vendor.rs new file mode 100644 index 00000000..12b6affb --- /dev/null +++ b/src/tests_cfg/vendor.rs @@ -0,0 +1,27 @@ +use crate as sea_orm; +use crate::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "vendor")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, +} + +#[derive(Copy, Clone, Debug, EnumIter)] +pub enum Relation {} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + panic!() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + super::filling::Relation::Vendor.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/tests/relational_tests.rs b/tests/relational_tests.rs index caef9c3c..e3138bb3 100644 --- a/tests/relational_tests.rs +++ b/tests/relational_tests.rs @@ -708,6 +708,26 @@ pub async fn linked() -> Result<(), DbErr> { ] ); + let baker_bob = Baker::find() + .filter(baker::Column::Id.eq(1)) + .one(&ctx.db) + .await? + .unwrap(); + + let baker_bob_customers = baker_bob + .find_linked(baker::BakedForCustomer) + .all(&ctx.db) + .await?; + + assert_eq!( + baker_bob_customers, + vec![customer::Model { + id: 2, + name: "Kara".to_owned(), + notes: Some("Loves all cakes".to_owned()), + }] + ); + ctx.delete().await; Ok(())