From 56e4b4337b4298f718d93043f203acceae574392 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Thu, 2 Feb 2023 11:21:00 +0800 Subject: [PATCH] Improve API & Example --- examples/basic/src/select.rs | 45 +++++++++++++++++++++-- src/query/loader.rs | 71 +++++++++++++++++++++++++----------- tests/loader_tests.rs | 25 +++++++++---- 3 files changed, 110 insertions(+), 31 deletions(-) diff --git a/examples/basic/src/select.rs b/examples/basic/src/select.rs index 220920ce..611226a2 100644 --- a/examples/basic/src/select.rs +++ b/examples/basic/src/select.rs @@ -10,6 +10,10 @@ pub async fn all_about_select(db: &DbConn) -> Result<(), DbErr> { println!("===== =====\n"); + find_many(db).await?; + + println!("===== =====\n"); + find_one(db).await?; println!("===== =====\n"); @@ -77,6 +81,30 @@ async fn find_together(db: &DbConn) -> Result<(), DbErr> { Ok(()) } +async fn find_many(db: &DbConn) -> Result<(), DbErr> { + print!("find cakes with fruits: "); + + let cakes_with_fruits: Vec<(cake::Model, Vec)> = Cake::find() + .find_with_related(fruit::Entity) + .all(db) + .await?; + + // equivalent; but with a different API + let cakes: Vec = Cake::find().all(db).await?; + let fruits: Vec> = cakes.load_many(fruit::Entity, db).await?; + + println!(); + for (left, right) in cakes_with_fruits + .into_iter() + .zip(cakes.into_iter().zip(fruits.into_iter())) + { + println!("{left:?}\n"); + assert_eq!(left, right); + } + + Ok(()) +} + impl Cake { fn find_by_name(name: &str) -> Select { Self::find().filter(cake::Column::Name.contains(name)) @@ -142,13 +170,24 @@ async fn count_fruits_by_cake(db: &DbConn) -> Result<(), DbErr> { async fn find_many_to_many(db: &DbConn) -> Result<(), DbErr> { print!("find cakes and fillings: "); - let both: Vec<(cake::Model, Vec)> = + let cakes_with_fillings: Vec<(cake::Model, Vec)> = Cake::find().find_with_related(Filling).all(db).await?; + // equivalent; but with a different API + let cakes: Vec = Cake::find().all(db).await?; + let fillings: Vec> = cakes + .load_many_to_many(filling::Entity, cake_filling::Entity, db) + .await?; + println!(); - for bb in both.iter() { - println!("{bb:?}\n"); + for (left, right) in cakes_with_fillings + .into_iter() + .zip(cakes.into_iter().zip(fillings.into_iter())) + { + println!("{left:?}\n"); + assert_eq!(left, right); } + println!(); print!("find fillings for cheese cake: "); diff --git a/src/query/loader.rs b/src/query/loader.rs index 56f263de..90e5e35b 100644 --- a/src/query/loader.rs +++ b/src/query/loader.rs @@ -6,32 +6,40 @@ use async_trait::async_trait; use sea_query::{ColumnRef, DynIden, Expr, IntoColumnRef, SimpleExpr, TableRef, ValueTuple}; use std::{collections::HashMap, str::FromStr}; -/// A trait for basic Dataloader +/// Entity, or a Select; to be used as parameters in [`LoaderTrait`] +pub trait EntityOrSelect: Send { + /// If self is Entity, use Entity::find() + fn select(self) -> Select; +} + +/// This trait implements the Data Loader API #[async_trait] pub trait LoaderTrait { /// Source model type Model: ModelTrait; /// Used to eager load has_one relations - async fn load_one(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_one(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related; /// Used to eager load has_many relations - async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_many(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related; /// Used to eager load many_to_many relations - async fn load_many_to_many( + async fn load_many_to_many( &self, - stmt: Select, + stmt: S, via: V, db: &C, ) -> Result>, DbErr> @@ -39,11 +47,30 @@ pub trait LoaderTrait { C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, V: EntityTrait, V::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related; } +impl EntityOrSelect for E +where + E: EntityTrait, +{ + fn select(self) -> Select { + E::find() + } +} + +impl EntityOrSelect for Select +where + E: EntityTrait, +{ + fn select(self) -> Select { + self + } +} + #[async_trait] impl LoaderTrait for Vec where @@ -51,29 +78,31 @@ where { type Model = M; - async fn load_one(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_one(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related, { self.as_slice().load_one(stmt, db).await } - async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_many(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related, { self.as_slice().load_many(stmt, db).await } - async fn load_many_to_many( + async fn load_many_to_many( &self, - stmt: Select, + stmt: S, via: V, db: &C, ) -> Result>, DbErr> @@ -81,6 +110,7 @@ where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, V: EntityTrait, V::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related, @@ -96,11 +126,12 @@ where { type Model = M; - async fn load_one(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_one(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related, { // we verify that is HasOne relation @@ -119,7 +150,7 @@ where let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); - let stmt = as QueryFilter>::filter(stmt, condition); + let stmt = as QueryFilter>::filter(stmt.select(), condition); let data = stmt.all(db).await?; @@ -145,11 +176,12 @@ where Ok(result) } - async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_many(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related, { // we verify that is HasMany relation @@ -169,7 +201,7 @@ where let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); - let stmt = as QueryFilter>::filter(stmt, condition); + let stmt = as QueryFilter>::filter(stmt.select(), condition); let data = stmt.all(db).await?; @@ -205,9 +237,9 @@ where Ok(result) } - async fn load_many_to_many( + async fn load_many_to_many( &self, - stmt: Select, + stmt: S, via: V, db: &C, ) -> Result>, DbErr> @@ -215,6 +247,7 @@ where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, V: EntityTrait, V::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related, @@ -261,7 +294,7 @@ where let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); - let stmt = as QueryFilter>::filter(stmt, condition); + let stmt = as QueryFilter>::filter(stmt.select(), condition); let data = stmt.all(db).await?; // Map of R::PK -> R::Model @@ -283,11 +316,7 @@ where let models: Vec<_> = fkeys .into_iter() - .map(|fkey| { - data.get(&format!("{fkey:?}")) - .cloned() - .expect("Failed at finding key on hashmap") - }) + .filter_map(|fkey| data.get(&format!("{fkey:?}")).cloned()) .collect(); models diff --git a/tests/loader_tests.rs b/tests/loader_tests.rs index bc0b5467..2db148a2 100644 --- a/tests/loader_tests.rs +++ b/tests/loader_tests.rs @@ -27,7 +27,7 @@ async fn loader_load_one() -> Result<(), DbErr> { .await?; let bakers = baker::Entity::find().all(&ctx.db).await?; - let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?; + let bakeries = bakers.load_one(bakery::Entity, &ctx.db).await?; assert_eq!(bakers, [baker_1, baker_2, baker_3]); assert_eq!(bakeries, [Some(bakery_0.clone()), Some(bakery_0), None]); @@ -55,7 +55,7 @@ async fn loader_load_many() -> Result<(), DbErr> { let baker_4 = insert_baker(&ctx.db, "Baker 4", bakery_2.id).await?; let bakeries = bakery::Entity::find().all(&ctx.db).await?; - let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?; + let bakers = bakeries.load_many(baker::Entity, &ctx.db).await?; assert_eq!(bakeries, [bakery_1.clone(), bakery_2.clone()]); assert_eq!( @@ -126,8 +126,8 @@ async fn loader_load_many_multi() -> Result<(), DbErr> { let _cake_4 = insert_cake(&ctx.db, "Apple Pie", None).await?; // no one makes apple pie let bakeries = bakery::Entity::find().all(&ctx.db).await?; - let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?; - let cakes = bakeries.load_many(cake::Entity::find(), &ctx.db).await?; + let bakers = bakeries.load_many(baker::Entity, &ctx.db).await?; + let cakes = bakeries.load_many(cake::Entity, &ctx.db).await?; assert_eq!(bakeries, [bakery_1, bakery_2]); assert_eq!(bakers, [vec![baker_1, baker_2], vec![baker_3]]); @@ -152,7 +152,7 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> { let baker_2 = insert_baker(&ctx.db, "Peter", bakery_1.id).await?; let cake_1 = insert_cake(&ctx.db, "Cheesecake", None).await?; - let cake_2 = insert_cake(&ctx.db, "Chocolate", None).await?; + let cake_2 = insert_cake(&ctx.db, "Coffee", None).await?; let cake_3 = insert_cake(&ctx.db, "Chiffon", None).await?; let cake_4 = insert_cake(&ctx.db, "Apple Pie", None).await?; // no one makes apple pie @@ -163,7 +163,7 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> { let bakers = baker::Entity::find().all(&ctx.db).await?; let cakes = bakers - .load_many_to_many(cake::Entity::find(), cakes_bakers::Entity, &ctx.db) + .load_many_to_many(cake::Entity, cakes_bakers::Entity, &ctx.db) .await?; assert_eq!(bakers, [baker_1.clone(), baker_2.clone()]); @@ -175,11 +175,22 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> { ] ); + // same, but apply restrictions on cakes + + let cakes = bakers + .load_many_to_many( + cake::Entity::find().filter(cake::Column::Name.like("Ch%")), + cakes_bakers::Entity, + &ctx.db, + ) + .await?; + assert_eq!(cakes, [vec![cake_1.clone()], vec![cake_3.clone()]]); + // now, start again from cakes let cakes = cake::Entity::find().all(&ctx.db).await?; let bakers = cakes - .load_many_to_many(baker::Entity::find(), cakes_bakers::Entity, &ctx.db) + .load_many_to_many(baker::Entity, cakes_bakers::Entity, &ctx.db) .await?; assert_eq!(cakes, [cake_1, cake_2, cake_3, cake_4]);