diff --git a/Cargo.toml b/Cargo.toml index fc5dffa0..a10d74ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ tracing = { version = "0.1", default-features = false, features = ["attributes", rust_decimal = { version = "1", default-features = false, optional = true } bigdecimal = { version = "0.3", default-features = false, optional = true } sea-orm-macros = { version = "0.12.0-rc.3", path = "sea-orm-macros", default-features = false, features = ["strum"] } -sea-query = { version = "0.29.0-rc.2", features = ["thread-safe"] } +sea-query = { version = "0.29.0-rc.2", features = ["thread-safe", "hashable-value"] } sea-query-binder = { version = "0.4.0-rc.2", default-features = false, optional = true } strum = { version = "0.24", default-features = false } serde = { version = "1.0", default-features = false } diff --git a/src/executor/select.rs b/src/executor/select.rs index 73815f1e..32f633f8 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -4,9 +4,9 @@ use crate::{ SelectB, SelectTwo, SelectTwoMany, Statement, StreamTrait, TryGetableMany, }; use futures::{Stream, TryStreamExt}; -use sea_query::SelectStatement; -use std::marker::PhantomData; -use std::pin::Pin; +use sea_query::{SelectStatement, Value}; +use std::collections::HashMap; +use std::{hash::Hash, marker::PhantomData, pin::Pin}; #[cfg(feature = "with-json")] use crate::JsonValue; @@ -993,6 +993,86 @@ where fn consolidate_query_result( rows: Vec<(L::Model, Option)>, ) -> Vec<(L::Model, Vec)> +where + L: EntityTrait, + R: EntityTrait, +{ + // This is a strong point to consider adding a trait associated constant + // to PrimaryKeyTrait to indicate the arity + let pkcol: Vec<_> = ::iter() + .map(|pk| pk.into_column()) + .collect(); + if pkcol.len() == 1 { + consolidate_query_result_of::>(rows, UnitPk(pkcol[0])) + } else { + consolidate_query_result_of::>(rows, TuplePk(pkcol)) + } +} + +trait ModelKey { + type Type: Hash + PartialEq + Eq; + fn get(&self, model: &E::Model) -> Self::Type; +} + +// This could have been an array of [E::Column; ::ARITY] +struct UnitPk(E::Column); +struct TuplePk(Vec); + +impl ModelKey for UnitPk { + type Type = Value; + fn get(&self, model: &E::Model) -> Self::Type { + model.get(self.0) + } +} + +impl ModelKey for TuplePk { + type Type = Vec; + fn get(&self, model: &E::Model) -> Self::Type { + let mut key = Vec::new(); + for col in self.0.iter() { + key.push(model.get(*col)); + } + key + } +} + +fn consolidate_query_result_of>( + mut rows: Vec<(L::Model, Option)>, + model_key: KEY, +) -> Vec<(L::Model, Vec)> +where + L: EntityTrait, + R: EntityTrait, +{ + let mut hashmap: HashMap> = + rows.iter_mut().fold(HashMap::new(), |mut acc, row| { + let key = model_key.get(&row.0); + if let Some(value) = row.1.take() { + let vec: Option<&mut Vec> = acc.get_mut(&key); + if let Some(vec) = vec { + vec.push(value) + } else { + acc.insert(key, vec![value]); + } + } + + acc + }); + + rows.into_iter() + .filter_map(|(l_model, _)| { + let l_pk = model_key.get(&l_model); + let r_models = hashmap.remove(&l_pk); + r_models.map(|r_models| (l_model, r_models)) + }) + .collect() +} + +/// This is the legacy consolidate algorithm. Kept for reference +#[allow(dead_code)] +fn consolidate_query_result_of_ordered_rows( + rows: Vec<(L::Model, Option)>, +) -> Vec<(L::Model, Vec)> where L: EntityTrait, R: EntityTrait, diff --git a/src/query/combine.rs b/src/query/combine.rs index 9ffb8e6e..61d957f7 100644 --- a/src/query/combine.rs +++ b/src/query/combine.rs @@ -119,12 +119,16 @@ where F: EntityTrait, { pub(crate) fn new(query: SelectStatement) -> Self { + Self::new_without_prepare(query) + .prepare_select() + .prepare_order_by() + } + + pub(crate) fn new_without_prepare(query: SelectStatement) -> Self { Self { query, entity: PhantomData, } - .prepare_select() - .prepare_order_by() } fn prepare_select(mut self) -> Self { diff --git a/src/query/join.rs b/src/query/join.rs index 09f2e385..6ed56de0 100644 --- a/src/query/join.rs +++ b/src/query/join.rs @@ -107,6 +107,52 @@ where } select_two } + + /// Left Join with a Linked Entity and select Entity as a `Vec`. + pub fn find_with_linked(self, l: L) -> SelectTwoMany + where + L: Linked, + T: EntityTrait, + { + let mut slf = self; + for (i, mut rel) in l.link().into_iter().enumerate() { + let to_tbl = Alias::new(format!("r{i}")).into_iden(); + let from_tbl = if i > 0 { + Alias::new(format!("r{}", i - 1)).into_iden() + } else { + unpack_table_ref(&rel.from_tbl) + }; + let table_ref = rel.to_tbl; + + let mut condition = Condition::all().add(join_tbl_on_condition( + SeaRc::clone(&from_tbl), + SeaRc::clone(&to_tbl), + rel.from_col, + rel.to_col, + )); + if let Some(f) = rel.on_condition.take() { + condition = condition.add(f(SeaRc::clone(&from_tbl), SeaRc::clone(&to_tbl))); + } + + slf.query() + .join_as(JoinType::LeftJoin, table_ref, to_tbl, condition); + } + slf = slf.apply_alias(SelectA.as_str()); + let mut select_two_many = SelectTwoMany::new_without_prepare(slf.query); + for col in ::iter() { + let alias = format!("{}{}", SelectB.as_str(), col.as_str()); + let expr = Expr::col(( + Alias::new(format!("r{}", l.link().len() - 1)).into_iden(), + col.into_iden(), + )); + select_two_many.query().expr(SelectExpr { + expr: col.select_as(expr), + alias: Some(SeaRc::new(Alias::new(alias))), + window: None, + }); + } + select_two_many + } } #[cfg(test)] diff --git a/tests/relational_tests.rs b/tests/relational_tests.rs index 959f0557..6b0c6950 100644 --- a/tests/relational_tests.rs +++ b/tests/relational_tests.rs @@ -2,6 +2,7 @@ pub mod common; pub use chrono::offset::Utc; pub use common::{bakery_chain::*, setup::*, TestContext}; +use pretty_assertions::assert_eq; pub use rust_decimal::prelude::*; pub use rust_decimal_macros::dec; use sea_orm::{entity::*, query::*, DbErr, DerivePartialModel, FromQueryResult}; @@ -747,6 +748,85 @@ pub async fn linked() -> Result<(), DbErr> { }] ); + let select_baker_with_customer = Baker::find() + .find_with_linked(baker::BakedForCustomer) + .order_by_asc(baker::Column::Id) + .order_by_asc(Expr::col((Alias::new("r4"), customer::Column::Id))); + + assert_eq!( + select_baker_with_customer + .build(sea_orm::DatabaseBackend::MySql) + .to_string(), + [ + // FIXME: This might be faulty! + "SELECT `baker`.`id` AS `A_id`,", + "`baker`.`name` AS `A_name`,", + "`baker`.`contact_details` AS `A_contact_details`,", + "`baker`.`bakery_id` AS `A_bakery_id`,", + "`r4`.`id` AS `B_id`,", + "`r4`.`name` AS `B_name`,", + "`r4`.`notes` AS `B_notes`", + "FROM `baker`", + "LEFT JOIN `cakes_bakers` AS `r0` ON `baker`.`id` = `r0`.`baker_id`", + "LEFT JOIN `cake` AS `r1` ON `r0`.`cake_id` = `r1`.`id`", + "LEFT JOIN `lineitem` AS `r2` ON `r1`.`id` = `r2`.`cake_id`", + "LEFT JOIN `order` AS `r3` ON `r2`.`order_id` = `r3`.`id`", + "LEFT JOIN `customer` AS `r4` ON `r3`.`customer_id` = `r4`.`id`", + "ORDER BY `baker`.`id` ASC, `r4`.`id` ASC" + ] + .join(" ") + ); + + assert_eq!( + select_baker_with_customer.all(&ctx.db).await?, + [ + ( + baker::Model { + id: 1, + name: "Baker Bob".into(), + contact_details: serde_json::json!({ + "mobile": "+61424000000", + "home": "0395555555", + "address": "12 Test St, Testville, Vic, Australia", + }), + bakery_id: Some(1), + }, + vec![customer::Model { + id: 2, + name: "Kara".into(), + notes: Some("Loves all cakes".into()), + }] + ), + ( + baker::Model { + id: 2, + name: "Baker Bobby".into(), + contact_details: serde_json::json!({ + "mobile": "+85212345678", + }), + bakery_id: Some(1), + }, + vec![ + customer::Model { + id: 1, + name: "Kate".into(), + notes: Some("Loves cheese cake".into()), + }, + customer::Model { + id: 1, + name: "Kate".into(), + notes: Some("Loves cheese cake".into()), + }, + customer::Model { + id: 2, + name: "Kara".into(), + notes: Some("Loves all cakes".into()), + }, + ] + ), + ] + ); + ctx.delete().await; Ok(())