From 23498892b06fe9c3e0023a8d9ec0291a766b771f Mon Sep 17 00:00:00 2001 From: jasper Date: Tue, 9 Nov 2021 21:21:42 +0800 Subject: [PATCH] Add PaginatorTrait and CountTrait --- src/executor/paginator.rs | 86 +++++++++++++++++++++++++++++++++++++-- src/executor/select.rs | 58 +------------------------- tests/crud/updates.rs | 2 +- 3 files changed, 86 insertions(+), 60 deletions(-) diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index fd8e822c..727f9076 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -1,4 +1,4 @@ -use crate::{error::*, ConnectionTrait, DbBackend, SelectorTrait}; +use crate::{ConnectionTrait, DbBackend, EntityTrait, FromQueryResult, Select, SelectModel, SelectTwo, SelectTwoModel, Selector, SelectorTrait, error::*}; use async_stream::stream; use futures::Stream; use sea_query::{Alias, Expr, SelectStatement}; @@ -95,7 +95,7 @@ where /// /// ```rust /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, MockDatabase, DbBackend}; + /// # use sea_orm::{error::*, MockDatabase, DbBackend, PaginatorTrait}; /// # let owned_db = MockDatabase::new(DbBackend::Postgres).into_connection(); /// # let db = &owned_db; /// # let _: Result<(), DbErr> = smol::block_on(async { @@ -123,7 +123,7 @@ where /// /// ```rust /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, MockDatabase, DbBackend}; + /// # use sea_orm::{error::*, MockDatabase, DbBackend, PaginatorTrait}; /// # let owned_db = MockDatabase::new(DbBackend::Postgres).into_connection(); /// # let db = &owned_db; /// # let _: Result<(), DbErr> = smol::block_on(async { @@ -155,12 +155,92 @@ where } } +/// Used to enforce constraints on any type that wants to paginate results +pub trait PaginatorTrait<'db, C> +where + C: ConnectionTrait<'db>, +{ + /// Select operation + type Selector: SelectorTrait + 'db; + + /// Paginate the result of a select operation. + fn paginate(self, db: &'db C, page_size: usize) -> Paginator<'db, C, Self::Selector>; +} + +impl<'db, C, S> PaginatorTrait<'db, C> for Selector +where + C: ConnectionTrait<'db>, + S: SelectorTrait + 'db, +{ + type Selector = S; + + fn paginate(self, db: &'db C, page_size: usize) -> Paginator<'db, C, S> { + Paginator { + query: self.query, + page: 0, + page_size, + db, + selector: PhantomData, + } + } +} + +impl<'db, C, M, E> PaginatorTrait<'db, C> for Select +where + C: ConnectionTrait<'db>, + E: EntityTrait, + M: FromQueryResult + Sized + 'db, +{ + type Selector = SelectModel; + + fn paginate(self, db: &'db C, page_size: usize) -> Paginator<'db, C, Self::Selector> { + self.into_model().paginate(db, page_size) + } +} + +impl<'db, C, M, N, E, F> PaginatorTrait<'db, C> for SelectTwo +where + C: ConnectionTrait<'db>, + E: EntityTrait, + F: EntityTrait, + M: FromQueryResult + Sized + 'db, + N: FromQueryResult + Sized + 'db, +{ + type Selector = SelectTwoModel; + + fn paginate(self, db: &'db C, page_size: usize) -> Paginator<'db, C, Self::Selector> { + self.into_model().paginate(db, page_size) + } +} + +/// Used to enforce constraints on any type that wants to count results using pagination. +#[async_trait::async_trait] +pub trait CountTrait<'db, C>: PaginatorTrait<'db, C> +where + C: ConnectionTrait<'db>, +{ + /// Perform a count on the paginated results + async fn count(self, db: &'db C) -> Result; +} + +#[async_trait::async_trait] +impl<'db, C, P, S> CountTrait<'db, C> for P +where + C: ConnectionTrait<'db>, + P: PaginatorTrait<'db, C, Selector = S> + Send, + S: SelectorTrait + Send + Sync + 'db +{ + async fn count(self, db:&'db C) -> Result { + self.paginate(db, 1).num_items().await + } +} #[cfg(test)] #[cfg(feature = "mock")] mod tests { use crate::entity::prelude::*; use crate::{tests_cfg::*, ConnectionTrait}; use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction}; + use super::*; use futures::TryStreamExt; use sea_query::{Alias, Expr, SelectStatement, Value}; diff --git a/src/executor/select.rs b/src/executor/select.rs index f9bbd756..2dd9e3c3 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -1,6 +1,6 @@ use crate::{ error::*, ConnectionTrait, EntityTrait, FromQueryResult, IdenStatic, Iterable, ModelTrait, - Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, SelectTwoMany, + PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, SelectTwoMany, Statement, TryGetableMany, }; use futures::{Stream, TryStreamExt}; @@ -17,7 +17,7 @@ pub struct Selector where S: SelectorTrait, { - query: SelectStatement, + pub(crate) query: SelectStatement, selector: S, } @@ -276,26 +276,6 @@ where { self.into_model().stream(db).await } - - /// Paginate the results of a SELECT operation on a Model - pub fn paginate<'a, C>( - self, - db: &'a C, - page_size: usize, - ) -> Paginator<'a, C, SelectModel> - where - C: ConnectionTrait<'a>, - { - self.into_model().paginate(db, page_size) - } - - /// Perform a `COUNT` operation on a items on a Model using pagination - pub async fn count<'a, C>(self, db: &'a C) -> Result - where - C: ConnectionTrait<'a>, - { - self.paginate(db, 1).num_items().await - } } impl SelectTwo @@ -350,26 +330,6 @@ where { self.into_model().stream(db).await } - - /// Paginate the results of a select operation on two models - pub fn paginate<'a, C>( - self, - db: &'a C, - page_size: usize, - ) -> Paginator<'a, C, SelectTwoModel> - where - C: ConnectionTrait<'a>, - { - self.into_model().paginate(db, page_size) - } - - /// Perform a count on the paginated results - pub async fn count<'a, C>(self, db: &'a C) -> Result - where - C: ConnectionTrait<'a>, - { - self.paginate(db, 1).num_items().await - } } impl SelectTwoMany @@ -499,20 +459,6 @@ where futures::future::ready(S::from_raw_query_result(row)) }))) } - - /// Paginate the result of a select operation on a Model - pub fn paginate<'a, C>(self, db: &'a C, page_size: usize) -> Paginator<'a, C, S> - where - C: ConnectionTrait<'a>, - { - Paginator { - query: self.query, - page: 0, - page_size, - db, - selector: PhantomData, - } - } } impl SelectorRaw diff --git a/tests/crud/updates.rs b/tests/crud/updates.rs index 262031ef..04aea292 100644 --- a/tests/crud/updates.rs +++ b/tests/crud/updates.rs @@ -1,6 +1,6 @@ pub use super::*; use rust_decimal_macros::dec; -use sea_orm::DbErr; +use sea_orm::{DbErr, CountTrait}; use uuid::Uuid; pub async fn test_update_cake(db: &DbConn) {