Merge pull request #471 from sebpuetz/send-streams

Add Send bound to streams.
This commit is contained in:
Chris Tsang 2022-03-06 22:11:12 +08:00 committed by GitHub
commit ccd0d97eb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 177 additions and 65 deletions

View File

@ -0,0 +1,3 @@
HOST=127.0.0.1
PORT=8000
DATABASE_URL="postgres://postgres:password@localhost/axum_exmaple"

View File

@ -0,0 +1,23 @@
[package]
name = "sea-orm-axum-example"
version = "0.1.0"
authors = ["Sebastian Pütz <seb.puetz@gmail.com>"]
edition = "2021"
publish = false
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[workspace]
[dependencies]
tokio = { version = "1.14", features = ["full"] }
anyhow = "1"
dotenv = "0.15"
futures-util = "0.3"
serde = "1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[dependencies.sea-orm]
path = "../../" # remove this line in your own project
# version = "^0.5.0"
features = ["macros", "mock", "sqlx-all", "runtime-tokio-rustls", "debug-print"]
default-features = false

View File

@ -0,0 +1 @@
Demonstrator for using streaming queries with `tokio::spawn` or in contexts that require `Send` futures.

View File

@ -0,0 +1,29 @@
mod post;
mod setup;
use futures_util::StreamExt;
use post::Entity as Post;
use sea_orm::{prelude::*, Database};
use std::env;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
env::set_var("RUST_LOG", "debug");
tracing_subscriber::fmt::init();
dotenv::dotenv().ok();
let db_url = env::var("DATABASE_URL").expect("DATABASE_URL is not set in .env file");
let db = Database::connect(db_url)
.await
.expect("Database connection failed");
let _ = setup::create_post_table(&db);
tokio::task::spawn(async move {
let mut stream = Post::find().stream(&db).await.unwrap();
while let Some(item) = stream.next().await {
let item = item?;
println!("got something: {}", item.text);
}
Ok::<(), anyhow::Error>(())
})
.await?
}

View File

@ -0,0 +1,26 @@
//! SeaORM Entity. Generated by sea-orm-codegen 0.3.2
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)]
#[sea_orm(table_name = "posts")]
pub struct Model {
#[sea_orm(primary_key)]
#[serde(skip_deserializing)]
pub id: i32,
pub title: String,
#[sea_orm(column_type = "Text")]
pub text: String,
}
#[derive(Copy, Clone, Debug, EnumIter)]
pub enum Relation {}
impl RelationTrait for Relation {
fn def(&self) -> RelationDef {
panic!("No RelationDef")
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,33 @@
use sea_orm::sea_query::{ColumnDef, TableCreateStatement};
use sea_orm::{error::*, sea_query, ConnectionTrait, DbConn, ExecResult};
async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result<ExecResult, DbErr> {
let builder = db.get_database_backend();
db.execute(builder.build(stmt)).await
}
pub async fn create_post_table(db: &DbConn) -> Result<ExecResult, DbErr> {
let stmt = sea_query::Table::create()
.table(super::post::Entity)
.if_not_exists()
.col(
ColumnDef::new(super::post::Column::Id)
.integer()
.not_null()
.auto_increment()
.primary_key(),
)
.col(
ColumnDef::new(super::post::Column::Title)
.string()
.not_null(),
)
.col(
ColumnDef::new(super::post::Column::Text)
.string()
.not_null(),
)
.to_owned();
create_table(db, &stmt).await
}

View File

@ -9,7 +9,7 @@ use std::{future::Future, pin::Pin};
#[async_trait::async_trait] #[async_trait::async_trait]
pub trait ConnectionTrait<'a>: Sync { pub trait ConnectionTrait<'a>: Sync {
/// Create a stream for the [QueryResult] /// Create a stream for the [QueryResult]
type Stream: Stream<Item = Result<QueryResult, DbErr>>; type Stream: Stream<Item = Result<QueryResult, DbErr>> + Send;
/// Fetch the database backend as specified in [DbBackend]. /// Fetch the database backend as specified in [DbBackend].
/// This depends on feature flags enabled. /// This depends on feature flags enabled.
@ -28,7 +28,7 @@ pub trait ConnectionTrait<'a>: Sync {
fn stream( fn stream(
&'a self, &'a self,
stmt: Statement, stmt: Statement,
) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a>>; ) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a + Send>>;
/// Execute SQL `BEGIN` transaction. /// Execute SQL `BEGIN` transaction.
/// Returns a Transaction that can be committed or rolled back /// Returns a Transaction that can be committed or rolled back

View File

@ -155,7 +155,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
fn stream( fn stream(
&'a self, &'a self,
stmt: Statement, stmt: Statement,
) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a>> { ) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a + Send>> {
Box::pin(async move { Box::pin(async move {
Ok(match self { Ok(match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]

View File

@ -24,7 +24,7 @@ pub struct QueryStream {
metric_callback: Option<crate::metric::Callback>, metric_callback: Option<crate::metric::Callback>,
#[borrows(mut conn, stmt, metric_callback)] #[borrows(mut conn, stmt, metric_callback)]
#[not_covariant] #[not_covariant]
stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + 'this>>, stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send + 'this>>,
} }
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]

View File

@ -24,7 +24,7 @@ pub struct TransactionStream<'a> {
metric_callback: Option<crate::metric::Callback>, metric_callback: Option<crate::metric::Callback>,
#[borrows(mut conn, stmt, metric_callback)] #[borrows(mut conn, stmt, metric_callback)]
#[not_covariant] #[not_covariant]
stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + 'this>>, stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + 'this + Send>>,
} }
impl<'a> std::fmt::Debug for TransactionStream<'a> { impl<'a> std::fmt::Debug for TransactionStream<'a> {
@ -35,62 +35,57 @@ impl<'a> std::fmt::Debug for TransactionStream<'a> {
impl<'a> TransactionStream<'a> { impl<'a> TransactionStream<'a> {
#[instrument(level = "trace", skip(metric_callback))] #[instrument(level = "trace", skip(metric_callback))]
pub(crate) async fn build( pub(crate) fn build(
conn: MutexGuard<'a, InnerConnection>, conn: MutexGuard<'a, InnerConnection>,
stmt: Statement, stmt: Statement,
metric_callback: Option<crate::metric::Callback>, metric_callback: Option<crate::metric::Callback>,
) -> TransactionStream<'a> { ) -> TransactionStream<'a> {
TransactionStreamAsyncBuilder { TransactionStreamBuilder {
stmt, stmt,
conn, conn,
metric_callback, metric_callback,
stream_builder: |conn, stmt, _metric_callback| { stream_builder: |conn, stmt, _metric_callback| match conn.deref_mut() {
Box::pin(async move { #[cfg(feature = "sqlx-mysql")]
match conn.deref_mut() { InnerConnection::MySql(c) => {
#[cfg(feature = "sqlx-mysql")] let query = crate::driver::sqlx_mysql::sqlx_query(stmt);
InnerConnection::MySql(c) => { crate::metric::metric_ok!(_metric_callback, stmt, {
let query = crate::driver::sqlx_mysql::sqlx_query(stmt); Box::pin(
crate::metric::metric_ok!(_metric_callback, stmt, { c.fetch(query)
Box::pin( .map_ok(Into::into)
c.fetch(query) .map_err(crate::sqlx_error_to_query_err),
.map_ok(Into::into) )
.map_err(crate::sqlx_error_to_query_err), as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>>
) })
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>> }
}) #[cfg(feature = "sqlx-postgres")]
} InnerConnection::Postgres(c) => {
#[cfg(feature = "sqlx-postgres")] let query = crate::driver::sqlx_postgres::sqlx_query(stmt);
InnerConnection::Postgres(c) => { crate::metric::metric_ok!(_metric_callback, stmt, {
let query = crate::driver::sqlx_postgres::sqlx_query(stmt); Box::pin(
crate::metric::metric_ok!(_metric_callback, stmt, { c.fetch(query)
Box::pin( .map_ok(Into::into)
c.fetch(query) .map_err(crate::sqlx_error_to_query_err),
.map_ok(Into::into) )
.map_err(crate::sqlx_error_to_query_err), as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>>
) })
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>> }
}) #[cfg(feature = "sqlx-sqlite")]
} InnerConnection::Sqlite(c) => {
#[cfg(feature = "sqlx-sqlite")] let query = crate::driver::sqlx_sqlite::sqlx_query(stmt);
InnerConnection::Sqlite(c) => { crate::metric::metric_ok!(_metric_callback, stmt, {
let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); Box::pin(
crate::metric::metric_ok!(_metric_callback, stmt, { c.fetch(query)
Box::pin( .map_ok(Into::into)
c.fetch(query) .map_err(crate::sqlx_error_to_query_err),
.map_ok(Into::into) )
.map_err(crate::sqlx_error_to_query_err), as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>>
) })
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>> }
}) #[cfg(feature = "mock")]
} InnerConnection::Mock(c) => c.fetch(stmt),
#[cfg(feature = "mock")]
InnerConnection::Mock(c) => c.fetch(stmt),
}
})
}, },
} }
.build() .build()
.await
} }
} }

View File

@ -354,14 +354,14 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
fn stream( fn stream(
&'a self, &'a self,
stmt: Statement, stmt: Statement,
) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a>> { ) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a + Send>> {
Box::pin(async move { Box::pin(async move {
let conn = self.conn.lock().await;
Ok(crate::TransactionStream::build( Ok(crate::TransactionStream::build(
self.conn.lock().await, conn,
stmt, stmt,
self.metric_callback.clone(), self.metric_callback.clone(),
) ))
.await)
}) })
} }

View File

@ -148,7 +148,7 @@ impl MockDatabaseConnection {
pub fn fetch( pub fn fetch(
&self, &self,
statement: &Statement, statement: &Statement,
) -> Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>> { ) -> Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>> {
match self.query_all(statement.clone()) { match self.query_all(statement.clone()) {
Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(Ok))), Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(Ok))),
Err(e) => Box::pin(futures::stream::iter(Some(Err(e)).into_iter())), Err(e) => Box::pin(futures::stream::iter(Some(Err(e)).into_iter())),

View File

@ -273,9 +273,9 @@ where
pub async fn stream<'a: 'b, 'b, C>( pub async fn stream<'a: 'b, 'b, C>(
self, self,
db: &'a C, db: &'a C,
) -> Result<impl Stream<Item = Result<E::Model, DbErr>> + 'b, DbErr> ) -> Result<impl Stream<Item = Result<E::Model, DbErr>> + 'b + Send, DbErr>
where where
C: ConnectionTrait<'a>, C: ConnectionTrait<'a> + Send,
{ {
self.into_model().stream(db).await self.into_model().stream(db).await
} }
@ -329,7 +329,7 @@ where
db: &'a C, db: &'a C,
) -> Result<impl Stream<Item = Result<(E::Model, Option<F::Model>), DbErr>> + 'b, DbErr> ) -> Result<impl Stream<Item = Result<(E::Model, Option<F::Model>), DbErr>> + 'b, DbErr>
where where
C: ConnectionTrait<'a>, C: ConnectionTrait<'a> + Send,
{ {
self.into_model().stream(db).await self.into_model().stream(db).await
} }
@ -373,9 +373,9 @@ where
pub async fn stream<'a: 'b, 'b, C>( pub async fn stream<'a: 'b, 'b, C>(
self, self,
db: &'a C, db: &'a C,
) -> Result<impl Stream<Item = Result<(E::Model, Option<F::Model>), DbErr>> + 'b, DbErr> ) -> Result<impl Stream<Item = Result<(E::Model, Option<F::Model>), DbErr>> + 'b + Send, DbErr>
where where
C: ConnectionTrait<'a>, C: ConnectionTrait<'a> + Send,
{ {
self.into_model().stream(db).await self.into_model().stream(db).await
} }
@ -452,10 +452,11 @@ where
pub async fn stream<'a: 'b, 'b, C>( pub async fn stream<'a: 'b, 'b, C>(
self, self,
db: &'a C, db: &'a C,
) -> Result<Pin<Box<dyn Stream<Item = Result<S::Item, DbErr>> + 'b>>, DbErr> ) -> Result<Pin<Box<dyn Stream<Item = Result<S::Item, DbErr>> + 'b + Send>>, DbErr>
where where
C: ConnectionTrait<'a>, C: ConnectionTrait<'a> + Send,
S: 'b, S: 'b,
S::Item: Send,
{ {
self.into_selector_raw(db).stream(db).await self.into_selector_raw(db).stream(db).await
} }
@ -737,10 +738,11 @@ where
pub async fn stream<'a: 'b, 'b, C>( pub async fn stream<'a: 'b, 'b, C>(
self, self,
db: &'a C, db: &'a C,
) -> Result<Pin<Box<dyn Stream<Item = Result<S::Item, DbErr>> + 'b>>, DbErr> ) -> Result<Pin<Box<dyn Stream<Item = Result<S::Item, DbErr>> + 'b + Send>>, DbErr>
where where
C: ConnectionTrait<'a>, C: ConnectionTrait<'a> + Send,
S: 'b, S: 'b,
S::Item: Send,
{ {
let stream = db.stream(self.stmt).await?; let stream = db.stream(self.stmt).await?;
Ok(Box::pin(stream.and_then(|row| { Ok(Box::pin(stream.and_then(|row| {