From 06708271489d67a8de362413e5b89f1110063df6 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Thu, 2 Feb 2023 07:46:38 +0800 Subject: [PATCH] Improve test cases --- src/entity/relation.rs | 2 +- src/query/loader.rs | 104 +++++++++++++++-------------------- tests/loader_tests.rs | 119 +++++++++++++++++++++++++++++------------ 3 files changed, 130 insertions(+), 95 deletions(-) diff --git a/src/entity/relation.rs b/src/entity/relation.rs index ad291f0f..cae51788 100644 --- a/src/entity/relation.rs +++ b/src/entity/relation.rs @@ -7,7 +7,7 @@ use sea_query::{ use std::fmt::Debug; /// Defines the type of relationship -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum RelationType { /// An Entity has one relationship HasOne, diff --git a/src/query/loader.rs b/src/query/loader.rs index 5d9dd37d..5beee927 100644 --- a/src/query/loader.rs +++ b/src/query/loader.rs @@ -1,9 +1,9 @@ use crate::{ - error::*, ColumnTrait, Condition, ConnectionTrait, DbErr, EntityTrait, Identity, ModelTrait, - QueryFilter, Related, RelationType, Select, + error::*, Condition, ConnectionTrait, DbErr, EntityTrait, Identity, ModelTrait, QueryFilter, + Related, RelationType, Select, }; use async_trait::async_trait; -use sea_query::{Expr, IntoColumnRef, SimpleExpr, ValueTuple}; +use sea_query::{ColumnRef, DynIden, Expr, IntoColumnRef, SimpleExpr, TableRef, ValueTuple}; use std::{collections::HashMap, str::FromStr}; /// A trait for basic Dataloader @@ -71,12 +71,13 @@ where R::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related, { + // we verify that is HasOne relation + if <<::Model as ModelTrait>::Entity as Related>::via().is_some() { + return Err(type_err("Relation is ManytoMany instead of HasOne")); + } let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); - - // we verify that is has_one relation - match (rel_def).rel_type { - RelationType::HasOne => (), - RelationType::HasMany => return Err(type_err("Relation is HasMany instead of HasOne")), + if rel_def.rel_type == RelationType::HasMany { + return Err(type_err("Relation is HasMany instead of HasOne")); } let keys: Vec = self @@ -84,7 +85,7 @@ where .map(|model: &M| extract_key(&rel_def.from_col, model)) .collect(); - let condition = prepare_condition::<::Model>(&rel_def.to_col, &keys); + let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); let stmt = as QueryFilter>::filter(stmt, condition); @@ -119,12 +120,14 @@ where R::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related, { - let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); + // we verify that is HasMany relation - // we verify that is has_many relation - match (rel_def).rel_type { - RelationType::HasMany => (), - RelationType::HasOne => return Err(type_err("Relation is HasOne instead of HasMany")), + if <<::Model as ModelTrait>::Entity as Related>::via().is_some() { + return Err(type_err("Relation is ManyToMany instead of HasMany")); + } + let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); + if rel_def.rel_type == RelationType::HasOne { + return Err(type_err("Relation is HasOne instead of HasMany")); } let keys: Vec = self @@ -132,7 +135,7 @@ where .map(|model: &M| extract_key(&rel_def.from_col, model)) .collect(); - let condition = prepare_condition::<::Model>(&rel_def.to_col, &keys); + let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); let stmt = as QueryFilter>::filter(stmt, condition); @@ -222,54 +225,35 @@ where } } -fn prepare_condition(col: &Identity, keys: &[ValueTuple]) -> Condition -where - M: ModelTrait, -{ +fn prepare_condition(table: &TableRef, col: &Identity, keys: &[ValueTuple]) -> Condition { match col { Identity::Unary(column_a) => { - let column_a: ::Column = - <::Column as FromStr>::from_str(&column_a.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *A:1")); - Condition::all().add(ColumnTrait::is_in( - &column_a, - keys.iter().cloned().flatten(), - )) - } - Identity::Binary(column_a, column_b) => { - let column_a: ::Column = - <::Column as FromStr>::from_str(&column_a.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *A:2")); - let column_b: ::Column = - <::Column as FromStr>::from_str(&column_b.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *B:2")); - Condition::all().add( - Expr::tuple([ - SimpleExpr::Column(column_a.into_column_ref()), - SimpleExpr::Column(column_b.into_column_ref()), - ]) - .in_tuples(keys.iter().cloned()), - ) - } - Identity::Ternary(column_a, column_b, column_c) => { - let column_a: ::Column = - <::Column as FromStr>::from_str(&column_a.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *A:3")); - let column_b: ::Column = - <::Column as FromStr>::from_str(&column_b.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *B:3")); - let column_c: ::Column = - <::Column as FromStr>::from_str(&column_c.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *C:3")); - Condition::all().add( - Expr::tuple([ - SimpleExpr::Column(column_a.into_column_ref()), - SimpleExpr::Column(column_b.into_column_ref()), - SimpleExpr::Column(column_c.into_column_ref()), - ]) - .in_tuples(keys.iter().cloned()), - ) + let column_a = table_column(table, column_a); + Condition::all().add(Expr::col(column_a).is_in(keys.iter().cloned().flatten())) } + Identity::Binary(column_a, column_b) => Condition::all().add( + Expr::tuple([ + SimpleExpr::Column(table_column(table, column_a)), + SimpleExpr::Column(table_column(table, column_b)), + ]) + .in_tuples(keys.iter().cloned()), + ), + Identity::Ternary(column_a, column_b, column_c) => Condition::all().add( + Expr::tuple([ + SimpleExpr::Column(table_column(table, column_a)), + SimpleExpr::Column(table_column(table, column_b)), + SimpleExpr::Column(table_column(table, column_c)), + ]) + .in_tuples(keys.iter().cloned()), + ), + } +} + +fn table_column(tbl: &TableRef, col: &DynIden) -> ColumnRef { + match tbl.to_owned() { + TableRef::Table(tbl) => (tbl, col.clone()).into_column_ref(), + TableRef::SchemaTable(sch, tbl) => (sch, tbl, col.clone()).into_column_ref(), + val => unimplemented!("Unsupported TableRef {val:?}"), } } diff --git a/tests/loader_tests.rs b/tests/loader_tests.rs index 0026fc85..ba22d618 100644 --- a/tests/loader_tests.rs +++ b/tests/loader_tests.rs @@ -24,18 +24,10 @@ async fn loader_load_one() -> Result<(), DbErr> { ..Default::default() } .insert(&ctx.db) - .await - .expect("could not insert baker"); + .await?; - let bakers = baker::Entity::find() - .all(&ctx.db) - .await - .expect("Should load bakers"); - - let bakeries = bakers - .load_one(bakery::Entity::find(), &ctx.db) - .await - .expect("Should load bakeries"); + let bakers = baker::Entity::find().all(&ctx.db).await?; + let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?; assert_eq!(bakers, [baker_1, baker_2]); @@ -59,15 +51,8 @@ async fn loader_load_one_complex() -> Result<(), DbErr> { let baker_1 = insert_baker(&ctx.db, "Baker 1", bakery.id).await?; let baker_2 = insert_baker(&ctx.db, "Baker 2", bakery.id).await?; - let bakers = baker::Entity::find() - .all(&ctx.db) - .await - .expect("Should load bakers"); - - let bakeries = bakers - .load_one(bakery::Entity::find(), &ctx.db) - .await - .expect("Should load bakeries"); + let bakers = baker::Entity::find().all(&ctx.db).await?; + let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?; assert_eq!(bakers, [baker_1, baker_2]); @@ -95,23 +80,19 @@ async fn loader_load_many() -> Result<(), DbErr> { let baker_3 = insert_baker(&ctx.db, "John", bakery_2.id).await?; let baker_4 = insert_baker(&ctx.db, "Baker 4", bakery_2.id).await?; - let bakeries = bakery::Entity::find() - .all(&ctx.db) - .await - .expect("Should load bakeries"); + let bakeries = bakery::Entity::find().all(&ctx.db).await?; let bakers = bakeries .load_many( baker::Entity::find().filter(baker::Column::Name.like("Baker%")), &ctx.db, ) - .await - .expect("Should load bakers"); + .await?; println!("A: {bakers:?}"); println!("B: {bakeries:?}"); - assert_eq!(bakeries, [bakery_1, bakery_2]); + assert_eq!(bakeries, [bakery_1.clone(), bakery_2.clone()]); assert_eq!( bakers, @@ -121,12 +102,32 @@ async fn loader_load_many() -> Result<(), DbErr> { ] ); - let bakers = bakeries - .load_many(baker::Entity::find(), &ctx.db) - .await - .expect("Should load bakers"); + let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?; - assert_eq!(bakers, [[baker_1, baker_2], [baker_3, baker_4]]); + assert_eq!( + bakers, + [ + [baker_1.clone(), baker_2.clone()], + [baker_3.clone(), baker_4.clone()] + ] + ); + + // now, start from baker + + let bakers = baker::Entity::find().all(&ctx.db).await?; + let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?; + + // note that two bakers share the same bakery + assert_eq!(bakers, [baker_1, baker_2, baker_3, baker_4]); + assert_eq!( + bakeries, + [ + Some(bakery_1.clone()), + Some(bakery_1), + Some(bakery_2.clone()), + Some(bakery_2) + ] + ); Ok(()) } @@ -137,8 +138,8 @@ async fn loader_load_many() -> Result<(), DbErr> { feature = "sqlx-sqlite", feature = "sqlx-postgres" ))] -async fn loader_load_many_many() -> Result<(), DbErr> { - let ctx = TestContext::new("loader_test_load_many_many").await; +async fn loader_load_many_multi() -> Result<(), DbErr> { + let ctx = TestContext::new("loader_test_load_many_multi").await; create_tables(&ctx.db).await?; let bakery_1 = insert_bakery(&ctx.db, "SeaSide Bakery").await?; @@ -167,6 +168,43 @@ async fn loader_load_many_many() -> Result<(), DbErr> { Ok(()) } +#[ignore] +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +async fn loader_load_many_to_many() -> Result<(), DbErr> { + let ctx = TestContext::new("loader_test_load_many_to_many").await; + create_tables(&ctx.db).await?; + + let bakery_1 = insert_bakery(&ctx.db, "SeaSide Bakery").await?; + + let baker_1 = insert_baker(&ctx.db, "Jane", bakery_1.id).await?; + let baker_2 = insert_baker(&ctx.db, "Peter", bakery_1.id).await?; + + let cake_1 = insert_cake(&ctx.db, "Cheesecake", bakery_1.id).await?; + let cake_2 = insert_cake(&ctx.db, "Chocolate", bakery_1.id).await?; + let cake_3 = insert_cake(&ctx.db, "Chiffon", bakery_1.id).await?; + + insert_cake_baker(&ctx.db, baker_1.id, cake_1.id).await?; + insert_cake_baker(&ctx.db, baker_1.id, cake_2.id).await?; + insert_cake_baker(&ctx.db, baker_2.id, cake_2.id).await?; + insert_cake_baker(&ctx.db, baker_2.id, cake_3.id).await?; + + let bakers = baker::Entity::find().all(&ctx.db).await?; + let cakes = bakers.load_many(cake::Entity::find(), &ctx.db).await?; + + println!("{bakers:?}"); + println!("{cakes:?}"); + + assert_eq!(bakers, [baker_1, baker_2]); + assert_eq!(cakes, [vec![cake_1, cake_2.clone()], vec![cake_2, cake_3]]); + + Ok(()) +} + pub async fn insert_bakery(db: &DbConn, name: &str) -> Result { bakery::ActiveModel { name: Set(name.to_owned()), @@ -199,3 +237,16 @@ pub async fn insert_cake(db: &DbConn, name: &str, bakery_id: i32) -> Result Result { + cakes_bakers::ActiveModel { + cake_id: Set(cake_id), + baker_id: Set(baker_id), + } + .insert(db) + .await +}