Improve API & Example

This commit is contained in:
Chris Tsang 2023-02-02 11:21:00 +08:00
parent 83c0732395
commit 56e4b4337b
3 changed files with 110 additions and 31 deletions

View File

@ -10,6 +10,10 @@ pub async fn all_about_select(db: &DbConn) -> Result<(), DbErr> {
println!("===== =====\n");
find_many(db).await?;
println!("===== =====\n");
find_one(db).await?;
println!("===== =====\n");
@ -77,6 +81,30 @@ async fn find_together(db: &DbConn) -> Result<(), DbErr> {
Ok(())
}
async fn find_many(db: &DbConn) -> Result<(), DbErr> {
print!("find cakes with fruits: ");
let cakes_with_fruits: Vec<(cake::Model, Vec<fruit::Model>)> = Cake::find()
.find_with_related(fruit::Entity)
.all(db)
.await?;
// equivalent; but with a different API
let cakes: Vec<cake::Model> = Cake::find().all(db).await?;
let fruits: Vec<Vec<fruit::Model>> = cakes.load_many(fruit::Entity, db).await?;
println!();
for (left, right) in cakes_with_fruits
.into_iter()
.zip(cakes.into_iter().zip(fruits.into_iter()))
{
println!("{left:?}\n");
assert_eq!(left, right);
}
Ok(())
}
impl Cake {
fn find_by_name(name: &str) -> Select<Self> {
Self::find().filter(cake::Column::Name.contains(name))
@ -142,13 +170,24 @@ async fn count_fruits_by_cake(db: &DbConn) -> Result<(), DbErr> {
async fn find_many_to_many(db: &DbConn) -> Result<(), DbErr> {
print!("find cakes and fillings: ");
let both: Vec<(cake::Model, Vec<filling::Model>)> =
let cakes_with_fillings: Vec<(cake::Model, Vec<filling::Model>)> =
Cake::find().find_with_related(Filling).all(db).await?;
// equivalent; but with a different API
let cakes: Vec<cake::Model> = Cake::find().all(db).await?;
let fillings: Vec<Vec<filling::Model>> = cakes
.load_many_to_many(filling::Entity, cake_filling::Entity, db)
.await?;
println!();
for bb in both.iter() {
println!("{bb:?}\n");
for (left, right) in cakes_with_fillings
.into_iter()
.zip(cakes.into_iter().zip(fillings.into_iter()))
{
println!("{left:?}\n");
assert_eq!(left, right);
}
println!();
print!("find fillings for cheese cake: ");

View File

@ -6,32 +6,40 @@ use async_trait::async_trait;
use sea_query::{ColumnRef, DynIden, Expr, IntoColumnRef, SimpleExpr, TableRef, ValueTuple};
use std::{collections::HashMap, str::FromStr};
/// A trait for basic Dataloader
/// Entity, or a Select<Entity>; to be used as parameters in [`LoaderTrait`]
pub trait EntityOrSelect<E: EntityTrait>: Send {
/// If self is Entity, use Entity::find()
fn select(self) -> Select<E>;
}
/// This trait implements the Data Loader API
#[async_trait]
pub trait LoaderTrait {
/// Source model
type Model: ModelTrait;
/// Used to eager load has_one relations
async fn load_one<R, C>(&self, stmt: Select<R>, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
where
C: ConnectionTrait,
R: EntityTrait,
R::Model: Send + Sync,
S: EntityOrSelect<R>,
<<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
/// Used to eager load has_many relations
async fn load_many<R, C>(&self, stmt: Select<R>, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
where
C: ConnectionTrait,
R: EntityTrait,
R::Model: Send + Sync,
S: EntityOrSelect<R>,
<<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
/// Used to eager load many_to_many relations
async fn load_many_to_many<R, V, C>(
async fn load_many_to_many<R, S, V, C>(
&self,
stmt: Select<R>,
stmt: S,
via: V,
db: &C,
) -> Result<Vec<Vec<R::Model>>, DbErr>
@ -39,11 +47,30 @@ pub trait LoaderTrait {
C: ConnectionTrait,
R: EntityTrait,
R::Model: Send + Sync,
S: EntityOrSelect<R>,
V: EntityTrait,
V::Model: Send + Sync,
<<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
}
impl<E> EntityOrSelect<E> for E
where
E: EntityTrait,
{
fn select(self) -> Select<E> {
E::find()
}
}
impl<E> EntityOrSelect<E> for Select<E>
where
E: EntityTrait,
{
fn select(self) -> Select<E> {
self
}
}
#[async_trait]
impl<M> LoaderTrait for Vec<M>
where
@ -51,29 +78,31 @@ where
{
type Model = M;
async fn load_one<R, C>(&self, stmt: Select<R>, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
where
C: ConnectionTrait,
R: EntityTrait,
R::Model: Send + Sync,
S: EntityOrSelect<R>,
<<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
{
self.as_slice().load_one(stmt, db).await
}
async fn load_many<R, C>(&self, stmt: Select<R>, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
where
C: ConnectionTrait,
R: EntityTrait,
R::Model: Send + Sync,
S: EntityOrSelect<R>,
<<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
{
self.as_slice().load_many(stmt, db).await
}
async fn load_many_to_many<R, V, C>(
async fn load_many_to_many<R, S, V, C>(
&self,
stmt: Select<R>,
stmt: S,
via: V,
db: &C,
) -> Result<Vec<Vec<R::Model>>, DbErr>
@ -81,6 +110,7 @@ where
C: ConnectionTrait,
R: EntityTrait,
R::Model: Send + Sync,
S: EntityOrSelect<R>,
V: EntityTrait,
V::Model: Send + Sync,
<<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
@ -96,11 +126,12 @@ where
{
type Model = M;
async fn load_one<R, C>(&self, stmt: Select<R>, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
where
C: ConnectionTrait,
R: EntityTrait,
R::Model: Send + Sync,
S: EntityOrSelect<R>,
<<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
{
// we verify that is HasOne relation
@ -119,7 +150,7 @@ where
let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
let stmt = <Select<R> as QueryFilter>::filter(stmt, condition);
let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
let data = stmt.all(db).await?;
@ -145,11 +176,12 @@ where
Ok(result)
}
async fn load_many<R, C>(&self, stmt: Select<R>, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
where
C: ConnectionTrait,
R: EntityTrait,
R::Model: Send + Sync,
S: EntityOrSelect<R>,
<<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
{
// we verify that is HasMany relation
@ -169,7 +201,7 @@ where
let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
let stmt = <Select<R> as QueryFilter>::filter(stmt, condition);
let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
let data = stmt.all(db).await?;
@ -205,9 +237,9 @@ where
Ok(result)
}
async fn load_many_to_many<R, V, C>(
async fn load_many_to_many<R, S, V, C>(
&self,
stmt: Select<R>,
stmt: S,
via: V,
db: &C,
) -> Result<Vec<Vec<R::Model>>, DbErr>
@ -215,6 +247,7 @@ where
C: ConnectionTrait,
R: EntityTrait,
R::Model: Send + Sync,
S: EntityOrSelect<R>,
V: EntityTrait,
V::Model: Send + Sync,
<<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
@ -261,7 +294,7 @@ where
let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
let stmt = <Select<R> as QueryFilter>::filter(stmt, condition);
let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
let data = stmt.all(db).await?;
// Map of R::PK -> R::Model
@ -283,11 +316,7 @@ where
let models: Vec<_> = fkeys
.into_iter()
.map(|fkey| {
data.get(&format!("{fkey:?}"))
.cloned()
.expect("Failed at finding key on hashmap")
})
.filter_map(|fkey| data.get(&format!("{fkey:?}")).cloned())
.collect();
models

View File

@ -27,7 +27,7 @@ async fn loader_load_one() -> Result<(), DbErr> {
.await?;
let bakers = baker::Entity::find().all(&ctx.db).await?;
let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?;
let bakeries = bakers.load_one(bakery::Entity, &ctx.db).await?;
assert_eq!(bakers, [baker_1, baker_2, baker_3]);
assert_eq!(bakeries, [Some(bakery_0.clone()), Some(bakery_0), None]);
@ -55,7 +55,7 @@ async fn loader_load_many() -> Result<(), DbErr> {
let baker_4 = insert_baker(&ctx.db, "Baker 4", bakery_2.id).await?;
let bakeries = bakery::Entity::find().all(&ctx.db).await?;
let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?;
let bakers = bakeries.load_many(baker::Entity, &ctx.db).await?;
assert_eq!(bakeries, [bakery_1.clone(), bakery_2.clone()]);
assert_eq!(
@ -126,8 +126,8 @@ async fn loader_load_many_multi() -> Result<(), DbErr> {
let _cake_4 = insert_cake(&ctx.db, "Apple Pie", None).await?; // no one makes apple pie
let bakeries = bakery::Entity::find().all(&ctx.db).await?;
let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?;
let cakes = bakeries.load_many(cake::Entity::find(), &ctx.db).await?;
let bakers = bakeries.load_many(baker::Entity, &ctx.db).await?;
let cakes = bakeries.load_many(cake::Entity, &ctx.db).await?;
assert_eq!(bakeries, [bakery_1, bakery_2]);
assert_eq!(bakers, [vec![baker_1, baker_2], vec![baker_3]]);
@ -152,7 +152,7 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> {
let baker_2 = insert_baker(&ctx.db, "Peter", bakery_1.id).await?;
let cake_1 = insert_cake(&ctx.db, "Cheesecake", None).await?;
let cake_2 = insert_cake(&ctx.db, "Chocolate", None).await?;
let cake_2 = insert_cake(&ctx.db, "Coffee", None).await?;
let cake_3 = insert_cake(&ctx.db, "Chiffon", None).await?;
let cake_4 = insert_cake(&ctx.db, "Apple Pie", None).await?; // no one makes apple pie
@ -163,7 +163,7 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> {
let bakers = baker::Entity::find().all(&ctx.db).await?;
let cakes = bakers
.load_many_to_many(cake::Entity::find(), cakes_bakers::Entity, &ctx.db)
.load_many_to_many(cake::Entity, cakes_bakers::Entity, &ctx.db)
.await?;
assert_eq!(bakers, [baker_1.clone(), baker_2.clone()]);
@ -175,11 +175,22 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> {
]
);
// same, but apply restrictions on cakes
let cakes = bakers
.load_many_to_many(
cake::Entity::find().filter(cake::Column::Name.like("Ch%")),
cakes_bakers::Entity,
&ctx.db,
)
.await?;
assert_eq!(cakes, [vec![cake_1.clone()], vec![cake_3.clone()]]);
// now, start again from cakes
let cakes = cake::Entity::find().all(&ctx.db).await?;
let bakers = cakes
.load_many_to_many(baker::Entity::find(), cakes_bakers::Entity, &ctx.db)
.load_many_to_many(baker::Entity, cakes_bakers::Entity, &ctx.db)
.await?;
assert_eq!(cakes, [cake_1, cake_2, cake_3, cake_4]);