diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index c514588d..874b50c2 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -18,6 +18,8 @@ where pub(crate) selector: PhantomData, } +// LINT: warn if paginator is used without an order by clause + impl<'db, S> Paginator<'db, S> where S: SelectorTrait + 'db, @@ -46,12 +48,12 @@ where self.fetch_page(self.page).await } - /// Get the total number of pages - pub async fn num_pages(&self) -> Result { + /// Get the total number of items + pub async fn num_items(&self) -> Result { let builder = self.db.get_database_backend(); let stmt = builder.build( SelectStatement::new() - .expr(Expr::cust("COUNT(*) AS num_rows")) + .expr(Expr::cust("COUNT(*) AS num_items")) .from_subquery( self.query.clone().reset_limit().reset_offset().to_owned(), Alias::new("sub_query"), @@ -61,8 +63,14 @@ where Some(res) => res, None => return Ok(0), }; - let num_rows = result.try_get::("", "num_rows")? as usize; - let num_pages = (num_rows / self.page_size) + (num_rows % self.page_size > 0) as usize; + let num_items = result.try_get::("", "num_items")? as usize; + Ok(num_items) + } + + /// Get the total number of pages + pub async fn num_pages(&self) -> Result { + let num_items = self.num_items().await?; + let num_pages = (num_items / self.page_size) + (num_items % self.page_size > 0) as usize; Ok(num_pages) } @@ -136,15 +144,15 @@ mod tests { (db, vec![page1, page2, page3]) } - fn setup_num_rows() -> (DatabaseConnection, i32) { - let num_rows = 3; + fn setup_num_items() -> (DatabaseConnection, i32) { + let num_items = 3; let db = MockDatabase::new(DatabaseBackend::Postgres) .append_query_results(vec![vec![maplit::btreemap! { - "num_rows" => Into::::into(num_rows), + "num_items" => Into::::into(num_items), }]]) .into_connection(); - (db, num_rows) + (db, num_items) } #[async_std::test] @@ -213,11 +221,11 @@ mod tests { #[async_std::test] async fn num_pages() -> Result<(), DbErr> { - let (db, num_rows) = setup_num_rows(); + let (db, num_items) = setup_num_items(); - let num_rows = num_rows as usize; + let num_items = num_items as usize; let page_size = 2_usize; - let num_pages = (num_rows / page_size) + (num_rows % page_size > 0) as usize; + let num_pages = (num_items / page_size) + (num_items % page_size > 0) as usize; let paginator = fruit::Entity::find().paginate(&db, page_size); assert_eq!(paginator.num_pages().await?, num_pages); @@ -232,7 +240,7 @@ mod tests { .to_owned(); let select = SelectStatement::new() - .expr(Expr::cust("COUNT(*) AS num_rows")) + .expr(Expr::cust("COUNT(*) AS num_items")) .from_subquery(sub_query, Alias::new("sub_query")) .to_owned(); diff --git a/src/executor/select.rs b/src/executor/select.rs index 9d8ad4a6..77acd31c 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -116,6 +116,10 @@ where ) -> Paginator<'_, SelectModel> { self.into_model().paginate(db, page_size) } + + pub async fn count(self, db: &DatabaseConnection) -> Result { + self.paginate(db, 1).num_items().await + } } impl SelectTwo @@ -155,6 +159,18 @@ where ) -> Result)>, DbErr> { self.into_model().all(db).await } + + pub fn paginate( + self, + db: &DatabaseConnection, + page_size: usize, + ) -> Paginator<'_, SelectTwoModel> { + self.into_model().paginate(db, page_size) + } + + pub async fn count(self, db: &DatabaseConnection) -> Result { + self.paginate(db, 1).num_items().await + } } impl SelectTwoMany @@ -195,6 +211,15 @@ where let rows = self.into_model().all(db).await?; Ok(consolidate_query_result::(rows)) } + + // pub fn paginate() + // we could not implement paginate easily, if the number of children for a + // parent is larger than one page, then we will end up splitting it in two pages + // so the correct way is actually perform query in two stages + // paginate the parent model and then populate the children + + // pub fn count() + // we should only count the number of items of the parent model } impl Selector diff --git a/tests/basic.rs b/tests/basic.rs index 77c46ebc..3faaa346 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -59,6 +59,12 @@ async fn crud_cake(db: &DbConn) -> Result<(), DbErr> { println!(); println!("Updated: {:?}", apple); + let count = cake::Entity::find().count(db).await?; + + println!(); + println!("Count: {:?}", count); + assert_eq!(count, 1); + let apple = cake::Entity::find_by_id(1).one(db).await?; assert_eq!( @@ -80,5 +86,11 @@ async fn crud_cake(db: &DbConn) -> Result<(), DbErr> { assert_eq!(None, apple); + let count = cake::Entity::find().count(db).await?; + + println!(); + println!("Count: {:?}", count); + assert_eq!(count, 0); + Ok(()) }