diff --git a/src/query/loader.rs b/src/query/loader.rs index 5beee927..56f263de 100644 --- a/src/query/loader.rs +++ b/src/query/loader.rs @@ -20,13 +20,28 @@ pub trait LoaderTrait { R::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related; - /// Used to eager load has_many relations + /// Used to eager load has_many relations async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related; + + /// Used to eager load many_to_many relations + async fn load_many_to_many( + &self, + stmt: Select, + via: V, + db: &C, + ) -> Result>, DbErr> + where + C: ConnectionTrait, + R: EntityTrait, + R::Model: Send + Sync, + V: EntityTrait, + V::Model: Send + Sync, + <::Model as ModelTrait>::Entity: Related; } #[async_trait] @@ -55,6 +70,23 @@ where { self.as_slice().load_many(stmt, db).await } + + async fn load_many_to_many( + &self, + stmt: Select, + via: V, + db: &C, + ) -> Result>, DbErr> + where + C: ConnectionTrait, + R: EntityTrait, + R::Model: Send + Sync, + V: EntityTrait, + V::Model: Send + Sync, + <::Model as ModelTrait>::Entity: Related, + { + self.as_slice().load_many_to_many(stmt, via, db).await + } } #[async_trait] @@ -73,11 +105,11 @@ where { // we verify that is HasOne relation if <<::Model as ModelTrait>::Entity as Related>::via().is_some() { - return Err(type_err("Relation is ManytoMany instead of HasOne")); + return Err(query_err("Relation is ManytoMany instead of HasOne")); } let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); if rel_def.rel_type == RelationType::HasMany { - return Err(type_err("Relation is HasMany instead of HasOne")); + return Err(query_err("Relation is HasMany instead of HasOne")); } let keys: Vec = self @@ -123,11 +155,11 @@ where // we verify that is HasMany relation if <<::Model as ModelTrait>::Entity as Related>::via().is_some() { - return Err(type_err("Relation is ManyToMany instead of HasMany")); + return Err(query_err("Relation is ManyToMany instead of HasMany")); } let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); if rel_def.rel_type == RelationType::HasOne { - return Err(type_err("Relation is HasOne instead of HasMany")); + return Err(query_err("Relation is HasOne instead of HasMany")); } let keys: Vec = self @@ -172,6 +204,106 @@ where Ok(result) } + + async fn load_many_to_many( + &self, + stmt: Select, + via: V, + db: &C, + ) -> Result>, DbErr> + where + C: ConnectionTrait, + R: EntityTrait, + R::Model: Send + Sync, + V: EntityTrait, + V::Model: Send + Sync, + <::Model as ModelTrait>::Entity: Related, + { + if let Some(via_rel) = + <<::Model as ModelTrait>::Entity as Related>::via() + { + let rel_def = + <<::Model as ModelTrait>::Entity as Related>::to(); + if rel_def.rel_type != RelationType::HasOne { + return Err(query_err("Relation to is not HasOne")); + } + + if !cmp_table_ref(&via_rel.to_tbl, &via.table_ref()) { + return Err(query_err(format!( + "The given via Entity is incorrect: expected: {:?}, given: {:?}", + via_rel.to_tbl, + via.table_ref() + ))); + } + + let pkeys: Vec = self + .iter() + .map(|model: &M| extract_key(&via_rel.from_col, model)) + .collect(); + + // Map of M::PK -> Vec + let mut keymap: HashMap> = Default::default(); + + let keys: Vec = { + let condition = prepare_condition(&via_rel.to_tbl, &via_rel.to_col, &pkeys); + let stmt = V::find().filter(condition); + let data = stmt.all(db).await?; + data.into_iter().for_each(|model| { + let pk = format!("{:?}", extract_key(&via_rel.to_col, &model)); + let entry = keymap.entry(pk).or_default(); + + let fk = extract_key(&rel_def.from_col, &model); + entry.push(fk); + }); + + keymap.values().flatten().cloned().collect() + }; + + let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); + + let stmt = as QueryFilter>::filter(stmt, condition); + + let data = stmt.all(db).await?; + // Map of R::PK -> R::Model + let data: HashMap::Model> = data + .into_iter() + .map(|model| { + let key = format!("{:?}", extract_key(&rel_def.to_col, &model)); + (key, model) + }) + .collect(); + + let result: Vec> = pkeys + .into_iter() + .map(|pkey| { + let fkeys = keymap + .get(&format!("{pkey:?}")) + .cloned() + .unwrap_or_default(); + + let models: Vec<_> = fkeys + .into_iter() + .map(|fkey| { + data.get(&format!("{fkey:?}")) + .cloned() + .expect("Failed at finding key on hashmap") + }) + .collect(); + + models + }) + .collect(); + + Ok(result) + } else { + return Err(query_err("Relation is not ManyToMany")); + } + } +} + +fn cmp_table_ref(left: &TableRef, right: &TableRef) -> bool { + // not ideal; but + format!("{left:?}") == format!("{right:?}") } fn extract_key(target_col: &Identity, model: &Model) -> ValueTuple diff --git a/tests/loader_tests.rs b/tests/loader_tests.rs index ba22d618..bc0b5467 100644 --- a/tests/loader_tests.rs +++ b/tests/loader_tests.rs @@ -13,12 +13,12 @@ async fn loader_load_one() -> Result<(), DbErr> { let ctx = TestContext::new("loader_test_load_one").await; create_tables(&ctx.db).await?; - let bakery = insert_bakery(&ctx.db, "SeaSide Bakery").await?; + let bakery_0 = insert_bakery(&ctx.db, "SeaSide Bakery").await?; - let baker_1 = insert_baker(&ctx.db, "Baker 1", bakery.id).await?; - - let baker_2 = baker::ActiveModel { - name: Set("Baker 2".to_owned()), + let baker_1 = insert_baker(&ctx.db, "Baker 1", bakery_0.id).await?; + let baker_2 = insert_baker(&ctx.db, "Baker 2", bakery_0.id).await?; + let baker_3 = baker::ActiveModel { + name: Set("Baker 3".to_owned()), contact_details: Set(serde_json::json!({})), bakery_id: Set(None), ..Default::default() @@ -29,34 +29,8 @@ async fn loader_load_one() -> Result<(), DbErr> { let bakers = baker::Entity::find().all(&ctx.db).await?; let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?; - assert_eq!(bakers, [baker_1, baker_2]); - - assert_eq!(bakeries, [Some(bakery), None]); - - Ok(()) -} - -#[sea_orm_macros::test] -#[cfg(any( - feature = "sqlx-mysql", - feature = "sqlx-sqlite", - feature = "sqlx-postgres" -))] -async fn loader_load_one_complex() -> Result<(), DbErr> { - let ctx = TestContext::new("loader_test_load_one_complex").await; - create_tables(&ctx.db).await?; - - let bakery = insert_bakery(&ctx.db, "SeaSide Bakery").await?; - - let baker_1 = insert_baker(&ctx.db, "Baker 1", bakery.id).await?; - let baker_2 = insert_baker(&ctx.db, "Baker 2", bakery.id).await?; - - let bakers = baker::Entity::find().all(&ctx.db).await?; - let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?; - - assert_eq!(bakers, [baker_1, baker_2]); - - assert_eq!(bakeries, [Some(bakery.clone()), Some(bakery.clone())]); + assert_eq!(bakers, [baker_1, baker_2, baker_3]); + assert_eq!(bakeries, [Some(bakery_0.clone()), Some(bakery_0), None]); Ok(()) } @@ -81,6 +55,18 @@ async fn loader_load_many() -> Result<(), DbErr> { let baker_4 = insert_baker(&ctx.db, "Baker 4", bakery_2.id).await?; let bakeries = bakery::Entity::find().all(&ctx.db).await?; + let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?; + + assert_eq!(bakeries, [bakery_1.clone(), bakery_2.clone()]); + assert_eq!( + bakers, + [ + [baker_1.clone(), baker_2.clone()], + [baker_3.clone(), baker_4.clone()] + ] + ); + + // load bakers again but with additional condition let bakers = bakeries .load_many( @@ -89,11 +75,6 @@ async fn loader_load_many() -> Result<(), DbErr> { ) .await?; - println!("A: {bakers:?}"); - println!("B: {bakeries:?}"); - - assert_eq!(bakeries, [bakery_1.clone(), bakery_2.clone()]); - assert_eq!( bakers, [ @@ -102,16 +83,6 @@ async fn loader_load_many() -> Result<(), DbErr> { ] ); - let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?; - - assert_eq!( - bakers, - [ - [baker_1.clone(), baker_2.clone()], - [baker_3.clone(), baker_4.clone()] - ] - ); - // now, start from baker let bakers = baker::Entity::find().all(&ctx.db).await?; @@ -149,18 +120,15 @@ async fn loader_load_many_multi() -> Result<(), DbErr> { let baker_2 = insert_baker(&ctx.db, "Jane", bakery_1.id).await?; let baker_3 = insert_baker(&ctx.db, "Peter", bakery_2.id).await?; - let cake_1 = insert_cake(&ctx.db, "Cheesecake", bakery_1.id).await?; - let cake_2 = insert_cake(&ctx.db, "Chocolate", bakery_2.id).await?; - let cake_3 = insert_cake(&ctx.db, "Chiffon", bakery_2.id).await?; + let cake_1 = insert_cake(&ctx.db, "Cheesecake", Some(bakery_1.id)).await?; + let cake_2 = insert_cake(&ctx.db, "Chocolate", Some(bakery_2.id)).await?; + let cake_3 = insert_cake(&ctx.db, "Chiffon", Some(bakery_2.id)).await?; + let _cake_4 = insert_cake(&ctx.db, "Apple Pie", None).await?; // no one makes apple pie let bakeries = bakery::Entity::find().all(&ctx.db).await?; let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?; let cakes = bakeries.load_many(cake::Entity::find(), &ctx.db).await?; - println!("{bakers:?}"); - println!("{bakeries:?}"); - println!("{cakes:?}"); - assert_eq!(bakeries, [bakery_1, bakery_2]); assert_eq!(bakers, [vec![baker_1, baker_2], vec![baker_3]]); assert_eq!(cakes, [vec![cake_1], vec![cake_2, cake_3]]); @@ -168,7 +136,6 @@ async fn loader_load_many_multi() -> Result<(), DbErr> { Ok(()) } -#[ignore] #[sea_orm_macros::test] #[cfg(any( feature = "sqlx-mysql", @@ -184,9 +151,10 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> { let baker_1 = insert_baker(&ctx.db, "Jane", bakery_1.id).await?; let baker_2 = insert_baker(&ctx.db, "Peter", bakery_1.id).await?; - let cake_1 = insert_cake(&ctx.db, "Cheesecake", bakery_1.id).await?; - let cake_2 = insert_cake(&ctx.db, "Chocolate", bakery_1.id).await?; - let cake_3 = insert_cake(&ctx.db, "Chiffon", bakery_1.id).await?; + let cake_1 = insert_cake(&ctx.db, "Cheesecake", None).await?; + let cake_2 = insert_cake(&ctx.db, "Chocolate", None).await?; + let cake_3 = insert_cake(&ctx.db, "Chiffon", None).await?; + let cake_4 = insert_cake(&ctx.db, "Apple Pie", None).await?; // no one makes apple pie insert_cake_baker(&ctx.db, baker_1.id, cake_1.id).await?; insert_cake_baker(&ctx.db, baker_1.id, cake_2.id).await?; @@ -194,13 +162,36 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> { insert_cake_baker(&ctx.db, baker_2.id, cake_3.id).await?; let bakers = baker::Entity::find().all(&ctx.db).await?; - let cakes = bakers.load_many(cake::Entity::find(), &ctx.db).await?; + let cakes = bakers + .load_many_to_many(cake::Entity::find(), cakes_bakers::Entity, &ctx.db) + .await?; - println!("{bakers:?}"); - println!("{cakes:?}"); + assert_eq!(bakers, [baker_1.clone(), baker_2.clone()]); + assert_eq!( + cakes, + [ + vec![cake_1.clone(), cake_2.clone()], + vec![cake_2.clone(), cake_3.clone()] + ] + ); - assert_eq!(bakers, [baker_1, baker_2]); - assert_eq!(cakes, [vec![cake_1, cake_2.clone()], vec![cake_2, cake_3]]); + // now, start again from cakes + + let cakes = cake::Entity::find().all(&ctx.db).await?; + let bakers = cakes + .load_many_to_many(baker::Entity::find(), cakes_bakers::Entity, &ctx.db) + .await?; + + assert_eq!(cakes, [cake_1, cake_2, cake_3, cake_4]); + assert_eq!( + bakers, + [ + vec![baker_1.clone()], + vec![baker_1.clone(), baker_2.clone()], + vec![baker_2.clone()], + vec![] + ] + ); Ok(()) } @@ -226,12 +217,16 @@ pub async fn insert_baker(db: &DbConn, name: &str, bakery_id: i32) -> Result Result { +pub async fn insert_cake( + db: &DbConn, + name: &str, + bakery_id: Option, +) -> Result { cake::ActiveModel { name: Set(name.to_owned()), price: Set(rust_decimal::Decimal::ONE), gluten_free: Set(false), - bakery_id: Set(Some(bakery_id)), + bakery_id: Set(bakery_id), ..Default::default() } .insert(db)