diff --git a/Cargo.toml b/Cargo.toml index a4466152..b1fa90c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } sqlx = { version = "^0.5", optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true } +ouroboros = "0.11" [dev-dependencies] smol = { version = "^1.2" } diff --git a/examples/rocket_example/src/main.rs b/examples/rocket_example/src/main.rs index 853eaaaa..43f227c6 100644 --- a/examples/rocket_example/src/main.rs +++ b/examples/rocket_example/src/main.rs @@ -42,7 +42,7 @@ async fn create(conn: Connection, post_form: Form) -> Flash, post_form: Form) -> Flash", data = "")] async fn update(conn: Connection, id: i32, post_form: Form) -> Flash { let post: post::ActiveModel = Post::find_by_id(id) - .one(&conn) + .one(&*conn) .await .unwrap() .unwrap() @@ -65,7 +65,7 @@ async fn update(conn: Connection, id: i32, post_form: Form) -> title: Set(form.title.to_owned()), text: Set(form.text.to_owned()), } - .save(&conn) + .save(&*conn) .await .expect("could not edit post"); @@ -89,7 +89,7 @@ async fn list( // Setup paginator let paginator = Post::find() .order_by_asc(post::Column::Id) - .paginate(&conn, posts_per_page); + .paginate(&*conn, posts_per_page); let num_pages = paginator.num_pages().await.ok().unwrap(); // Fetch paginated posts @@ -113,7 +113,7 @@ async fn list( #[get("/")] async fn edit(conn: Connection, id: i32) -> Template { let post: Option = Post::find_by_id(id) - .one(&conn) + .one(&*conn) .await .expect("could not find post"); @@ -128,20 +128,20 @@ async fn edit(conn: Connection, id: i32) -> Template { #[delete("/")] async fn delete(conn: Connection, id: i32) -> Flash { let post: post::ActiveModel = Post::find_by_id(id) - .one(&conn) + .one(&*conn) .await .unwrap() .unwrap() .into(); - post.delete(&conn).await.unwrap(); + post.delete(&*conn).await.unwrap(); Flash::success(Redirect::to("/"), "Post successfully deleted.") } #[delete("/")] async fn destroy(conn: Connection) -> Result<()> { - Post::delete_many().exec(&conn).await.unwrap(); + Post::delete_many().exec(&*conn).await.unwrap(); Ok(()) } diff --git a/examples/rocket_example/src/setup.rs b/examples/rocket_example/src/setup.rs index 034e8b53..91bbb7b1 100644 --- a/examples/rocket_example/src/setup.rs +++ b/examples/rocket_example/src/setup.rs @@ -1,5 +1,5 @@ use sea_orm::sea_query::{ColumnDef, TableCreateStatement}; -use sea_orm::{error::*, sea_query, DbConn, ExecResult}; +use sea_orm::{query::*, error::*, sea_query, DbConn, ExecResult}; async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result { let builder = db.get_database_backend(); diff --git a/src/database/connection.rs b/src/database/connection.rs index e5ec4e2c..630b8b96 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -1,4 +1,5 @@ -use crate::{error::*, ExecResult, QueryResult, Statement, StatementBuilder}; +use std::{future::Future, pin::Pin, sync::Arc}; +use crate::{DatabaseTransaction, ConnectionTrait, ExecResult, QueryResult, Statement, StatementBuilder, TransactionError, error::*}; use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; #[cfg_attr(not(feature = "mock"), derive(Clone))] @@ -10,7 +11,7 @@ pub enum DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection), #[cfg(feature = "mock")] - MockDatabaseConnection(crate::MockDatabaseConnection), + MockDatabaseConnection(Arc), Disconnected, } @@ -51,8 +52,11 @@ impl std::fmt::Debug for DatabaseConnection { } } -impl DatabaseConnection { - pub fn get_database_backend(&self) -> DbBackend { +#[async_trait::async_trait] +impl<'a> ConnectionTrait<'a> for DatabaseConnection { + type Stream = crate::QueryStream; + + fn get_database_backend(&self) -> DbBackend { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql, @@ -66,7 +70,7 @@ impl DatabaseConnection { } } - pub async fn execute(&self, stmt: Statement) -> Result { + async fn execute(&self, stmt: Statement) -> Result { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await, @@ -75,12 +79,12 @@ impl DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt).await, + DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } - pub async fn query_one(&self, stmt: Statement) -> Result, DbErr> { + async fn query_one(&self, stmt: Statement) -> Result, DbErr> { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await, @@ -89,12 +93,12 @@ impl DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt).await, + DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } - pub async fn query_all(&self, stmt: Statement) -> Result, DbErr> { + async fn query_all(&self, stmt: Statement) -> Result, DbErr> { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await, @@ -103,12 +107,76 @@ impl DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt).await, + DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } + fn stream(&'a self, stmt: Statement) -> Pin> + 'a>> { + Box::pin(async move { + Ok(match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => crate::QueryStream::from((Arc::clone(conn), stmt)), + DatabaseConnection::Disconnected => panic!("Disconnected"), + }) + }) + } + + async fn begin(&self) -> Result { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => DatabaseTransaction::new_mock(Arc::clone(conn)).await, + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } + + /// Execute the function inside a transaction. + /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. + async fn transaction(&self, _callback: F) -> Result> + where + F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, + T: Send, + E: std::error::Error + Send, + { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.transaction(_callback).await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => { + let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)).await.map_err(|e| TransactionError::Connection(e))?; + transaction.run(_callback).await + }, + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } + #[cfg(feature = "mock")] + fn is_mock_connection(&self) -> bool { + match self { + DatabaseConnection::MockDatabaseConnection(_) => true, + _ => false, + } + } +} + +#[cfg(feature = "mock")] +impl DatabaseConnection { pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection { match self { DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn, @@ -116,12 +184,6 @@ impl DatabaseConnection { } } - #[cfg(not(feature = "mock"))] - pub fn as_mock_connection(&self) -> Option { - None - } - - #[cfg(feature = "mock")] pub fn into_transaction_log(self) -> Vec { let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap(); mocker.drain_transaction_log() diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs new file mode 100644 index 00000000..569f3896 --- /dev/null +++ b/src/database/db_connection.rs @@ -0,0 +1,45 @@ +use std::{future::Future, pin::Pin, sync::Arc}; +use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, MockDatabaseConnection, QueryResult, Statement, TransactionError}; +use futures::Stream; +#[cfg(feature = "sqlx-dep")] +use sqlx::pool::PoolConnection; + +pub(crate) enum InnerConnection { + #[cfg(feature = "sqlx-mysql")] + MySql(PoolConnection), + #[cfg(feature = "sqlx-postgres")] + Postgres(PoolConnection), + #[cfg(feature = "sqlx-sqlite")] + Sqlite(PoolConnection), + #[cfg(feature = "mock")] + Mock(Arc), +} + +#[async_trait::async_trait] +pub trait ConnectionTrait<'a>: Sync { + type Stream: Stream>; + + fn get_database_backend(&self) -> DbBackend; + + async fn execute(&self, stmt: Statement) -> Result; + + async fn query_one(&self, stmt: Statement) -> Result, DbErr>; + + async fn query_all(&self, stmt: Statement) -> Result, DbErr>; + + fn stream(&'a self, stmt: Statement) -> Pin> + 'a>>; + + async fn begin(&self) -> Result; + + /// Execute the function inside a transaction. + /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. + async fn transaction(&self, callback: F) -> Result> + where + F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, + T: Send, + E: std::error::Error + Send; + + fn is_mock_connection(&self) -> bool { + false + } +} diff --git a/src/database/db_transaction.rs b/src/database/db_transaction.rs new file mode 100644 index 00000000..b403971e --- /dev/null +++ b/src/database/db_transaction.rs @@ -0,0 +1,308 @@ +use std::{sync::Arc, future::Future, pin::Pin}; +use crate::{ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, Statement, TransactionStream, debug_print}; +use futures::lock::Mutex; +#[cfg(feature = "sqlx-dep")] +use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err}; +#[cfg(feature = "sqlx-dep")] +use sqlx::{pool::PoolConnection, TransactionManager}; + +// a Transaction is just a sugar for a connection where START TRANSACTION has been executed +pub struct DatabaseTransaction { + conn: Arc>, + backend: DbBackend, + open: bool, +} + +impl std::fmt::Debug for DatabaseTransaction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DatabaseTransaction") + } +} + +impl DatabaseTransaction { + #[cfg(feature = "sqlx-mysql")] + pub(crate) async fn new_mysql(inner: PoolConnection) -> Result { + Self::build(Arc::new(Mutex::new(InnerConnection::MySql(inner))), DbBackend::MySql).await + } + + #[cfg(feature = "sqlx-postgres")] + pub(crate) async fn new_postgres(inner: PoolConnection) -> Result { + Self::build(Arc::new(Mutex::new(InnerConnection::Postgres(inner))), DbBackend::Postgres).await + } + + #[cfg(feature = "sqlx-sqlite")] + pub(crate) async fn new_sqlite(inner: PoolConnection) -> Result { + Self::build(Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), DbBackend::Sqlite).await + } + + #[cfg(feature = "mock")] + pub(crate) async fn new_mock(inner: Arc) -> Result { + let backend = inner.get_database_backend(); + Self::build(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await + } + + async fn build(conn: Arc>, backend: DbBackend) -> Result { + let res = DatabaseTransaction { + conn, + backend, + open: true, + }; + match *res.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? + }, + // should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } + Ok(res) + } + + pub(crate) async fn run(self, callback: F) -> Result> + where + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + T: Send, + E: std::error::Error + Send, + { + let res = callback(&self).await.map_err(|e| TransactionError::Transaction(e)); + if res.is_ok() { + self.commit().await.map_err(|e| TransactionError::Connection(e))?; + } + else { + self.rollback().await.map_err(|e| TransactionError::Connection(e))?; + } + res + } + + pub async fn commit(mut self) -> Result<(), DbErr> { + self.open = false; + match *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? + }, + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } + Ok(()) + } + + pub async fn rollback(mut self) -> Result<(), DbErr> { + self.open = false; + match *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? + }, + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } + Ok(()) + } + + // the rollback is queued and will be performed on next async operation, like returning the connection to the pool + fn start_rollback(&mut self) { + if self.open { + if let Some(mut conn) = self.conn.try_lock() { + match &mut *conn { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + ::TransactionManager::start_rollback(c); + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + ::TransactionManager::start_rollback(c); + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + ::TransactionManager::start_rollback(c); + }, + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } + } + else { + //this should never happen + panic!("Dropping a locked Transaction"); + } + } + } +} + +impl Drop for DatabaseTransaction { + fn drop(&mut self) { + self.start_rollback(); + } +} + +#[async_trait::async_trait] +impl<'a> ConnectionTrait<'a> for DatabaseTransaction { + type Stream = TransactionStream<'a>; + + fn get_database_backend(&self) -> DbBackend { + // this way we don't need to lock + self.backend + } + + async fn execute(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query.execute(conn).await + .map(Into::into) + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query.execute(conn).await + .map(Into::into) + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query.execute(conn).await + .map(Into::into) + }, + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.execute(stmt), + }; + #[cfg(feature = "sqlx-dep")] + _res.map_err(sqlx_error_to_exec_err) + } + + async fn query_one(&self, stmt: Statement) -> Result, DbErr> { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query.fetch_one(conn).await + .map(|row| Some(row.into())) + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query.fetch_one(conn).await + .map(|row| Some(row.into())) + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query= crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query.fetch_one(conn).await + .map(|row| Some(row.into())) + }, + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.query_one(stmt), + }; + #[cfg(feature = "sqlx-dep")] + if let Err(sqlx::Error::RowNotFound) = _res { + Ok(None) + } + else { + _res.map_err(sqlx_error_to_query_err) + } + } + + async fn query_all(&self, stmt: Statement) -> Result, DbErr> { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query.fetch_all(conn).await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query.fetch_all(conn).await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query.fetch_all(conn).await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + }, + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.query_all(stmt), + }; + #[cfg(feature = "sqlx-dep")] + _res.map_err(sqlx_error_to_query_err) + } + + fn stream(&'a self, stmt: Statement) -> Pin> + 'a>> { + Box::pin(async move { + Ok(crate::TransactionStream::build(self.conn.lock().await, stmt).await) + }) + } + + async fn begin(&self) -> Result { + DatabaseTransaction::build(Arc::clone(&self.conn), self.backend).await + } + + /// Execute the function inside a transaction. + /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. + async fn transaction(&self, _callback: F) -> Result> + where + F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, + T: Send, + E: std::error::Error + Send, + { + let transaction = self.begin().await.map_err(|e| TransactionError::Connection(e))?; + transaction.run(_callback).await + } +} + +#[derive(Debug)] +pub enum TransactionError +where E: std::error::Error { + Connection(DbErr), + Transaction(E), +} + +impl std::fmt::Display for TransactionError +where E: std::error::Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TransactionError::Connection(e) => std::fmt::Display::fmt(e, f), + TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f), + } + } +} + +impl std::error::Error for TransactionError +where E: std::error::Error {} diff --git a/src/database/mock.rs b/src/database/mock.rs index ccb34a49..d9e1f9d0 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -4,7 +4,7 @@ use crate::{ Statement, Transaction, }; use sea_query::{Value, ValueType}; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, sync::Arc}; #[derive(Debug)] pub struct MockDatabase { @@ -40,7 +40,7 @@ impl MockDatabase { } pub fn into_connection(self) -> DatabaseConnection { - DatabaseConnection::MockDatabaseConnection(MockDatabaseConnection::new(self)) + DatabaseConnection::MockDatabaseConnection(Arc::new(MockDatabaseConnection::new(self))) } pub fn append_exec_results(mut self, mut vec: Vec) -> Self { @@ -100,7 +100,8 @@ impl MockRow { where T: ValueType, { - Ok(self.values.get(col).unwrap().clone().unwrap()) + T::try_from(self.values.get(col).unwrap().clone()) + .map_err(|e| DbErr::Query(e.to_string())) } pub fn into_column_value_tuples(self) -> impl Iterator { diff --git a/src/database/mod.rs b/src/database/mod.rs index f61343c1..ce4127e3 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -3,12 +3,18 @@ mod connection; mod mock; mod statement; mod transaction; +mod db_connection; +mod db_transaction; +mod stream; pub use connection::*; #[cfg(feature = "mock")] pub use mock::*; pub use statement::*; pub use transaction::*; +pub use db_connection::*; +pub use db_transaction::*; +pub use stream::*; use crate::DbErr; diff --git a/src/database/stream/mod.rs b/src/database/stream/mod.rs new file mode 100644 index 00000000..774cf45f --- /dev/null +++ b/src/database/stream/mod.rs @@ -0,0 +1,5 @@ +mod query; +mod transaction; + +pub use query::*; +pub use transaction::*; diff --git a/src/database/stream/query.rs b/src/database/stream/query.rs new file mode 100644 index 00000000..553d9f7b --- /dev/null +++ b/src/database/stream/query.rs @@ -0,0 +1,108 @@ +use std::{pin::Pin, task::Poll, sync::Arc}; + +use futures::Stream; +#[cfg(feature = "sqlx-dep")] +use futures::TryStreamExt; + +#[cfg(feature = "sqlx-dep")] +use sqlx::{pool::PoolConnection, Executor}; + +use crate::{DbErr, InnerConnection, QueryResult, Statement}; + +#[ouroboros::self_referencing] +pub struct QueryStream { + stmt: Statement, + conn: InnerConnection, + #[borrows(mut conn, stmt)] + #[not_covariant] + stream: Pin> + 'this>>, +} + +#[cfg(feature = "sqlx-mysql")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::MySql(conn)) + } +} + +#[cfg(feature = "sqlx-postgres")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::Postgres(conn)) + } +} + +#[cfg(feature = "sqlx-sqlite")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::Sqlite(conn)) + } +} + +#[cfg(feature = "mock")] +impl From<(Arc, Statement)> for QueryStream { + fn from((conn, stmt): (Arc, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::Mock(conn)) + } +} + +impl std::fmt::Debug for QueryStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "QueryStream") + } +} + +impl QueryStream { + fn build(stmt: Statement, conn: InnerConnection) -> QueryStream { + QueryStreamBuilder { + stmt, + conn, + stream_builder: |conn, stmt| { + match conn { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + let query = crate::driver::sqlx_mysql::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + let query = crate::driver::sqlx_postgres::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) + }, + #[cfg(feature = "mock")] + InnerConnection::Mock(c) => { + c.fetch(stmt) + }, + } + }, + }.build() + } +} + +impl Stream for QueryStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + this.with_stream_mut(|stream| { + stream.as_mut().poll_next(cx) + }) + } +} diff --git a/src/database/stream/transaction.rs b/src/database/stream/transaction.rs new file mode 100644 index 00000000..d945f409 --- /dev/null +++ b/src/database/stream/transaction.rs @@ -0,0 +1,82 @@ +use std::{ops::DerefMut, pin::Pin, task::Poll}; + +use futures::Stream; +#[cfg(feature = "sqlx-dep")] +use futures::TryStreamExt; + +#[cfg(feature = "sqlx-dep")] +use sqlx::Executor; + +use futures::lock::MutexGuard; + +use crate::{DbErr, InnerConnection, QueryResult, Statement}; + +#[ouroboros::self_referencing] +pub struct TransactionStream<'a> { + stmt: Statement, + conn: MutexGuard<'a, InnerConnection>, + #[borrows(mut conn, stmt)] + #[not_covariant] + stream: Pin> + 'this>>, +} + +impl<'a> std::fmt::Debug for TransactionStream<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TransactionStream") + } +} + +impl<'a> TransactionStream<'a> { + pub(crate) async fn build(conn: MutexGuard<'a, InnerConnection>, stmt: Statement) -> TransactionStream<'a> { + TransactionStreamAsyncBuilder { + stmt, + conn, + stream_builder: |conn, stmt| Box::pin(async move { + match conn.deref_mut() { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + let query = crate::driver::sqlx_mysql::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) as Pin>>> + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + let query = crate::driver::sqlx_postgres::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) as Pin>>> + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) as Pin>>> + }, + #[cfg(feature = "mock")] + InnerConnection::Mock(c) => { + c.fetch(stmt) + }, + } + }), + }.build().await + } +} + +impl<'a> Stream for TransactionStream<'a> { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + this.with_stream_mut(|stream| { + stream.as_mut().poll_next(cx) + }) + } +} diff --git a/src/driver/mock.rs b/src/driver/mock.rs index 0e398586..388ac911 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -2,11 +2,11 @@ use crate::{ debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult, Statement, Transaction, }; -use std::fmt::Debug; -use std::sync::{ +use std::{fmt::Debug, pin::Pin, sync::{Arc, atomic::{AtomicUsize, Ordering}, Mutex, -}; +}}; +use futures::Stream; #[derive(Debug)] pub struct MockDatabaseConnector; @@ -50,7 +50,7 @@ impl MockDatabaseConnector { macro_rules! connect_mock_db { ( $syntax: expr ) => { Ok(DatabaseConnection::MockDatabaseConnection( - MockDatabaseConnection::new(MockDatabase::new($syntax)), + Arc::new(MockDatabaseConnection::new(MockDatabase::new($syntax))), )) }; } @@ -86,25 +86,32 @@ impl MockDatabaseConnection { &self.mocker } - pub async fn execute(&self, statement: Statement) -> Result { + pub fn execute(&self, statement: Statement) -> Result { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); self.mocker.lock().unwrap().execute(counter, statement) } - pub async fn query_one(&self, statement: Statement) -> Result, DbErr> { + pub fn query_one(&self, statement: Statement) -> Result, DbErr> { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); let result = self.mocker.lock().unwrap().query(counter, statement)?; Ok(result.into_iter().next()) } - pub async fn query_all(&self, statement: Statement) -> Result, DbErr> { + pub fn query_all(&self, statement: Statement) -> Result, DbErr> { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); self.mocker.lock().unwrap().query(counter, statement) } + pub fn fetch(&self, statement: &Statement) -> Pin>>> { + match self.query_all(statement.clone()) { + Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(|r| Ok(r)))), + Err(e) => Box::pin(futures::stream::iter(Some(Err(e)).into_iter())), + } + } + pub fn get_database_backend(&self) -> DbBackend { self.mocker.lock().unwrap().get_database_backend() } diff --git a/src/driver/mod.rs b/src/driver/mod.rs index 6f6cfb64..33b6c847 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -3,11 +3,11 @@ mod mock; #[cfg(feature = "sqlx-dep")] mod sqlx_common; #[cfg(feature = "sqlx-mysql")] -mod sqlx_mysql; +pub(crate) mod sqlx_mysql; #[cfg(feature = "sqlx-postgres")] -mod sqlx_postgres; +pub(crate) mod sqlx_postgres; #[cfg(feature = "sqlx-sqlite")] -mod sqlx_sqlite; +pub(crate) mod sqlx_sqlite; #[cfg(feature = "mock")] pub use mock::*; diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index c542a9b4..75e6e5ff 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -1,12 +1,11 @@ -use sqlx::{ - mysql::{MySqlArguments, MySqlQueryResult, MySqlRow}, - MySql, MySqlPool, -}; +use std::{future::Future, pin::Pin}; + +use sqlx::{MySql, MySqlPool, mysql::{MySqlArguments, MySqlQueryResult, MySqlRow}}; sea_query::sea_query_driver_mysql!(); use sea_query_driver_mysql::bind_query; -use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; +use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; use super::sqlx_common::*; @@ -20,7 +19,7 @@ pub struct SqlxMySqlPoolConnection { impl SqlxMySqlConnector { pub fn accepts(string: &str) -> bool { - DbBackend::MySql.is_prefix_of(string) + string.starts_with("mysql://") } pub async fn connect(string: &str) -> Result { @@ -91,6 +90,44 @@ impl SqlxMySqlPoolConnection { )) } } + + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn begin(&self) -> Result { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_mysql(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn transaction(&self, callback: F) -> Result> + where + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + T: Send, + E: std::error::Error + Send, + { + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_mysql(conn).await.map_err(|e| TransactionError::Connection(e))?; + transaction.run(callback).await + } else { + Err(TransactionError::Connection(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + ))) + } + } } impl From for QueryResult { @@ -109,7 +146,7 @@ impl From for ExecResult { } } -fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySqlArguments> { +pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySqlArguments> { let mut query = sqlx::query(&stmt.sql); if let Some(values) = &stmt.values { query = bind_query(query, values); diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index fb5402eb..c9949375 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -1,12 +1,11 @@ -use sqlx::{ - postgres::{PgArguments, PgQueryResult, PgRow}, - PgPool, Postgres, -}; +use std::{future::Future, pin::Pin}; + +use sqlx::{PgPool, Postgres, postgres::{PgArguments, PgQueryResult, PgRow}}; sea_query::sea_query_driver_postgres!(); use sea_query_driver_postgres::bind_query; -use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; +use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; use super::sqlx_common::*; @@ -20,7 +19,7 @@ pub struct SqlxPostgresPoolConnection { impl SqlxPostgresConnector { pub fn accepts(string: &str) -> bool { - DbBackend::Postgres.is_prefix_of(string) + string.starts_with("postgres://") } pub async fn connect(string: &str) -> Result { @@ -91,6 +90,44 @@ impl SqlxPostgresPoolConnection { )) } } + + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn begin(&self) -> Result { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_postgres(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn transaction(&self, callback: F) -> Result> + where + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + T: Send, + E: std::error::Error + Send, + { + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_postgres(conn).await.map_err(|e| TransactionError::Connection(e))?; + transaction.run(callback).await + } else { + Err(TransactionError::Connection(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + ))) + } + } } impl From for QueryResult { @@ -109,7 +146,7 @@ impl From for ExecResult { } } -fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Postgres, PgArguments> { +pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Postgres, PgArguments> { let mut query = sqlx::query(&stmt.sql); if let Some(values) = &stmt.values { query = bind_query(query, values); diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index b02f4408..bf06a265 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -1,12 +1,11 @@ -use sqlx::{ - sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}, - Sqlite, SqlitePool, -}; +use std::{future::Future, pin::Pin}; + +use sqlx::{Sqlite, SqlitePool, sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}}; sea_query::sea_query_driver_sqlite!(); use sea_query_driver_sqlite::bind_query; -use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; +use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; use super::sqlx_common::*; @@ -20,7 +19,7 @@ pub struct SqlxSqlitePoolConnection { impl SqlxSqliteConnector { pub fn accepts(string: &str) -> bool { - DbBackend::Sqlite.is_prefix_of(string) + string.starts_with("sqlite:") } pub async fn connect(string: &str) -> Result { @@ -91,6 +90,44 @@ impl SqlxSqlitePoolConnection { )) } } + + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn begin(&self) -> Result { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_sqlite(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn transaction(&self, callback: F) -> Result> + where + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + T: Send, + E: std::error::Error + Send, + { + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_sqlite(conn).await.map_err(|e| TransactionError::Connection(e))?; + transaction.run(callback).await + } else { + Err(TransactionError::Connection(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + ))) + } + } } impl From for QueryResult { @@ -109,7 +146,7 @@ impl From for ExecResult { } } -fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqliteArguments> { +pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqliteArguments> { let mut query = sqlx::query(&stmt.sql); if let Some(values) = &stmt.values { query = bind_query(query, values); diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index cfcb0bbd..32e9d77d 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -1,5 +1,5 @@ use crate::{ - error::*, DatabaseConnection, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, + error::*, ConnectionTrait, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, PrimaryKeyTrait, Value, }; use async_trait::async_trait; @@ -67,9 +67,11 @@ pub trait ActiveModelTrait: Clone + Debug { fn default() -> Self; - async fn insert(self, db: &DatabaseConnection) -> Result + async fn insert<'a, C>(self, db: &'a C) -> Result where ::Model: IntoActiveModel, + C: ConnectionTrait<'a>, + Self: 'a, { let am = self; let exec = ::insert(am).exec(db); @@ -90,17 +92,22 @@ pub trait ActiveModelTrait: Clone + Debug { } } - async fn update(self, db: &DatabaseConnection) -> Result { + async fn update<'a, C>(self, db: &'a C) -> Result + where + C: ConnectionTrait<'a>, + Self: 'a, + { let exec = Self::Entity::update(self).exec(db); exec.await } /// Insert the model if primary key is unset, update otherwise. /// Only works if the entity has auto increment primary key. - async fn save(self, db: &DatabaseConnection) -> Result + async fn save<'a, C>(self, db: &'a C) -> Result where - Self: ActiveModelBehavior, + Self: ActiveModelBehavior + 'a, ::Model: IntoActiveModel, + C: ConnectionTrait<'a>, { let mut am = self; am = ActiveModelBehavior::before_save(am); @@ -122,9 +129,10 @@ pub trait ActiveModelTrait: Clone + Debug { } /// Delete an active model by its primary key - async fn delete(self, db: &DatabaseConnection) -> Result + async fn delete<'a, C>(self, db: &'a C) -> Result where - Self: ActiveModelBehavior, + Self: ActiveModelBehavior + 'a, + C: ConnectionTrait<'a>, { let mut am = self; am = ActiveModelBehavior::before_delete(am); diff --git a/src/entity/base_entity.rs b/src/entity/base_entity.rs index 764f2524..aef46207 100644 --- a/src/entity/base_entity.rs +++ b/src/entity/base_entity.rs @@ -510,7 +510,7 @@ pub trait EntityTrait: EntityName { /// /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; + /// # use sea_orm::{entity::*, error::*, query::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_exec_results(vec![ diff --git a/src/executor/delete.rs b/src/executor/delete.rs index 807bc544..85b37cb0 100644 --- a/src/executor/delete.rs +++ b/src/executor/delete.rs @@ -1,6 +1,4 @@ -use crate::{ - error::*, ActiveModelTrait, DatabaseConnection, DeleteMany, DeleteOne, EntityTrait, Statement, -}; +use crate::{ActiveModelTrait, ConnectionTrait, DeleteMany, DeleteOne, EntityTrait, Statement, error::*}; use sea_query::DeleteStatement; use std::future::Future; @@ -18,10 +16,11 @@ impl<'a, A: 'a> DeleteOne where A: ActiveModelTrait, { - pub fn exec( + pub fn exec( self, - db: &'a DatabaseConnection, - ) -> impl Future> + 'a { + db: &'a C, + ) -> impl Future> + 'a + where C: ConnectionTrait<'a> { // so that self is dropped before entering await exec_delete_only(self.query, db) } @@ -31,10 +30,11 @@ impl<'a, E> DeleteMany where E: EntityTrait, { - pub fn exec( + pub fn exec( self, - db: &'a DatabaseConnection, - ) -> impl Future> + 'a { + db: &'a C, + ) -> impl Future> + 'a + where C: ConnectionTrait<'a> { // so that self is dropped before entering await exec_delete_only(self.query, db) } @@ -45,24 +45,27 @@ impl Deleter { Self { query } } - pub fn exec( + pub fn exec<'a, C>( self, - db: &DatabaseConnection, - ) -> impl Future> + '_ { + db: &'a C, + ) -> impl Future> + '_ + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); exec_delete(builder.build(&self.query), db) } } -async fn exec_delete_only( +async fn exec_delete_only<'a, C>( query: DeleteStatement, - db: &DatabaseConnection, -) -> Result { + db: &'a C, +) -> Result +where C: ConnectionTrait<'a> { Deleter::new(query).exec(db).await } // Only Statement impl Send -async fn exec_delete(statement: Statement, db: &DatabaseConnection) -> Result { +async fn exec_delete<'a, C>(statement: Statement, db: &C) -> Result +where C: ConnectionTrait<'a> { let result = db.execute(statement).await?; Ok(DeleteResult { rows_affected: result.rows_affected(), diff --git a/src/executor/insert.rs b/src/executor/insert.rs index a44867f7..7683f9bb 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,9 +1,6 @@ -use crate::{ - error::*, ActiveModelTrait, DatabaseConnection, DbBackend, EntityTrait, Insert, - PrimaryKeyTrait, Statement, TryFromU64, -}; +use crate::{ActiveModelTrait, ConnectionTrait, EntityTrait, Insert, PrimaryKeyTrait, Statement, TryFromU64, error::*}; use sea_query::InsertStatement; -use std::{future::Future, marker::PhantomData}; +use std::marker::PhantomData; #[derive(Clone, Debug)] pub struct Inserter @@ -27,11 +24,12 @@ where A: ActiveModelTrait, { #[allow(unused_mut)] - pub fn exec<'a>( + pub async fn exec<'a, C>( self, - db: &'a DatabaseConnection, - ) -> impl Future, DbErr>> + 'a + db: &'a C, + ) -> Result, DbErr> where + C: ConnectionTrait<'a>, A: 'a, { // TODO: extract primary key's value from query @@ -47,7 +45,7 @@ where ); } } - Inserter::::new(query).exec(db) + Inserter::::new(query).exec(db).await // TODO: return primary key if extracted before, otherwise use InsertResult } } @@ -63,24 +61,26 @@ where } } - pub fn exec<'a>( + pub async fn exec<'a, C>( self, - db: &'a DatabaseConnection, - ) -> impl Future, DbErr>> + 'a + db: &'a C, + ) -> Result, DbErr> where + C: ConnectionTrait<'a>, A: 'a, { let builder = db.get_database_backend(); - exec_insert(builder.build(&self.query), db) + exec_insert(builder.build(&self.query), db).await } } // Only Statement impl Send -async fn exec_insert( +async fn exec_insert<'a, A, C>( statement: Statement, - db: &DatabaseConnection, + db: &C, ) -> Result, DbErr> where + C: ConnectionTrait<'a>, A: ActiveModelTrait, { type PrimaryKey = <::Entity as EntityTrait>::PrimaryKey; @@ -93,13 +93,13 @@ where .collect::>(); let res = db.query_one(statement).await?.unwrap(); res.try_get_many("", cols.as_ref()).unwrap_or_default() - } + }, _ => { let last_insert_id = db.execute(statement).await?.last_insert_id(); ValueTypeOf::::try_from_u64(last_insert_id) .ok() .unwrap_or_default() - } + }, }; Ok(InsertResult { last_insert_id }) } diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index 608d9dc1..a4e9bc69 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -1,4 +1,4 @@ -use crate::{error::*, DatabaseConnection, DbBackend, SelectorTrait}; +use crate::{ConnectionTrait, SelectorTrait, error::*}; use async_stream::stream; use futures::Stream; use sea_query::{Alias, Expr, SelectStatement}; @@ -7,21 +7,23 @@ use std::{marker::PhantomData, pin::Pin}; pub type PinBoxStream<'db, Item> = Pin + 'db>>; #[derive(Clone, Debug)] -pub struct Paginator<'db, S> +pub struct Paginator<'db, C, S> where + C: ConnectionTrait<'db>, S: SelectorTrait + 'db, { pub(crate) query: SelectStatement, pub(crate) page: usize, pub(crate) page_size: usize, - pub(crate) db: &'db DatabaseConnection, + pub(crate) db: &'db C, pub(crate) selector: PhantomData, } // LINT: warn if paginator is used without an order by clause -impl<'db, S> Paginator<'db, S> +impl<'db, C, S> Paginator<'db, C, S> where + C: ConnectionTrait<'db>, S: SelectorTrait + 'db, { /// Fetch a specific page; page index starts from zero @@ -155,7 +157,7 @@ where #[cfg(feature = "mock")] mod tests { use crate::entity::prelude::*; - use crate::tests_cfg::*; + use crate::{ConnectionTrait, tests_cfg::*}; use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction}; use futures::TryStreamExt; use sea_query::{Alias, Expr, SelectStatement, Value}; diff --git a/src/executor/select.rs b/src/executor/select.rs index bb386722..cff4a348 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -1,8 +1,8 @@ -use crate::{ - error::*, DatabaseConnection, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, - ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, - SelectTwoMany, Statement, TryGetableMany, -}; +#[cfg(feature = "sqlx-dep")] +use std::pin::Pin; +use crate::{ConnectionTrait, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, SelectTwoMany, Statement, error::*}; +#[cfg(feature = "sqlx-dep")] +use futures::{Stream, TryStreamExt}; use sea_query::SelectStatement; use std::marker::PhantomData; @@ -235,23 +235,35 @@ where Selector::>::with_columns(self.query) } - pub async fn one(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn one<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { self.into_model().one(db).await } - pub async fn all(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { self.into_model().all(db).await } - pub fn paginate( + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result> + 'b, DbErr> + where + C: ConnectionTrait<'a>, + { + self.into_model().stream(db).await + } + + pub fn paginate<'a, C>( self, - db: &DatabaseConnection, + db: &'a C, page_size: usize, - ) -> Paginator<'_, SelectModel> { + ) -> Paginator<'a, C, SelectModel> + where C: ConnectionTrait<'a> { self.into_model().paginate(db, page_size) } - pub async fn count(self, db: &DatabaseConnection) -> Result { + pub async fn count<'a, C>(self, db: &'a C) -> Result + where C: ConnectionTrait<'a> { self.paginate(db, 1).num_items().await } } @@ -280,29 +292,41 @@ where } } - pub async fn one( + pub async fn one<'a, C>( self, - db: &DatabaseConnection, - ) -> Result)>, DbErr> { + db: &C, + ) -> Result)>, DbErr> + where C: ConnectionTrait<'a> { self.into_model().one(db).await } - pub async fn all( + pub async fn all<'a, C>( self, - db: &DatabaseConnection, - ) -> Result)>, DbErr> { + db: &C, + ) -> Result)>, DbErr> + where C: ConnectionTrait<'a> { self.into_model().all(db).await } - pub fn paginate( + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result), DbErr>> + 'b, DbErr> + where + C: ConnectionTrait<'a>, + { + self.into_model().stream(db).await + } + + pub fn paginate<'a, C>( self, - db: &DatabaseConnection, + db: &'a C, page_size: usize, - ) -> Paginator<'_, SelectTwoModel> { + ) -> Paginator<'a, C, SelectTwoModel> + where C: ConnectionTrait<'a> { self.into_model().paginate(db, page_size) } - pub async fn count(self, db: &DatabaseConnection) -> Result { + pub async fn count<'a, C>(self, db: &'a C) -> Result + where C: ConnectionTrait<'a> { self.paginate(db, 1).num_items().await } } @@ -331,17 +355,27 @@ where } } - pub async fn one( + pub async fn one<'a, C>( self, - db: &DatabaseConnection, - ) -> Result)>, DbErr> { + db: &C, + ) -> Result)>, DbErr> + where C: ConnectionTrait<'a> { self.into_model().one(db).await } - pub async fn all( + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result), DbErr>> + 'b, DbErr> + where + C: ConnectionTrait<'a>, + { + self.into_model().stream(db).await + } + + pub async fn all<'a, C>( self, - db: &DatabaseConnection, - ) -> Result)>, DbErr> { + db: &C, + ) -> Result)>, DbErr> + where C: ConnectionTrait<'a> { let rows = self.into_model().all(db).await?; Ok(consolidate_query_result::(rows)) } @@ -376,7 +410,8 @@ where } } - pub async fn one(mut self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn one<'a, C>(mut self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); self.query.limit(1); let row = db.query_one(builder.build(&self.query)).await?; @@ -386,7 +421,8 @@ where } } - pub async fn all(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); let rows = db.query_all(builder.build(&self.query)).await?; let mut models = Vec::new(); @@ -396,7 +432,21 @@ where Ok(models) } - pub fn paginate(self, db: &DatabaseConnection, page_size: usize) -> Paginator<'_, S> { + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result> + 'b>>, DbErr> + where + C: ConnectionTrait<'a>, + S: 'b, + { + let builder = db.get_database_backend(); + let stream = db.stream(builder.build(&self.query)).await?; + Ok(Box::pin(stream.and_then(|row| { + futures::future::ready(S::from_raw_query_result(row)) + }))) + } + + pub fn paginate<'a, C>(self, db: &'a C, page_size: usize) -> Paginator<'a, C, S> + where C: ConnectionTrait<'a> { Paginator { query: self.query, page: 0, @@ -606,7 +656,8 @@ where /// ),] /// ); /// ``` - pub async fn one(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn one<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let row = db.query_one(self.stmt).await?; match row { Some(row) => Ok(Some(S::from_raw_query_result(row)?)), @@ -645,7 +696,8 @@ where /// ),] /// ); /// ``` - pub async fn all(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let rows = db.query_all(self.stmt).await?; let mut models = Vec::new(); for row in rows.into_iter() { diff --git a/src/executor/update.rs b/src/executor/update.rs index 6c7a9873..06cd514e 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -1,6 +1,4 @@ -use crate::{ - error::*, ActiveModelTrait, DatabaseConnection, EntityTrait, Statement, UpdateMany, UpdateOne, -}; +use crate::{ActiveModelTrait, ConnectionTrait, EntityTrait, Statement, UpdateMany, UpdateOne, error::*}; use sea_query::UpdateStatement; use std::future::Future; @@ -18,9 +16,10 @@ impl<'a, A: 'a> UpdateOne where A: ActiveModelTrait, { - pub fn exec(self, db: &'a DatabaseConnection) -> impl Future> + 'a { + pub async fn exec<'b, C>(self, db: &'b C) -> Result + where C: ConnectionTrait<'b> { // so that self is dropped before entering await - exec_update_and_return_original(self.query, self.model, db) + exec_update_and_return_original(self.query, self.model, db).await } } @@ -28,10 +27,11 @@ impl<'a, E> UpdateMany where E: EntityTrait, { - pub fn exec( + pub fn exec( self, - db: &'a DatabaseConnection, - ) -> impl Future> + 'a { + db: &'a C, + ) -> impl Future> + 'a + where C: ConnectionTrait<'a> { // so that self is dropped before entering await exec_update_only(self.query, db) } @@ -42,36 +42,40 @@ impl Updater { Self { query } } - pub fn exec( + pub async fn exec<'a, C>( self, - db: &DatabaseConnection, - ) -> impl Future> + '_ { + db: &'a C, + ) -> Result + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); - exec_update(builder.build(&self.query), db) + exec_update(builder.build(&self.query), db).await } } -async fn exec_update_only( +async fn exec_update_only<'a, C>( query: UpdateStatement, - db: &DatabaseConnection, -) -> Result { + db: &'a C, +) -> Result +where C: ConnectionTrait<'a> { Updater::new(query).exec(db).await } -async fn exec_update_and_return_original( +async fn exec_update_and_return_original<'a, A, C>( query: UpdateStatement, model: A, - db: &DatabaseConnection, + db: &'a C, ) -> Result where A: ActiveModelTrait, + C: ConnectionTrait<'a>, { Updater::new(query).exec(db).await?; Ok(model) } // Only Statement impl Send -async fn exec_update(statement: Statement, db: &DatabaseConnection) -> Result { +async fn exec_update<'a, C>(statement: Statement, db: &'a C) -> Result +where C: ConnectionTrait<'a> { let result = db.execute(statement).await?; Ok(UpdateResult { rows_affected: result.rows_affected(), diff --git a/src/query/mod.rs b/src/query/mod.rs index 54cc12dd..c7f60049 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -20,4 +20,4 @@ pub use select::*; pub use traits::*; pub use update::*; -pub use crate::{InsertResult, Statement, UpdateResult, Value, Values}; +pub use crate::{InsertResult, Statement, UpdateResult, Value, Values, ConnectionTrait}; diff --git a/tests/basic.rs b/tests/basic.rs index a0763d45..ef379779 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -1,6 +1,6 @@ pub mod common; -pub use sea_orm::{entity::*, error::*, sea_query, tests_cfg::*, Database, DbConn}; +pub use sea_orm::{entity::*, error::*, query::*, sea_query, tests_cfg::*, Database, DbConn}; // cargo test --features sqlx-sqlite,runtime-async-std-native-tls --test basic #[sea_orm_macros::test] diff --git a/tests/common/setup/mod.rs b/tests/common/setup/mod.rs index d982b2b7..9deb903f 100644 --- a/tests/common/setup/mod.rs +++ b/tests/common/setup/mod.rs @@ -1,4 +1,4 @@ -use sea_orm::{Database, DatabaseBackend, DatabaseConnection, Statement}; +use sea_orm::{Database, DatabaseBackend, DatabaseConnection, ConnectionTrait, Statement}; pub mod schema; pub use schema::*; diff --git a/tests/common/setup/schema.rs b/tests/common/setup/schema.rs index b39f1e77..ac34304a 100644 --- a/tests/common/setup/schema.rs +++ b/tests/common/setup/schema.rs @@ -1,6 +1,6 @@ pub use super::super::bakery_chain::*; use pretty_assertions::assert_eq; -use sea_orm::{error::*, sea_query, DbBackend, DbConn, EntityTrait, ExecResult, Schema}; +use sea_orm::{error::*, sea_query, ConnectionTrait, DbBackend, DbConn, EntityTrait, ExecResult, Schema}; use sea_query::{ Alias, ColumnDef, ForeignKey, ForeignKeyAction, Index, Table, TableCreateStatement, }; diff --git a/tests/query_tests.rs b/tests/query_tests.rs index 2b5a2295..e688b14f 100644 --- a/tests/query_tests.rs +++ b/tests/query_tests.rs @@ -2,7 +2,7 @@ pub mod common; pub use common::{bakery_chain::*, setup::*, TestContext}; pub use sea_orm::entity::*; -pub use sea_orm::QueryFilter; +pub use sea_orm::{QueryFilter, ConnectionTrait}; // Run the test locally: // DATABASE_URL="mysql://root:@localhost" cargo test --features sqlx-mysql,runtime-async-std --test query_tests diff --git a/tests/sequential_op_tests.rs b/tests/sequential_op_tests.rs index 28333d84..47e69ccb 100644 --- a/tests/sequential_op_tests.rs +++ b/tests/sequential_op_tests.rs @@ -179,7 +179,7 @@ async fn find_baker_least_sales(db: &DatabaseConnection) -> Option let mut results: Vec = select .into_model::() - .all(&db) + .all(db) .await .unwrap() .into_iter() diff --git a/tests/stream_tests.rs b/tests/stream_tests.rs new file mode 100644 index 00000000..969b93e1 --- /dev/null +++ b/tests/stream_tests.rs @@ -0,0 +1,37 @@ +pub mod common; + +pub use common::{bakery_chain::*, setup::*, TestContext}; +pub use sea_orm::entity::*; +pub use sea_orm::{QueryFilter, ConnectionTrait, DbErr}; +use futures::StreamExt; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn stream() -> Result<(), DbErr> { + let ctx = TestContext::new("stream").await; + + let bakery = bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(&ctx.db) + .await?; + + let result = Bakery::find_by_id(bakery.id.clone().unwrap()) + .stream(&ctx.db) + .await? + .next() + .await + .unwrap()?; + + assert_eq!(result.id, bakery.id.unwrap()); + + ctx.delete().await; + + Ok(()) +} diff --git a/tests/transaction_tests.rs b/tests/transaction_tests.rs new file mode 100644 index 00000000..539eaefc --- /dev/null +++ b/tests/transaction_tests.rs @@ -0,0 +1,90 @@ +pub mod common; + +pub use common::{bakery_chain::*, setup::*, TestContext}; +use sea_orm::{DatabaseTransaction, DbErr}; +pub use sea_orm::entity::*; +pub use sea_orm::{QueryFilter, ConnectionTrait}; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn transaction() { + let ctx = TestContext::new("transaction_test").await; + + ctx.db.transaction::<_, _, DbErr>(|txn| Box::pin(async move { + let _ = bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(txn) + .await?; + + let _ = bakery::ActiveModel { + name: Set("Top Bakery".to_owned()), + profit_margin: Set(15.0), + ..Default::default() + } + .save(txn) + .await?; + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 2); + + Ok(()) + })).await.unwrap(); + + ctx.delete().await; +} + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn transaction_with_reference() { + let ctx = TestContext::new("transaction_with_reference_test").await; + let name1 = "SeaSide Bakery"; + let name2 = "Top Bakery"; + let search_name = "Bakery"; + ctx.db.transaction(|txn| _transaction_with_reference(txn, name1, name2, search_name)).await.unwrap(); + + ctx.delete().await; +} + +fn _transaction_with_reference<'a>(txn: &'a DatabaseTransaction, name1: &'a str, name2: &'a str, search_name: &'a str) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + let _ = bakery::ActiveModel { + name: Set(name1.to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(txn) + .await?; + + let _ = bakery::ActiveModel { + name: Set(name2.to_owned()), + profit_margin: Set(15.0), + ..Default::default() + } + .save(txn) + .await?; + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains(search_name)) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 2); + + Ok(()) + }) +}