diff --git a/src/entity/loader.rs b/src/entity/loader.rs index e9af2cf9..b1c11084 100644 --- a/src/entity/loader.rs +++ b/src/entity/loader.rs @@ -1,11 +1,23 @@ -use crate::{DbErr, EntityTrait, ModelTrait, QueryFilter, Select, Related, RelationType, Identity, Condition, Value, ColumnTrait, ConnectionTrait}; -use std::{fmt::Debug, str::FromStr, collections::BTreeMap}; +use crate::{ + ColumnTrait, Condition, ConnectionTrait, DbErr, EntityTrait, Identity, ModelTrait, QueryFilter, + Related, RelationType, Select, Value, +}; +use async_trait::async_trait; +use sea_query::{Expr, IntoColumnRef, SimpleExpr, ValueTuple}; +use std::{collections::BTreeMap, fmt::Debug, str::FromStr}; -#[async_trait::async_trait] +/// A trait for basic Dataloader +#[async_trait] pub trait LoaderTrait { + /// Source model type Model: ModelTrait; - async fn load_one(&self, db: &C) -> Result>, DbErr> + /// Used to eager load has_one relations + /// + /// + /// + /// + async fn load_one(&self, stmt: Select, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, @@ -14,7 +26,36 @@ pub trait LoaderTrait { <::Model as ModelTrait>::Entity: Related, <<<::Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::Err: Debug; - async fn load_many(&self, db: &C) -> Result>, DbErr> + /// Used to eager load has_many relations + /// + /// # Example + /// + /// ``` + /// use sea_orm::{tests_cfg::*, entity::loader::*}; + /// + /// let db = MockDatabase::new(DbBackend::Postgres) + /// .append_query_results(vec![ + /// vec![cake::Model { + /// id: 1, + /// name: "New York Cheese".to_owned(), + /// } + /// .into_mock_row()], + /// vec![fruit::Model { + /// id: 1, + /// name: "Apple".to_owned(), + /// cake_id: Some(1), + /// } + /// .into_mock_row()], + /// ]) + /// .into_connection(); + /// + /// let cakes = vec![cake::Model { id: 1, name: "New York Cheese".to_owned(), }]; + /// + /// let fruits = cakes.load_many(fruit::Entity::find(), &db); + /// + /// assert_eq!(fruits, vec![fruit::Model { id: 1, name: "Apple".to_owned(), cake_id: Some(1), }]); + /// ``` + async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, @@ -32,7 +73,7 @@ where { type Model = M; - async fn load_one(&self, db: &C) -> Result>, DbErr> + async fn load_one(&self, stmt: Select, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, @@ -41,8 +82,7 @@ where <::Model as ModelTrait>::Entity: Related, <<::Entity as EntityTrait>::Column as FromStr>::Err: Debug, { - let rel_def = - <<::Model as ModelTrait>::Entity as Related>::to(); + let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); // we verify that is has_one relation match (&rel_def).rel_type { @@ -52,79 +92,21 @@ where } } - fn extract_key(target_col: &Identity, model: &Model) -> Vec - where - Model: ModelTrait, - <<::Entity as EntityTrait>::Column as FromStr>::Err: Debug, - { - match target_col { - Identity::Unary(a) => { - let column_a = <<::Entity as EntityTrait>::Column as FromStr>::from_str(&a.to_string()).unwrap(); - vec![model.get(column_a)] - }, - Identity::Binary(a, b) => { - let column_a = <<::Entity as EntityTrait>::Column as FromStr>::from_str(&a.to_string()).unwrap(); - let column_b = <<::Entity as EntityTrait>::Column as FromStr>::from_str(&b.to_string()).unwrap(); - vec![model.get(column_a), model.get(column_b)] - }, - Identity::Ternary(a, b, c) => { - let column_a = <<::Entity as EntityTrait>::Column as FromStr>::from_str(&a.to_string()).unwrap(); - let column_b = <<::Entity as EntityTrait>::Column as FromStr>::from_str(&b.to_string()).unwrap(); - let column_c = <<::Entity as EntityTrait>::Column as FromStr>::from_str(&c.to_string()).unwrap(); - vec![model.get(column_a), model.get(column_b), model.get(column_c)] - }, - } - } - let keys: Vec> = self .iter() - .map(|model: &M| { - extract_key(&rel_def.from_col, model) - }) + .map(|model: &M| extract_key(&rel_def.from_col, model)) .collect(); - let condition = match &rel_def.to_col { - Identity::Unary(a) => { - let column_a: ::Column = <<<::Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(&a.to_string()).unwrap(); - Condition::all().add(ColumnTrait::is_in( - &column_a, - keys.iter().map(|key| key[0].clone()).collect::>(), - )) - } - Identity::Binary(a, b) => { - let column_a: <<::Model as ModelTrait>::Entity as EntityTrait>::Column = <<<::Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(&a.to_string()).unwrap(); - let column_b: <<::Model as ModelTrait>::Entity as EntityTrait>::Column = <<<::Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(&b.to_string()).unwrap(); - // TODO - // Condition::all().add( - // sea_query::Expr::tuple([column_a.to_string(), column_b]).is_in(keys.iter().map(|key| (key[0].clone(), key[1].clone())).collect::>()) - // ) - // TODO - Condition::all().add(ColumnTrait::is_in( - &column_a, - keys.iter().map(|key| key[0].clone()).collect::>(), - )) - } - Identity::Ternary(a, b, c) => { - let column_a = <<<::Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(&a.to_string()).unwrap(); - let column_b = <<<::Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(&b.to_string()).unwrap(); - let column_c = <<<::Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(&c.to_string()).unwrap(); - // TODO - Condition::all().add(ColumnTrait::is_in( - &column_a, - keys.iter().map(|key| key[0].clone()).collect::>(), - )) - } - }; - - let stmt = ::find(); + let condition = prepare_condition::(&rel_def.to_col, &keys); let stmt = as QueryFilter>::filter(stmt, condition); let data = stmt.all(db).await?; - let mut hashmap: BTreeMap::::Model> = data - .into_iter() - .fold(BTreeMap::::Model>::new(), |mut acc: BTreeMap::::Model>, value: ::Model| { + let mut hashmap: BTreeMap::Model> = data.into_iter().fold( + BTreeMap::::Model>::new(), + |mut acc: BTreeMap::Model>, + value: ::Model| { { let key = extract_key(&rel_def.to_col, &value); @@ -132,7 +114,8 @@ where } acc - }); + }, + ); let result: Vec::Model>> = keys .iter() @@ -146,7 +129,7 @@ where Ok(result) } - async fn load_many(&self, db: &C) -> Result>, DbErr> + async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, @@ -155,7 +138,165 @@ where <::Model as ModelTrait>::Entity: Related, <<::Entity as EntityTrait>::Column as FromStr>::Err: Debug, { - // we should verify this is a has_many relation - Ok(vec![]) + let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); + + // we verify that is has_many relation + match (&rel_def).rel_type { + RelationType::HasMany => (), + RelationType::HasOne => { + return Err(DbErr::Type("Relation is HasOne instead of HasMany".into())) + } + } + + let keys: Vec> = self + .iter() + .map(|model: &M| extract_key(&rel_def.from_col, model)) + .collect(); + + let condition = prepare_condition::(&rel_def.to_col, &keys); + + let stmt = as QueryFilter>::filter(stmt, condition); + + let data = stmt.all(db).await?; + + let mut hashmap: BTreeMap::Model>> = + keys.iter() + .fold(BTreeMap::new(), |mut acc, key: &Vec| { + acc.insert(format!("{:?}", key), Vec::new()); + + acc + }); + + data.into_iter() + .for_each(|value: ::Model| { + let key = extract_key(&rel_def.to_col, &value); + + let vec = hashmap.get_mut(&format!("{:?}", key)).unwrap(); + + vec.push(value); + }); + + let result: Vec> = keys + .iter() + .map(|key: &Vec| hashmap.remove(&format!("{:?}", key)).to_owned().unwrap()) + .collect(); + + Ok(result) + } +} + +fn extract_key(target_col: &Identity, model: &Model) -> Vec +where + Model: ModelTrait, + <<::Entity as EntityTrait>::Column as FromStr>::Err: Debug, +{ + match target_col { + Identity::Unary(a) => { + let column_a = + <<::Entity as EntityTrait>::Column as FromStr>::from_str( + &a.to_string(), + ) + .unwrap(); + vec![model.get(column_a)] + } + Identity::Binary(a, b) => { + let column_a = + <<::Entity as EntityTrait>::Column as FromStr>::from_str( + &a.to_string(), + ) + .unwrap(); + let column_b = + <<::Entity as EntityTrait>::Column as FromStr>::from_str( + &b.to_string(), + ) + .unwrap(); + vec![model.get(column_a), model.get(column_b)] + } + Identity::Ternary(a, b, c) => { + let column_a = + <<::Entity as EntityTrait>::Column as FromStr>::from_str( + &a.to_string(), + ) + .unwrap(); + let column_b = + <<::Entity as EntityTrait>::Column as FromStr>::from_str( + &b.to_string(), + ) + .unwrap(); + let column_c = + <<::Entity as EntityTrait>::Column as FromStr>::from_str( + &c.to_string(), + ) + .unwrap(); + vec![ + model.get(column_a), + model.get(column_b), + model.get(column_c), + ] + } + } +} + +fn prepare_condition(col: &Identity, keys: &Vec>) -> Condition +where + M: ModelTrait, + <<::Entity as EntityTrait>::Column as FromStr>::Err: Debug, +{ + match col { + Identity::Unary(column_a) => { + let column_a: ::Column = + <::Column as FromStr>::from_str(&column_a.to_string()) + .unwrap(); + Condition::all().add(ColumnTrait::is_in( + &column_a, + keys.iter() + .map(|key| key[0].clone()) + .collect::>(), + )) + } + Identity::Binary(column_a, column_b) => { + let column_a: ::Column = + <::Column as FromStr>::from_str(&column_a.to_string()) + .unwrap(); + let column_b: ::Column = + <::Column as FromStr>::from_str(&column_b.to_string()) + .unwrap(); + Condition::all().add( + Expr::tuple([ + SimpleExpr::Column(column_a.into_column_ref()), + SimpleExpr::Column(column_b.into_column_ref()), + ]) + .in_tuples( + keys.iter() + .map(|key| ValueTuple::Two(key[0].clone(), key[1].clone())) + .collect::>(), + ), + ) + } + Identity::Ternary(column_a, column_b, column_c) => { + let column_a: ::Column = + <::Column as FromStr>::from_str(&column_a.to_string()) + .unwrap(); + let column_b: ::Column = + <::Column as FromStr>::from_str(&column_b.to_string()) + .unwrap(); + let column_c: ::Column = + <::Column as FromStr>::from_str(&column_c.to_string()) + .unwrap(); + Condition::all().add( + Expr::tuple([ + SimpleExpr::Column(column_a.into_column_ref()), + SimpleExpr::Column(column_b.into_column_ref()), + SimpleExpr::Column(column_c.into_column_ref()), + ]) + .in_tuples( + keys.iter() + .map(|key| { + ValueTuple::Three(key[0].clone(), key[1].clone(), key[2].clone()) + }) + .collect::>(), + ), + ) + } } }