diff --git a/sea-orm-macros/src/derives/partial_model.rs b/sea-orm-macros/src/derives/partial_model.rs index 7bd033cd..3d370358 100644 --- a/sea-orm-macros/src/derives/partial_model.rs +++ b/sea-orm-macros/src/derives/partial_model.rs @@ -43,7 +43,7 @@ impl DerivePartialModel { return Err(Error::NotSupportGeneric(input.generics.params.span())); } - let syn::Data::Struct(syn::DataStruct{fields:syn::Fields::Named(syn::FieldsNamed{named:fields,..}),..},..)= input.data else{ + let syn::Data::Struct(syn::DataStruct{fields:syn::Fields::Named(syn::FieldsNamed{named:fields,..}),..},..) = input.data else{ return Err(Error::InputNotStruct); }; diff --git a/src/entity/prelude.rs b/src/entity/prelude.rs index 993b4c00..85cd3742 100644 --- a/src/entity/prelude.rs +++ b/src/entity/prelude.rs @@ -2,8 +2,8 @@ pub use crate::{ error::*, sea_query::BlobSize, ActiveEnum, ActiveModelBehavior, ActiveModelTrait, ColumnDef, ColumnTrait, ColumnType, ColumnTypeTrait, ConnectionTrait, CursorTrait, DatabaseConnection, DbConn, EntityName, EntityTrait, EnumIter, ForeignKeyAction, Iden, IdenStatic, Linked, - LoaderTrait, ModelTrait, PaginatorTrait, PartialModelTrait, PrimaryKeyToColumn, - PrimaryKeyTrait, QueryFilter, QueryResult, Related, RelationDef, RelationTrait, Select, Value, + LoaderTrait, ModelTrait, PaginatorTrait, PrimaryKeyToColumn, PrimaryKeyTrait, QueryFilter, + QueryResult, Related, RelationDef, RelationTrait, Select, Value, }; #[cfg(feature = "macros")] diff --git a/src/executor/cursor.rs b/src/executor/cursor.rs index b1c3bc26..c9101de6 100644 --- a/src/executor/cursor.rs +++ b/src/executor/cursor.rs @@ -1,6 +1,6 @@ use crate::{ - ConnectionTrait, DbErr, EntityTrait, FromQueryResult, Identity, IntoIdentity, QueryOrder, - Select, SelectModel, SelectorTrait, + ConnectionTrait, DbErr, EntityTrait, FromQueryResult, Identity, IntoIdentity, + PartialModelTrait, QueryOrder, QuerySelect, Select, SelectModel, SelectorTrait, }; use sea_query::{ Condition, DynIden, Expr, IntoValueTuple, Order, SeaRc, SelectStatement, SimpleExpr, Value, @@ -234,6 +234,14 @@ where } } + /// Return a [Selector] from `Self` that wraps a [SelectModel] with a [PartialModel](PartialModelTrait) + pub fn into_partial_model(self) -> Cursor> + where + M: PartialModelTrait, + { + M::select_cols(QuerySelect::select_only(self)).into_model::() + } + /// Construct a [Cursor] that fetch JSON value #[cfg(feature = "with-json")] pub fn into_json(self) -> Cursor> { @@ -247,6 +255,17 @@ where } } +impl QuerySelect for Cursor +where + S: SelectorTrait, +{ + type QueryStatement = SelectStatement; + + fn query(&mut self) -> &mut SelectStatement { + &mut self.query + } +} + impl QueryOrder for Cursor where S: SelectorTrait, diff --git a/src/executor/select.rs b/src/executor/select.rs index 462f1601..2dc6168c 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -159,7 +159,7 @@ where where M: PartialModelTrait, { - M::select_cols(crate::QuerySelect::select_only(self)).into_model::() + M::select_cols(QuerySelect::select_only(self)).into_model::() } /// Get a selectable Model as a [JsonValue] for SQL JSON operations @@ -405,6 +405,18 @@ where { self.into_model().stream(db).await } + + /// Stream the result of the operation with PartialModel + pub async fn stream_partial_model<'a: 'b, 'b, C, M>( + self, + db: &'a C, + ) -> Result> + 'b + Send, DbErr> + where + C: ConnectionTrait + StreamTrait + Send, + M: PartialModelTrait + Send + 'b, + { + self.into_partial_model().stream(db).await + } } impl SelectTwo @@ -430,7 +442,7 @@ where M: PartialModelTrait, N: PartialModelTrait, { - let select = crate::QuerySelect::select_only(self); + let select = QuerySelect::select_only(self); let select = M::select_cols(select); let select = N::select_cols(select); select.into_model::() @@ -471,6 +483,19 @@ where { self.into_model().stream(db).await } + + /// Stream the result of the operation with PartialModel + pub async fn stream_partial_model<'a: 'b, 'b, C, M, N>( + self, + db: &'a C, + ) -> Result), DbErr>> + 'b + Send, DbErr> + where + C: ConnectionTrait + StreamTrait + Send, + M: PartialModelTrait + Send + 'b, + N: PartialModelTrait + Send + 'b, + { + self.into_partial_model().stream(db).await + } } impl SelectTwoMany @@ -489,6 +514,7 @@ where selector: SelectTwoModel { model: PhantomData }, } } + /// Performs a conversion to [Selector] with partial model fn into_partial_model(self) -> Selector> where diff --git a/tests/collection_tests.rs b/tests/collection_tests.rs index 84811437..2447bee9 100644 --- a/tests/collection_tests.rs +++ b/tests/collection_tests.rs @@ -2,7 +2,9 @@ pub mod common; pub use common::{features::*, setup::*, TestContext}; use pretty_assertions::assert_eq; -use sea_orm::{entity::prelude::*, entity::*, DatabaseConnection}; +use sea_orm::{ + entity::prelude::*, entity::*, DatabaseConnection, DerivePartialModel, FromQueryResult, +}; #[sea_orm_macros::test] #[cfg(all(feature = "sqlx-postgres", feature = "postgres-array"))] @@ -11,6 +13,7 @@ async fn main() -> Result<(), DbErr> { create_tables(&ctx.db).await?; insert_collection(&ctx.db).await?; update_collection(&ctx.db).await?; + select_collection(&ctx.db).await?; ctx.delete().await; Ok(()) @@ -149,3 +152,27 @@ pub async fn update_collection(db: &DatabaseConnection) -> Result<(), DbErr> { Ok(()) } + +pub async fn select_collection(db: &DatabaseConnection) -> Result<(), DbErr> { + use collection::*; + + #[derive(DerivePartialModel, FromQueryResult, Debug, PartialEq)] + #[sea_orm(entity = "Entity")] + struct PartialSelectResult { + name: String, + } + + let result = Entity::find_by_id(1) + .into_partial_model::() + .one(db) + .await?; + + assert_eq!( + result, + Some(PartialSelectResult { + name: "Collection 1".into(), + }) + ); + + Ok(()) +} diff --git a/tests/cursor_tests.rs b/tests/cursor_tests.rs index e52f8a4b..4d9ab91e 100644 --- a/tests/cursor_tests.rs +++ b/tests/cursor_tests.rs @@ -2,7 +2,7 @@ pub mod common; pub use common::{features::*, setup::*, TestContext}; use pretty_assertions::assert_eq; -use sea_orm::{entity::prelude::*, FromQueryResult}; +use sea_orm::{entity::prelude::*, DerivePartialModel, FromQueryResult}; use serde_json::json; #[sea_orm_macros::test] @@ -202,7 +202,7 @@ pub async fn cursor_pagination(db: &DatabaseConnection) -> Result<(), DbErr> { // Fetch custom struct - #[derive(FromQueryResult, Debug, PartialEq)] + #[derive(FromQueryResult, Debug, PartialEq, Clone)] struct Row { id: i32, } @@ -233,5 +233,44 @@ pub async fn cursor_pagination(db: &DatabaseConnection) -> Result<(), DbErr> { [json!({ "id": 6 }), json!({ "id": 7 })] ); + #[derive(DerivePartialModel, FromQueryResult, Debug, PartialEq, Clone)] + #[sea_orm(entity = "Entity")] + struct PartialRow { + #[sea_orm(from_col = "id")] + id: i32, + #[sea_orm(from_expr = "sea_query::Expr::col(Column::Id).add(1000)")] + id_shifted: i32, + } + + let mut cursor = cursor.into_partial_model::(); + + assert_eq!( + cursor.first(2).all(db).await?, + [ + PartialRow { + id: 6, + id_shifted: 1006, + }, + PartialRow { + id: 7, + id_shifted: 1007, + } + ] + ); + + assert_eq!( + cursor.first(3).all(db).await?, + [ + PartialRow { + id: 6, + id_shifted: 1006, + }, + PartialRow { + id: 7, + id_shifted: 1007, + } + ] + ); + Ok(()) } diff --git a/tests/derive_tests.rs b/tests/derive_tests.rs index 9fca8a4f..20e49d06 100644 --- a/tests/derive_tests.rs +++ b/tests/derive_tests.rs @@ -1,100 +1,58 @@ -mod from_query_result { - use sea_orm::{FromQueryResult, TryGetable}; +use sea_orm::{FromQueryResult, TryGetable}; - #[derive(FromQueryResult)] - struct SimpleTest { - _foo: i32, - _bar: String, - } - - #[derive(FromQueryResult)] - struct GenericTest { - _foo: i32, - _bar: T, - } - - #[derive(FromQueryResult)] - struct DoubleGenericTest { - _foo: T, - _bar: F, - } - - #[derive(FromQueryResult)] - struct BoundsGenericTest { - _foo: T, - } - - #[derive(FromQueryResult)] - struct WhereGenericTest - where - T: Copy + Clone + 'static, - { - _foo: T, - } - - #[derive(FromQueryResult)] - struct AlreadySpecifiedBoundsGenericTest { - _foo: T, - } - - #[derive(FromQueryResult)] - struct MixedGenericTest - where - F: Copy + Clone + 'static, - { - _foo: T, - _bar: F, - } +#[derive(FromQueryResult)] +struct SimpleTest { + _foo: i32, + _bar: String, } -mod partial_model { - use entity::{Column, Entity}; - use sea_orm::{ColumnTrait, DerivePartialModel, FromQueryResult}; - use sea_query::Expr; - mod entity { - use sea_orm::{ - ActiveModelBehavior, DeriveEntityModel, DerivePrimaryKey, DeriveRelation, EntityTrait, - EnumIter, PrimaryKeyTrait, - }; - - #[derive(Debug, Clone, DeriveEntityModel)] - #[sea_orm(table_name = "foo_table")] - pub struct Model { - #[sea_orm(primary_key)] - id: i32, - foo: i32, - bar: String, - foo2: bool, - bar2: f64, - } - - #[derive(Debug, DeriveRelation, EnumIter)] - pub enum Relation {} - - impl ActiveModelBehavior for ActiveModel {} - } - - #[derive(FromQueryResult, DerivePartialModel)] - #[sea_orm(entity = "Entity")] - struct SimpleTest { - _foo: i32, - _bar: String, - } - - #[derive(FromQueryResult, DerivePartialModel)] - #[sea_orm(entity = "Entity")] - struct FieldFromDiffNameColumnTest { - #[sea_orm(from_col = "foo2")] - _foo: i32, - #[sea_orm(from_col = "bar2")] - _bar: String, - } - - #[derive(FromQueryResult, DerivePartialModel)] - struct FieldFromExpr { - #[sea_orm(from_expr = "Column::Bar2.sum()")] - _foo: f64, - #[sea_orm(from_expr = "Expr::col(Column::Id).equals(Column::Foo)")] - _bar: bool, - } +#[derive(FromQueryResult)] +struct GenericTest { + _foo: i32, + _bar: T, +} + +#[derive(FromQueryResult)] +struct DoubleGenericTest { + _foo: T, + _bar: F, +} + +#[derive(FromQueryResult)] +struct BoundsGenericTest { + _foo: T, +} + +#[derive(FromQueryResult)] +struct WhereGenericTest +where + T: TryGetable + Copy + Clone + 'static, +{ + _foo: T, +} + +#[derive(FromQueryResult)] +struct AlreadySpecifiedBoundsGenericTest { + _foo: T, +} + +#[derive(FromQueryResult)] +struct MixedGenericTest +where + F: TryGetable + Copy + Clone + 'static, +{ + _foo: T, + _bar: F, +} + +trait MyTrait { + type Item: TryGetable; +} + +#[derive(FromQueryResult)] +struct TraitAssociateTypeTest +where + T: MyTrait, +{ + _foo: T::Item, } diff --git a/tests/partial_model_tests.rs b/tests/partial_model_tests.rs new file mode 100644 index 00000000..a9350045 --- /dev/null +++ b/tests/partial_model_tests.rs @@ -0,0 +1,47 @@ +use entity::{Column, Entity}; +use sea_orm::{ColumnTrait, DerivePartialModel, FromQueryResult}; +use sea_query::Expr; + +mod entity { + use sea_orm::prelude::*; + + #[derive(Debug, Clone, DeriveEntityModel)] + #[sea_orm(table_name = "foo_table")] + pub struct Model { + #[sea_orm(primary_key)] + id: i32, + foo: i32, + bar: String, + foo2: bool, + bar2: f64, + } + + #[derive(Debug, DeriveRelation, EnumIter)] + pub enum Relation {} + + impl ActiveModelBehavior for ActiveModel {} +} + +#[derive(FromQueryResult, DerivePartialModel)] +#[sea_orm(entity = "Entity")] +struct SimpleTest { + _foo: i32, + _bar: String, +} + +#[derive(FromQueryResult, DerivePartialModel)] +#[sea_orm(entity = "Entity")] +struct FieldFromDiffNameColumnTest { + #[sea_orm(from_col = "foo2")] + _foo: i32, + #[sea_orm(from_col = "bar2")] + _bar: String, +} + +#[derive(FromQueryResult, DerivePartialModel)] +struct FieldFromExpr { + #[sea_orm(from_expr = "Column::Bar2.sum()")] + _foo: f64, + #[sea_orm(from_expr = "Expr::col(Column::Id).equals(Column::Foo)")] + _bar: bool, +} diff --git a/tests/relational_tests.rs b/tests/relational_tests.rs index 29a1556c..49927534 100644 --- a/tests/relational_tests.rs +++ b/tests/relational_tests.rs @@ -4,7 +4,8 @@ pub use chrono::offset::Utc; pub use common::{bakery_chain::*, setup::*, TestContext}; pub use rust_decimal::prelude::*; pub use rust_decimal_macros::dec; -pub use sea_orm::{entity::*, query::*, DbErr, FromQueryResult}; +pub use sea_orm::{entity::*, query::*, DbErr, DerivePartialModel, FromQueryResult}; +pub use sea_query::{Alias, Expr, Func, SimpleExpr}; pub use uuid::Uuid; // Run the test locally: @@ -66,6 +67,7 @@ pub async fn left_join() { .filter(baker::Column::Name.contains("Baker 1")); let result = select + .clone() .into_model::() .one(&ctx.db) .await @@ -74,6 +76,28 @@ pub async fn left_join() { assert_eq!(result.name.as_str(), "Baker 1"); assert_eq!(result.bakery_name, Some("SeaSide Bakery".to_string())); + #[derive(DerivePartialModel, FromQueryResult, Debug, PartialEq)] + #[sea_orm(entity = "Baker")] + struct PartialSelectResult { + name: String, + #[sea_orm(from_expr = "Expr::col((bakery::Entity, bakery::Column::Name))")] + bakery_name: Option, + #[sea_orm( + from_expr = r#"SimpleExpr::FunctionCall(Func::upper(Expr::col((bakery::Entity, bakery::Column::Name))))"# + )] + bakery_name_upper: Option, + } + + let result = select + .into_partial_model::() + .one(&ctx.db) + .await + .unwrap() + .unwrap(); + assert_eq!(result.name.as_str(), "Baker 1"); + assert_eq!(result.bakery_name, Some("SeaSide Bakery".to_string())); + assert_eq!(result.bakery_name_upper, Some("SEASIDE BAKERY".to_string())); + let select = baker::Entity::find() .left_join(bakery::Entity) .select_only()