From 937af050037add74abca9694adfbf1af82a7805d Mon Sep 17 00:00:00 2001 From: Billy Chan <30400950+billy1624@users.noreply.github.com> Date: Sun, 6 Jun 2021 22:30:54 +0800 Subject: [PATCH] Paginator API (#10) * Paginator * Remove unnecessary dependency, rename to num_pages * Hotfix - nullable json value * num_pages implemented on sub_query --- Cargo.toml | 2 + examples/sqlx-mysql/Cargo.toml | 5 +- examples/sqlx-mysql/src/select.rs | 75 ++++++++++++++++++++++++++ src/connector/mod.rs | 2 + src/connector/paginator.rs | 87 +++++++++++++++++++++++++++++++ src/connector/select.rs | 19 +++++-- src/query/json.rs | 2 +- 7 files changed, 186 insertions(+), 6 deletions(-) create mode 100644 src/connector/paginator.rs diff --git a/Cargo.toml b/Cargo.toml index a333ddf5..b8da7967 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,8 @@ serde = { version = "^1.0", features = [ "derive" ] } sqlx = { version = "^0.5", optional = true } strum = { version = "^0.20", features = [ "derive" ] } serde_json = { version = "^1", optional = true } +async-stream = { version = "^0.3" } +futures-util = { version = "^0.3" } [features] debug-print = [] diff --git a/examples/sqlx-mysql/Cargo.toml b/examples/sqlx-mysql/Cargo.toml index 3bb96883..44f398b1 100644 --- a/examples/sqlx-mysql/Cargo.toml +++ b/examples/sqlx-mysql/Cargo.toml @@ -9,4 +9,7 @@ async-std = { version = "^1.9", features = [ "attributes" ] } sea-orm = { path = "../../", features = [ "sqlx-mysql", "runtime-async-std-native-tls", "debug-print" ] } # sea-query = { path = "../../../sea-query" } strum = { version = "^0.20", features = [ "derive" ] } -serde_json = { version = "^1" } \ No newline at end of file +serde_json = { version = "^1" } +futures = { version = "^0.3" } +async-stream = { version = "^0.3" } +futures-util = { version = "^0.3" } \ No newline at end of file diff --git a/examples/sqlx-mysql/src/select.rs b/examples/sqlx-mysql/src/select.rs index 627b266f..ab3c87c8 100644 --- a/examples/sqlx-mysql/src/select.rs +++ b/examples/sqlx-mysql/src/select.rs @@ -26,6 +26,18 @@ pub async fn all_about_select(db: &Database) -> Result<(), QueryErr> { all_about_select_json(db).await?; } + println!("===== =====\n"); + + find_all_stream(&db).await.unwrap(); + + println!("===== =====\n"); + + find_first_page(&db).await.unwrap(); + + println!("===== =====\n"); + + find_num_pages(&db).await.unwrap(); + Ok(()) } @@ -232,3 +244,66 @@ async fn count_fruits_by_cake_json(db: &Database) -> Result<(), QueryErr> { Ok(()) } + +async fn find_all_stream(db: &Database) -> Result<(), QueryErr> { + use futures::TryStreamExt; + use std::time::Duration; + use async_std::task::sleep; + + println!("find all cakes: "); + let mut cake_paginator = cake::Entity::find().paginate(db, 2); + while let Some(cake_res) = cake_paginator.fetch_and_next().await? { + for cake in cake_res { + println!("{:?}", cake); + } + } + + println!(); + println!("find all fruits: "); + let mut fruit_paginator = fruit::Entity::find().paginate(db, 2); + while let Some(fruit_res) = fruit_paginator.fetch_and_next().await? { + for fruit in fruit_res { + println!("{:?}", fruit); + } + } + + println!(); + println!("find all fruits with stream: "); + let mut fruit_stream = fruit::Entity::find().paginate(db, 2).into_stream(); + while let Some(fruits) = fruit_stream.try_next().await? { + for fruit in fruits { + println!("{:?}", fruit); + } + sleep(Duration::from_millis(250)).await; + } + + println!(); + println!("find all fruits in json with stream: "); + let mut json_stream = fruit::Entity::find().into_json().paginate(db, 2).into_stream(); + while let Some(jsons) = json_stream.try_next().await? { + for json in jsons { + println!("{:?}", json); + } + sleep(Duration::from_millis(250)).await; + } + + Ok(()) +} + +async fn find_first_page(db: &Database) -> Result<(), QueryErr> { + println!("fruits first page: "); + let page = fruit::Entity::find().paginate(db, 2).fetch_page(0).await?; + for fruit in page { + println!("{:?}", fruit); + } + + Ok(()) +} + +async fn find_num_pages(db: &Database) -> Result<(), QueryErr> { + println!("fruits number of page: "); + let num_pages = fruit::Entity::find().paginate(db, 2).num_pages().await?; + println!("{:?}", num_pages); + + Ok(()) +} diff --git a/src/connector/mod.rs b/src/connector/mod.rs index 906f724a..8ae914c2 100644 --- a/src/connector/mod.rs +++ b/src/connector/mod.rs @@ -1,10 +1,12 @@ mod executor; mod insert; +mod paginator; mod select; mod update; pub use executor::*; pub use insert::*; +pub use paginator::*; pub use select::*; pub use update::*; diff --git a/src/connector/paginator.rs b/src/connector/paginator.rs new file mode 100644 index 00000000..0e85ebab --- /dev/null +++ b/src/connector/paginator.rs @@ -0,0 +1,87 @@ +use crate::{Connection, Database, QueryErr, SelectorTrait}; +use futures::Stream; +use async_stream::stream; +use std::{marker::PhantomData, pin::Pin}; +use sea_query::{Alias, Expr, SelectStatement}; + +pub type PinBoxStream<'db, Item> = Pin + 'db>>; + +#[derive(Clone, Debug)] +pub struct Paginator<'db, S> +where + S: SelectorTrait + 'db, +{ + pub(crate) query: SelectStatement, + pub(crate) page: usize, + pub(crate) page_size: usize, + pub(crate) db: &'db Database, + pub(crate) selector: PhantomData, +} + +impl<'db, S> Paginator<'db, S> +where + S: SelectorTrait + 'db, +{ + pub async fn fetch_page(&mut self, page: usize) -> Result, QueryErr> { + self.query.limit(self.page_size as u64).offset((self.page_size * page) as u64); + let builder = self.db.get_query_builder_backend(); + let stmt = self.query.build(builder).into(); + let rows = self.db.get_connection().query_all(stmt).await?; + let mut buffer = Vec::with_capacity(rows.len()); + for row in rows.into_iter() { + // TODO: Error handling + buffer.push(S::from_raw_query_result(row).map_err(|_e| QueryErr)?); + } + Ok(buffer) + } + + pub async fn fetch(&mut self) -> Result, QueryErr> { + self.fetch_page(self.page).await + } + + pub async fn num_pages(&mut self) -> Result { + let builder = self.db.get_query_builder_backend(); + let stmt = SelectStatement::new() + .expr(Expr::cust("COUNT(*) AS num_rows")) + .from_subquery( + self.query.clone().reset_limit().reset_offset().to_owned(), + Alias::new("sub_query") + ) + .build(builder) + .into(); + let result = match self.db.get_connection().query_one(stmt).await? { + Some(res) => res, + None => return Ok(0), + }; + let num_rows = result.try_get::("", "num_rows").map_err(|_e| QueryErr)? as usize; + let num_pages = (num_rows / self.page_size) + (num_rows % self.page_size > 0) as usize; + Ok(num_pages) + } + + pub fn next(&mut self) { + self.page += 1; + } + + pub async fn fetch_and_next(&mut self) -> Result>, QueryErr> { + let vec = self.fetch().await?; + self.next(); + let opt = if !vec.is_empty() { + Some(vec) + } else { + None + }; + Ok(opt) + } + + pub fn into_stream(mut self) -> PinBoxStream<'db, Result, QueryErr>> { + Box::pin(stream! { + loop { + if let Some(vec) = self.fetch_and_next().await? { + yield Ok(vec); + } else { + break + } + } + }) + } +} diff --git a/src/connector/select.rs b/src/connector/select.rs index c3831c41..c272af7b 100644 --- a/src/connector/select.rs +++ b/src/connector/select.rs @@ -1,7 +1,4 @@ -use crate::{ - query::combine, Connection, Database, EntityTrait, FromQueryResult, JsonValue, QueryErr, - QueryResult, Select, SelectTwo, Statement, TypeErr, -}; +use crate::{Connection, Database, EntityTrait, FromQueryResult, JsonValue, Paginator, QueryErr, QueryResult, Select, SelectTwo, Statement, TypeErr, query::combine}; use sea_query::{QueryBuilder, SelectStatement}; use std::marker::PhantomData; @@ -91,6 +88,10 @@ where pub async fn all(self, db: &Database) -> Result, QueryErr> { self.into_model::().all(db).await } + + pub fn paginate<'db>(self, db: &'db Database, page_size: usize) -> Paginator<'db, SelectModel> { + self.into_model::().paginate(db, page_size) + } } impl SelectTwo @@ -156,4 +157,14 @@ where } Ok(models) } + + pub fn paginate<'db>(self, db: &'db Database, page_size: usize) -> Paginator<'db, S> { + Paginator { + query: self.query, + page: 0, + page_size, + db, + selector: PhantomData, + } + } } diff --git a/src/query/json.rs b/src/query/json.rs index a6699745..8227ae76 100644 --- a/src/query/json.rs +++ b/src/query/json.rs @@ -18,7 +18,7 @@ impl FromQueryResult for JsonValue { macro_rules! match_mysql_type { ( $type: ty ) => { if <$type as Type>::type_info().eq(col_type) { - map.insert(col.to_owned(), json!(res.try_get::<$type>(pre, &col)?)); + map.insert(col.to_owned(), json!(res.try_get::>(pre, &col)?)); continue; } };