load_many_to_many

This commit is contained in:
Chris Tsang 2023-02-02 09:38:30 +08:00
parent 0670827148
commit 83c0732395
2 changed files with 198 additions and 71 deletions

View File

@ -27,6 +27,21 @@ pub trait LoaderTrait {
R: EntityTrait,
R::Model: Send + Sync,
<<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
/// Used to eager load many_to_many relations
async fn load_many_to_many<R, V, C>(
&self,
stmt: Select<R>,
via: V,
db: &C,
) -> Result<Vec<Vec<R::Model>>, DbErr>
where
C: ConnectionTrait,
R: EntityTrait,
R::Model: Send + Sync,
V: EntityTrait,
V::Model: Send + Sync,
<<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
}
#[async_trait]
@ -55,6 +70,23 @@ where
{
self.as_slice().load_many(stmt, db).await
}
async fn load_many_to_many<R, V, C>(
&self,
stmt: Select<R>,
via: V,
db: &C,
) -> Result<Vec<Vec<R::Model>>, DbErr>
where
C: ConnectionTrait,
R: EntityTrait,
R::Model: Send + Sync,
V: EntityTrait,
V::Model: Send + Sync,
<<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
{
self.as_slice().load_many_to_many(stmt, via, db).await
}
}
#[async_trait]
@ -73,11 +105,11 @@ where
{
// we verify that is HasOne relation
if <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via().is_some() {
return Err(type_err("Relation is ManytoMany instead of HasOne"));
return Err(query_err("Relation is ManytoMany instead of HasOne"));
}
let rel_def = <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
if rel_def.rel_type == RelationType::HasMany {
return Err(type_err("Relation is HasMany instead of HasOne"));
return Err(query_err("Relation is HasMany instead of HasOne"));
}
let keys: Vec<ValueTuple> = self
@ -123,11 +155,11 @@ where
// we verify that is HasMany relation
if <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via().is_some() {
return Err(type_err("Relation is ManyToMany instead of HasMany"));
return Err(query_err("Relation is ManyToMany instead of HasMany"));
}
let rel_def = <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
if rel_def.rel_type == RelationType::HasOne {
return Err(type_err("Relation is HasOne instead of HasMany"));
return Err(query_err("Relation is HasOne instead of HasMany"));
}
let keys: Vec<ValueTuple> = self
@ -172,6 +204,106 @@ where
Ok(result)
}
async fn load_many_to_many<R, V, C>(
&self,
stmt: Select<R>,
via: V,
db: &C,
) -> Result<Vec<Vec<R::Model>>, DbErr>
where
C: ConnectionTrait,
R: EntityTrait,
R::Model: Send + Sync,
V: EntityTrait,
V::Model: Send + Sync,
<<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
{
if let Some(via_rel) =
<<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via()
{
let rel_def =
<<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
if rel_def.rel_type != RelationType::HasOne {
return Err(query_err("Relation to is not HasOne"));
}
if !cmp_table_ref(&via_rel.to_tbl, &via.table_ref()) {
return Err(query_err(format!(
"The given via Entity is incorrect: expected: {:?}, given: {:?}",
via_rel.to_tbl,
via.table_ref()
)));
}
let pkeys: Vec<ValueTuple> = self
.iter()
.map(|model: &M| extract_key(&via_rel.from_col, model))
.collect();
// Map of M::PK -> Vec<R::PK>
let mut keymap: HashMap<String, Vec<ValueTuple>> = Default::default();
let keys: Vec<ValueTuple> = {
let condition = prepare_condition(&via_rel.to_tbl, &via_rel.to_col, &pkeys);
let stmt = V::find().filter(condition);
let data = stmt.all(db).await?;
data.into_iter().for_each(|model| {
let pk = format!("{:?}", extract_key(&via_rel.to_col, &model));
let entry = keymap.entry(pk).or_default();
let fk = extract_key(&rel_def.from_col, &model);
entry.push(fk);
});
keymap.values().flatten().cloned().collect()
};
let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
let stmt = <Select<R> as QueryFilter>::filter(stmt, condition);
let data = stmt.all(db).await?;
// Map of R::PK -> R::Model
let data: HashMap<String, <R as EntityTrait>::Model> = data
.into_iter()
.map(|model| {
let key = format!("{:?}", extract_key(&rel_def.to_col, &model));
(key, model)
})
.collect();
let result: Vec<Vec<R::Model>> = pkeys
.into_iter()
.map(|pkey| {
let fkeys = keymap
.get(&format!("{pkey:?}"))
.cloned()
.unwrap_or_default();
let models: Vec<_> = fkeys
.into_iter()
.map(|fkey| {
data.get(&format!("{fkey:?}"))
.cloned()
.expect("Failed at finding key on hashmap")
})
.collect();
models
})
.collect();
Ok(result)
} else {
return Err(query_err("Relation is not ManyToMany"));
}
}
}
fn cmp_table_ref(left: &TableRef, right: &TableRef) -> bool {
// not ideal; but
format!("{left:?}") == format!("{right:?}")
}
fn extract_key<Model>(target_col: &Identity, model: &Model) -> ValueTuple

View File

@ -13,12 +13,12 @@ async fn loader_load_one() -> Result<(), DbErr> {
let ctx = TestContext::new("loader_test_load_one").await;
create_tables(&ctx.db).await?;
let bakery = insert_bakery(&ctx.db, "SeaSide Bakery").await?;
let bakery_0 = insert_bakery(&ctx.db, "SeaSide Bakery").await?;
let baker_1 = insert_baker(&ctx.db, "Baker 1", bakery.id).await?;
let baker_2 = baker::ActiveModel {
name: Set("Baker 2".to_owned()),
let baker_1 = insert_baker(&ctx.db, "Baker 1", bakery_0.id).await?;
let baker_2 = insert_baker(&ctx.db, "Baker 2", bakery_0.id).await?;
let baker_3 = baker::ActiveModel {
name: Set("Baker 3".to_owned()),
contact_details: Set(serde_json::json!({})),
bakery_id: Set(None),
..Default::default()
@ -29,34 +29,8 @@ async fn loader_load_one() -> Result<(), DbErr> {
let bakers = baker::Entity::find().all(&ctx.db).await?;
let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?;
assert_eq!(bakers, [baker_1, baker_2]);
assert_eq!(bakeries, [Some(bakery), None]);
Ok(())
}
#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
feature = "sqlx-sqlite",
feature = "sqlx-postgres"
))]
async fn loader_load_one_complex() -> Result<(), DbErr> {
let ctx = TestContext::new("loader_test_load_one_complex").await;
create_tables(&ctx.db).await?;
let bakery = insert_bakery(&ctx.db, "SeaSide Bakery").await?;
let baker_1 = insert_baker(&ctx.db, "Baker 1", bakery.id).await?;
let baker_2 = insert_baker(&ctx.db, "Baker 2", bakery.id).await?;
let bakers = baker::Entity::find().all(&ctx.db).await?;
let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?;
assert_eq!(bakers, [baker_1, baker_2]);
assert_eq!(bakeries, [Some(bakery.clone()), Some(bakery.clone())]);
assert_eq!(bakers, [baker_1, baker_2, baker_3]);
assert_eq!(bakeries, [Some(bakery_0.clone()), Some(bakery_0), None]);
Ok(())
}
@ -81,6 +55,18 @@ async fn loader_load_many() -> Result<(), DbErr> {
let baker_4 = insert_baker(&ctx.db, "Baker 4", bakery_2.id).await?;
let bakeries = bakery::Entity::find().all(&ctx.db).await?;
let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?;
assert_eq!(bakeries, [bakery_1.clone(), bakery_2.clone()]);
assert_eq!(
bakers,
[
[baker_1.clone(), baker_2.clone()],
[baker_3.clone(), baker_4.clone()]
]
);
// load bakers again but with additional condition
let bakers = bakeries
.load_many(
@ -89,11 +75,6 @@ async fn loader_load_many() -> Result<(), DbErr> {
)
.await?;
println!("A: {bakers:?}");
println!("B: {bakeries:?}");
assert_eq!(bakeries, [bakery_1.clone(), bakery_2.clone()]);
assert_eq!(
bakers,
[
@ -102,16 +83,6 @@ async fn loader_load_many() -> Result<(), DbErr> {
]
);
let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?;
assert_eq!(
bakers,
[
[baker_1.clone(), baker_2.clone()],
[baker_3.clone(), baker_4.clone()]
]
);
// now, start from baker
let bakers = baker::Entity::find().all(&ctx.db).await?;
@ -149,18 +120,15 @@ async fn loader_load_many_multi() -> Result<(), DbErr> {
let baker_2 = insert_baker(&ctx.db, "Jane", bakery_1.id).await?;
let baker_3 = insert_baker(&ctx.db, "Peter", bakery_2.id).await?;
let cake_1 = insert_cake(&ctx.db, "Cheesecake", bakery_1.id).await?;
let cake_2 = insert_cake(&ctx.db, "Chocolate", bakery_2.id).await?;
let cake_3 = insert_cake(&ctx.db, "Chiffon", bakery_2.id).await?;
let cake_1 = insert_cake(&ctx.db, "Cheesecake", Some(bakery_1.id)).await?;
let cake_2 = insert_cake(&ctx.db, "Chocolate", Some(bakery_2.id)).await?;
let cake_3 = insert_cake(&ctx.db, "Chiffon", Some(bakery_2.id)).await?;
let _cake_4 = insert_cake(&ctx.db, "Apple Pie", None).await?; // no one makes apple pie
let bakeries = bakery::Entity::find().all(&ctx.db).await?;
let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?;
let cakes = bakeries.load_many(cake::Entity::find(), &ctx.db).await?;
println!("{bakers:?}");
println!("{bakeries:?}");
println!("{cakes:?}");
assert_eq!(bakeries, [bakery_1, bakery_2]);
assert_eq!(bakers, [vec![baker_1, baker_2], vec![baker_3]]);
assert_eq!(cakes, [vec![cake_1], vec![cake_2, cake_3]]);
@ -168,7 +136,6 @@ async fn loader_load_many_multi() -> Result<(), DbErr> {
Ok(())
}
#[ignore]
#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
@ -184,9 +151,10 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> {
let baker_1 = insert_baker(&ctx.db, "Jane", bakery_1.id).await?;
let baker_2 = insert_baker(&ctx.db, "Peter", bakery_1.id).await?;
let cake_1 = insert_cake(&ctx.db, "Cheesecake", bakery_1.id).await?;
let cake_2 = insert_cake(&ctx.db, "Chocolate", bakery_1.id).await?;
let cake_3 = insert_cake(&ctx.db, "Chiffon", bakery_1.id).await?;
let cake_1 = insert_cake(&ctx.db, "Cheesecake", None).await?;
let cake_2 = insert_cake(&ctx.db, "Chocolate", None).await?;
let cake_3 = insert_cake(&ctx.db, "Chiffon", None).await?;
let cake_4 = insert_cake(&ctx.db, "Apple Pie", None).await?; // no one makes apple pie
insert_cake_baker(&ctx.db, baker_1.id, cake_1.id).await?;
insert_cake_baker(&ctx.db, baker_1.id, cake_2.id).await?;
@ -194,13 +162,36 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> {
insert_cake_baker(&ctx.db, baker_2.id, cake_3.id).await?;
let bakers = baker::Entity::find().all(&ctx.db).await?;
let cakes = bakers.load_many(cake::Entity::find(), &ctx.db).await?;
let cakes = bakers
.load_many_to_many(cake::Entity::find(), cakes_bakers::Entity, &ctx.db)
.await?;
println!("{bakers:?}");
println!("{cakes:?}");
assert_eq!(bakers, [baker_1.clone(), baker_2.clone()]);
assert_eq!(
cakes,
[
vec![cake_1.clone(), cake_2.clone()],
vec![cake_2.clone(), cake_3.clone()]
]
);
assert_eq!(bakers, [baker_1, baker_2]);
assert_eq!(cakes, [vec![cake_1, cake_2.clone()], vec![cake_2, cake_3]]);
// now, start again from cakes
let cakes = cake::Entity::find().all(&ctx.db).await?;
let bakers = cakes
.load_many_to_many(baker::Entity::find(), cakes_bakers::Entity, &ctx.db)
.await?;
assert_eq!(cakes, [cake_1, cake_2, cake_3, cake_4]);
assert_eq!(
bakers,
[
vec![baker_1.clone()],
vec![baker_1.clone(), baker_2.clone()],
vec![baker_2.clone()],
vec![]
]
);
Ok(())
}
@ -226,12 +217,16 @@ pub async fn insert_baker(db: &DbConn, name: &str, bakery_id: i32) -> Result<bak
.await
}
pub async fn insert_cake(db: &DbConn, name: &str, bakery_id: i32) -> Result<cake::Model, DbErr> {
pub async fn insert_cake(
db: &DbConn,
name: &str,
bakery_id: Option<i32>,
) -> Result<cake::Model, DbErr> {
cake::ActiveModel {
name: Set(name.to_owned()),
price: Set(rust_decimal::Decimal::ONE),
gluten_free: Set(false),
bakery_id: Set(Some(bakery_id)),
bakery_id: Set(bakery_id),
..Default::default()
}
.insert(db)