From 6b98a6f3955155ab062b739ddda86f35c8d3cf83 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 4 Oct 2021 20:40:27 +0800 Subject: [PATCH] Move code --- src/database/connection.rs | 236 ++------------------ src/database/db_connection.rs | 239 +++++++++++++++++++-- src/database/db_transaction.rs | 370 -------------------------------- src/database/mock.rs | 44 +++- src/database/mod.rs | 2 - src/database/transaction.rs | 380 ++++++++++++++++++++++++++++++--- 6 files changed, 634 insertions(+), 637 deletions(-) delete mode 100644 src/database/db_transaction.rs diff --git a/src/database/connection.rs b/src/database/connection.rs index 2c65fa28..d90c72a9 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -1,250 +1,40 @@ use crate::{ - error::*, ConnectionTrait, DatabaseTransaction, ExecResult, QueryResult, Statement, - StatementBuilder, TransactionError, + DatabaseTransaction, DbBackend, DbErr, ExecResult, QueryResult, Statement, TransactionError, }; -use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; +use futures::Stream; use std::{future::Future, pin::Pin}; -#[cfg(feature = "mock")] -use std::sync::Arc; - -#[cfg_attr(not(feature = "mock"), derive(Clone))] -pub enum DatabaseConnection { - #[cfg(feature = "sqlx-mysql")] - SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection), - #[cfg(feature = "sqlx-postgres")] - SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection), - #[cfg(feature = "sqlx-sqlite")] - SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection), - #[cfg(feature = "mock")] - MockDatabaseConnection(Arc), - Disconnected, -} - -pub type DbConn = DatabaseConnection; - -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum DatabaseBackend { - MySql, - Postgres, - Sqlite, -} - -pub type DbBackend = DatabaseBackend; - -impl Default for DatabaseConnection { - fn default() -> Self { - Self::Disconnected - } -} - -impl std::fmt::Debug for DatabaseConnection { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "{}", - match self { - #[cfg(feature = "sqlx-mysql")] - Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection", - #[cfg(feature = "sqlx-postgres")] - Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection", - #[cfg(feature = "sqlx-sqlite")] - Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection", - #[cfg(feature = "mock")] - Self::MockDatabaseConnection(_) => "MockDatabaseConnection", - Self::Disconnected => "Disconnected", - } - ) - } -} - #[async_trait::async_trait] -impl<'a> ConnectionTrait<'a> for DatabaseConnection { - type Stream = crate::QueryStream; +pub trait ConnectionTrait<'a>: Sync { + type Stream: Stream>; - fn get_database_backend(&self) -> DbBackend { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.get_database_backend(), - DatabaseConnection::Disconnected => panic!("Disconnected"), - } - } + fn get_database_backend(&self) -> DbBackend; - async fn execute(&self, stmt: Statement) -> Result { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt), - DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), - } - } + async fn execute(&self, stmt: Statement) -> Result; - async fn query_one(&self, stmt: Statement) -> Result, DbErr> { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt), - DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), - } - } + async fn query_one(&self, stmt: Statement) -> Result, DbErr>; - async fn query_all(&self, stmt: Statement) -> Result, DbErr> { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt), - DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), - } - } + async fn query_all(&self, stmt: Statement) -> Result, DbErr>; fn stream( &'a self, stmt: Statement, - ) -> Pin> + 'a>> { - Box::pin(async move { - Ok(match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => { - crate::QueryStream::from((Arc::clone(conn), stmt)) - } - DatabaseConnection::Disconnected => panic!("Disconnected"), - }) - }) - } + ) -> Pin> + 'a>>; - async fn begin(&self) -> Result { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => { - DatabaseTransaction::new_mock(Arc::clone(conn)).await - } - DatabaseConnection::Disconnected => panic!("Disconnected"), - } - } + async fn begin(&self) -> Result; /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. - async fn transaction(&self, _callback: F) -> Result> + async fn transaction(&self, callback: F) -> Result> where F: for<'c> FnOnce( &'c DatabaseTransaction, ) -> Pin> + Send + 'c>> + Send, T: Send, - E: std::error::Error + Send, - { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => { - conn.transaction(_callback).await - } - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => { - let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)) - .await - .map_err(|e| TransactionError::Connection(e))?; - transaction.run(_callback).await - } - DatabaseConnection::Disconnected => panic!("Disconnected"), - } - } + E: std::error::Error + Send; - #[cfg(feature = "mock")] fn is_mock_connection(&self) -> bool { - match self { - DatabaseConnection::MockDatabaseConnection(_) => true, - _ => false, - } - } -} - -#[cfg(feature = "mock")] -impl DatabaseConnection { - pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection { - match self { - DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn, - _ => panic!("not mock connection"), - } - } - - pub fn into_transaction_log(self) -> Vec { - let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap(); - mocker.drain_transaction_log() - } -} - -impl DbBackend { - pub fn is_prefix_of(self, base_url: &str) -> bool { - match self { - Self::Postgres => { - base_url.starts_with("postgres://") || base_url.starts_with("postgresql://") - } - Self::MySql => base_url.starts_with("mysql://"), - Self::Sqlite => base_url.starts_with("sqlite:"), - } - } - - pub fn build(&self, statement: &S) -> Statement - where - S: StatementBuilder, - { - statement.build(self) - } - - pub fn get_query_builder(&self) -> Box { - match self { - Self::MySql => Box::new(MysqlQueryBuilder), - Self::Postgres => Box::new(PostgresQueryBuilder), - Self::Sqlite => Box::new(SqliteQueryBuilder), - } - } -} - -#[cfg(test)] -mod tests { - use crate::DatabaseConnection; - - #[test] - fn assert_database_connection_traits() { - fn assert_send_sync() {} - - assert_send_sync::(); + false } } diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 6040e452..60d70ac7 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -1,10 +1,39 @@ use crate::{ - DatabaseTransaction, DbBackend, DbErr, ExecResult, QueryResult, Statement, TransactionError, + error::*, ConnectionTrait, DatabaseTransaction, ExecResult, QueryResult, Statement, + StatementBuilder, TransactionError, }; -use futures::Stream; +use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; +use std::{future::Future, pin::Pin}; + #[cfg(feature = "sqlx-dep")] use sqlx::pool::PoolConnection; -use std::{future::Future, pin::Pin}; + +#[cfg(feature = "mock")] +use std::sync::Arc; + +#[cfg_attr(not(feature = "mock"), derive(Clone))] +pub enum DatabaseConnection { + #[cfg(feature = "sqlx-mysql")] + SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection), + #[cfg(feature = "sqlx-postgres")] + SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection), + #[cfg(feature = "sqlx-sqlite")] + SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection), + #[cfg(feature = "mock")] + MockDatabaseConnection(Arc), + Disconnected, +} + +pub type DbConn = DatabaseConnection; + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum DatabaseBackend { + MySql, + Postgres, + Sqlite, +} + +pub type DbBackend = DatabaseBackend; pub(crate) enum InnerConnection { #[cfg(feature = "sqlx-mysql")] @@ -17,37 +46,219 @@ pub(crate) enum InnerConnection { Mock(std::sync::Arc), } +impl Default for DatabaseConnection { + fn default() -> Self { + Self::Disconnected + } +} + +impl std::fmt::Debug for DatabaseConnection { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "{}", + match self { + #[cfg(feature = "sqlx-mysql")] + Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection", + #[cfg(feature = "sqlx-postgres")] + Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection", + #[cfg(feature = "sqlx-sqlite")] + Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection", + #[cfg(feature = "mock")] + Self::MockDatabaseConnection(_) => "MockDatabaseConnection", + Self::Disconnected => "Disconnected", + } + ) + } +} + #[async_trait::async_trait] -pub trait ConnectionTrait<'a>: Sync { - type Stream: Stream>; +impl<'a> ConnectionTrait<'a> for DatabaseConnection { + type Stream = crate::QueryStream; - fn get_database_backend(&self) -> DbBackend; + fn get_database_backend(&self) -> DbBackend { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => conn.get_database_backend(), + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } - async fn execute(&self, stmt: Statement) -> Result; + async fn execute(&self, stmt: Statement) -> Result { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt), + DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), + } + } - async fn query_one(&self, stmt: Statement) -> Result, DbErr>; + async fn query_one(&self, stmt: Statement) -> Result, DbErr> { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt), + DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), + } + } - async fn query_all(&self, stmt: Statement) -> Result, DbErr>; + async fn query_all(&self, stmt: Statement) -> Result, DbErr> { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt), + DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), + } + } fn stream( &'a self, stmt: Statement, - ) -> Pin> + 'a>>; + ) -> Pin> + 'a>> { + Box::pin(async move { + Ok(match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => { + crate::QueryStream::from((Arc::clone(conn), stmt)) + } + DatabaseConnection::Disconnected => panic!("Disconnected"), + }) + }) + } - async fn begin(&self) -> Result; + async fn begin(&self) -> Result { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => { + DatabaseTransaction::new_mock(Arc::clone(conn)).await + } + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. - async fn transaction(&self, callback: F) -> Result> + async fn transaction(&self, _callback: F) -> Result> where F: for<'c> FnOnce( &'c DatabaseTransaction, ) -> Pin> + Send + 'c>> + Send, T: Send, - E: std::error::Error + Send; + E: std::error::Error + Send, + { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => { + conn.transaction(_callback).await + } + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => { + let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)) + .await + .map_err(|e| TransactionError::Connection(e))?; + transaction.run(_callback).await + } + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } + #[cfg(feature = "mock")] fn is_mock_connection(&self) -> bool { - false + match self { + DatabaseConnection::MockDatabaseConnection(_) => true, + _ => false, + } + } +} + +#[cfg(feature = "mock")] +impl DatabaseConnection { + pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection { + match self { + DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn, + _ => panic!("not mock connection"), + } + } + + pub fn into_transaction_log(self) -> Vec { + let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap(); + mocker.drain_transaction_log() + } +} + +impl DbBackend { + pub fn is_prefix_of(self, base_url: &str) -> bool { + match self { + Self::Postgres => { + base_url.starts_with("postgres://") || base_url.starts_with("postgresql://") + } + Self::MySql => base_url.starts_with("mysql://"), + Self::Sqlite => base_url.starts_with("sqlite:"), + } + } + + pub fn build(&self, statement: &S) -> Statement + where + S: StatementBuilder, + { + statement.build(self) + } + + pub fn get_query_builder(&self) -> Box { + match self { + Self::MySql => Box::new(MysqlQueryBuilder), + Self::Postgres => Box::new(PostgresQueryBuilder), + Self::Sqlite => Box::new(SqliteQueryBuilder), + } + } +} + +#[cfg(test)] +mod tests { + use crate::DatabaseConnection; + + #[test] + fn assert_database_connection_traits() { + fn assert_send_sync() {} + + assert_send_sync::(); } } diff --git a/src/database/db_transaction.rs b/src/database/db_transaction.rs deleted file mode 100644 index ae954097..00000000 --- a/src/database/db_transaction.rs +++ /dev/null @@ -1,370 +0,0 @@ -use crate::{ - debug_print, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, - Statement, TransactionStream, -}; -#[cfg(feature = "sqlx-dep")] -use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err}; -use futures::lock::Mutex; -#[cfg(feature = "sqlx-dep")] -use sqlx::{pool::PoolConnection, TransactionManager}; -use std::{future::Future, pin::Pin, sync::Arc}; - -// a Transaction is just a sugar for a connection where START TRANSACTION has been executed -pub struct DatabaseTransaction { - conn: Arc>, - backend: DbBackend, - open: bool, -} - -impl std::fmt::Debug for DatabaseTransaction { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "DatabaseTransaction") - } -} - -impl DatabaseTransaction { - #[cfg(feature = "sqlx-mysql")] - pub(crate) async fn new_mysql( - inner: PoolConnection, - ) -> Result { - Self::build( - Arc::new(Mutex::new(InnerConnection::MySql(inner))), - DbBackend::MySql, - ) - .await - } - - #[cfg(feature = "sqlx-postgres")] - pub(crate) async fn new_postgres( - inner: PoolConnection, - ) -> Result { - Self::build( - Arc::new(Mutex::new(InnerConnection::Postgres(inner))), - DbBackend::Postgres, - ) - .await - } - - #[cfg(feature = "sqlx-sqlite")] - pub(crate) async fn new_sqlite( - inner: PoolConnection, - ) -> Result { - Self::build( - Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), - DbBackend::Sqlite, - ) - .await - } - - #[cfg(feature = "mock")] - pub(crate) async fn new_mock( - inner: Arc, - ) -> Result { - let backend = inner.get_database_backend(); - Self::build(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await - } - - async fn build( - conn: Arc>, - backend: DbBackend, - ) -> Result { - let res = DatabaseTransaction { - conn, - backend, - open: true, - }; - match *res.conn.lock().await { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(ref mut c) => { - ::TransactionManager::begin(c) - .await - .map_err(sqlx_error_to_query_err)? - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(ref mut c) => { - ::TransactionManager::begin(c) - .await - .map_err(sqlx_error_to_query_err)? - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(ref mut c) => { - ::TransactionManager::begin(c) - .await - .map_err(sqlx_error_to_query_err)? - } - // should we do something for mocked connections? - #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} - } - Ok(res) - } - - pub(crate) async fn run(self, callback: F) -> Result> - where - F: for<'b> FnOnce( - &'b DatabaseTransaction, - ) -> Pin> + Send + 'b>> - + Send, - T: Send, - E: std::error::Error + Send, - { - let res = callback(&self) - .await - .map_err(|e| TransactionError::Transaction(e)); - if res.is_ok() { - self.commit() - .await - .map_err(|e| TransactionError::Connection(e))?; - } else { - self.rollback() - .await - .map_err(|e| TransactionError::Connection(e))?; - } - res - } - - pub async fn commit(mut self) -> Result<(), DbErr> { - self.open = false; - match *self.conn.lock().await { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(ref mut c) => { - ::TransactionManager::commit(c) - .await - .map_err(sqlx_error_to_query_err)? - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(ref mut c) => { - ::TransactionManager::commit(c) - .await - .map_err(sqlx_error_to_query_err)? - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(ref mut c) => { - ::TransactionManager::commit(c) - .await - .map_err(sqlx_error_to_query_err)? - } - //Should we do something for mocked connections? - #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} - } - Ok(()) - } - - pub async fn rollback(mut self) -> Result<(), DbErr> { - self.open = false; - match *self.conn.lock().await { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(ref mut c) => { - ::TransactionManager::rollback(c) - .await - .map_err(sqlx_error_to_query_err)? - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(ref mut c) => { - ::TransactionManager::rollback(c) - .await - .map_err(sqlx_error_to_query_err)? - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(ref mut c) => { - ::TransactionManager::rollback(c) - .await - .map_err(sqlx_error_to_query_err)? - } - //Should we do something for mocked connections? - #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} - } - Ok(()) - } - - // the rollback is queued and will be performed on next async operation, like returning the connection to the pool - fn start_rollback(&mut self) { - if self.open { - if let Some(mut conn) = self.conn.try_lock() { - match &mut *conn { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(c) => { - ::TransactionManager::start_rollback(c); - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(c) => { - ::TransactionManager::start_rollback(c); - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(c) => { - ::TransactionManager::start_rollback(c); - } - //Should we do something for mocked connections? - #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} - } - } else { - //this should never happen - panic!("Dropping a locked Transaction"); - } - } - } -} - -impl Drop for DatabaseTransaction { - fn drop(&mut self) { - self.start_rollback(); - } -} - -#[async_trait::async_trait] -impl<'a> ConnectionTrait<'a> for DatabaseTransaction { - type Stream = TransactionStream<'a>; - - fn get_database_backend(&self) -> DbBackend { - // this way we don't need to lock - self.backend - } - - async fn execute(&self, stmt: Statement) -> Result { - debug_print!("{}", stmt); - - let _res = match &mut *self.conn.lock().await { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(conn) => { - let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - query.execute(conn).await.map(Into::into) - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(conn) => { - let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - query.execute(conn).await.map(Into::into) - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(conn) => { - let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); - query.execute(conn).await.map(Into::into) - } - #[cfg(feature = "mock")] - InnerConnection::Mock(conn) => return conn.execute(stmt), - }; - #[cfg(feature = "sqlx-dep")] - _res.map_err(sqlx_error_to_exec_err) - } - - async fn query_one(&self, stmt: Statement) -> Result, DbErr> { - debug_print!("{}", stmt); - - let _res = match &mut *self.conn.lock().await { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(conn) => { - let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - query.fetch_one(conn).await.map(|row| Some(row.into())) - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(conn) => { - let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - query.fetch_one(conn).await.map(|row| Some(row.into())) - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(conn) => { - let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); - query.fetch_one(conn).await.map(|row| Some(row.into())) - } - #[cfg(feature = "mock")] - InnerConnection::Mock(conn) => return conn.query_one(stmt), - }; - #[cfg(feature = "sqlx-dep")] - if let Err(sqlx::Error::RowNotFound) = _res { - Ok(None) - } else { - _res.map_err(sqlx_error_to_query_err) - } - } - - async fn query_all(&self, stmt: Statement) -> Result, DbErr> { - debug_print!("{}", stmt); - - let _res = match &mut *self.conn.lock().await { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(conn) => { - let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - query - .fetch_all(conn) - .await - .map(|rows| rows.into_iter().map(|r| r.into()).collect()) - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(conn) => { - let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - query - .fetch_all(conn) - .await - .map(|rows| rows.into_iter().map(|r| r.into()).collect()) - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(conn) => { - let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); - query - .fetch_all(conn) - .await - .map(|rows| rows.into_iter().map(|r| r.into()).collect()) - } - #[cfg(feature = "mock")] - InnerConnection::Mock(conn) => return conn.query_all(stmt), - }; - #[cfg(feature = "sqlx-dep")] - _res.map_err(sqlx_error_to_query_err) - } - - fn stream( - &'a self, - stmt: Statement, - ) -> Pin> + 'a>> { - Box::pin( - async move { Ok(crate::TransactionStream::build(self.conn.lock().await, stmt).await) }, - ) - } - - async fn begin(&self) -> Result { - DatabaseTransaction::build(Arc::clone(&self.conn), self.backend).await - } - - /// Execute the function inside a transaction. - /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. - async fn transaction(&self, _callback: F) -> Result> - where - F: for<'c> FnOnce( - &'c DatabaseTransaction, - ) -> Pin> + Send + 'c>> - + Send, - T: Send, - E: std::error::Error + Send, - { - let transaction = self - .begin() - .await - .map_err(|e| TransactionError::Connection(e))?; - transaction.run(_callback).await - } -} - -#[derive(Debug)] -pub enum TransactionError -where - E: std::error::Error, -{ - Connection(DbErr), - Transaction(E), -} - -impl std::fmt::Display for TransactionError -where - E: std::error::Error, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - TransactionError::Connection(e) => std::fmt::Display::fmt(e, f), - TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f), - } - } -} - -impl std::error::Error for TransactionError where E: std::error::Error {} diff --git a/src/database/mock.rs b/src/database/mock.rs index f42add7a..f98c4b9d 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -1,9 +1,9 @@ use crate::{ error::*, DatabaseConnection, DbBackend, EntityTrait, ExecResult, ExecResultHolder, Iden, Iterable, MockDatabaseConnection, MockDatabaseTrait, ModelTrait, QueryResult, QueryResultRow, - Statement, Transaction, + Statement, }; -use sea_query::{Value, ValueType}; +use sea_query::{Value, ValueType, Values}; use std::{collections::BTreeMap, sync::Arc}; #[derive(Debug)] @@ -29,6 +29,11 @@ pub trait IntoMockRow { fn into_mock_row(self) -> MockRow; } +#[derive(Debug, Clone, PartialEq)] +pub struct Transaction { + stmts: Vec, +} + impl MockDatabase { pub fn new(db_backend: DbBackend) -> Self { Self { @@ -134,3 +139,38 @@ impl IntoMockRow for BTreeMap<&str, Value> { } } } + +impl Transaction { + pub fn from_sql_and_values(db_backend: DbBackend, sql: &str, values: I) -> Self + where + I: IntoIterator, + { + Self::one(Statement::from_string_values_tuple( + db_backend, + (sql.to_string(), Values(values.into_iter().collect())), + )) + } + + /// Create a Transaction with one statement + pub fn one(stmt: Statement) -> Self { + Self { stmts: vec![stmt] } + } + + /// Create a Transaction with many statements + pub fn many(stmts: I) -> Self + where + I: IntoIterator, + { + Self { + stmts: stmts.into_iter().collect(), + } + } + + /// Wrap each Statement as a single-statement Transaction + pub fn wrap(stmts: I) -> Vec + where + I: IntoIterator, + { + stmts.into_iter().map(Self::one).collect() + } +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 369ed539..a1dfea93 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,6 +1,5 @@ mod connection; mod db_connection; -mod db_transaction; #[cfg(feature = "mock")] mod mock; mod statement; @@ -9,7 +8,6 @@ mod transaction; pub use connection::*; pub use db_connection::*; -pub use db_transaction::*; #[cfg(feature = "mock")] pub use mock::*; pub use statement::*; diff --git a/src/database/transaction.rs b/src/database/transaction.rs index 6bf06491..3757052e 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -1,42 +1,370 @@ -use crate::{DbBackend, Statement}; -use sea_query::{Value, Values}; +use crate::{ + debug_print, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, + Statement, TransactionStream, +}; +#[cfg(feature = "sqlx-dep")] +use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err}; +use futures::lock::Mutex; +#[cfg(feature = "sqlx-dep")] +use sqlx::{pool::PoolConnection, TransactionManager}; +use std::{future::Future, pin::Pin, sync::Arc}; -#[derive(Debug, Clone, PartialEq)] -pub struct Transaction { - stmts: Vec, +// a Transaction is just a sugar for a connection where START TRANSACTION has been executed +pub struct DatabaseTransaction { + conn: Arc>, + backend: DbBackend, + open: bool, } -impl Transaction { - pub fn from_sql_and_values(db_backend: DbBackend, sql: &str, values: I) -> Self - where - I: IntoIterator, - { - Self::one(Statement::from_string_values_tuple( - db_backend, - (sql.to_string(), Values(values.into_iter().collect())), - )) +impl std::fmt::Debug for DatabaseTransaction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DatabaseTransaction") + } +} + +impl DatabaseTransaction { + #[cfg(feature = "sqlx-mysql")] + pub(crate) async fn new_mysql( + inner: PoolConnection, + ) -> Result { + Self::begin( + Arc::new(Mutex::new(InnerConnection::MySql(inner))), + DbBackend::MySql, + ) + .await } - /// Create a Transaction with one statement - pub fn one(stmt: Statement) -> Self { - Self { stmts: vec![stmt] } + #[cfg(feature = "sqlx-postgres")] + pub(crate) async fn new_postgres( + inner: PoolConnection, + ) -> Result { + Self::begin( + Arc::new(Mutex::new(InnerConnection::Postgres(inner))), + DbBackend::Postgres, + ) + .await } - /// Create a Transaction with many statements - pub fn many(stmts: I) -> Self + #[cfg(feature = "sqlx-sqlite")] + pub(crate) async fn new_sqlite( + inner: PoolConnection, + ) -> Result { + Self::begin( + Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), + DbBackend::Sqlite, + ) + .await + } + + #[cfg(feature = "mock")] + pub(crate) async fn new_mock( + inner: Arc, + ) -> Result { + let backend = inner.get_database_backend(); + Self::begin(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await + } + + async fn begin( + conn: Arc>, + backend: DbBackend, + ) -> Result { + let res = DatabaseTransaction { + conn, + backend, + open: true, + }; + match *res.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::begin(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::begin(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::begin(c) + .await + .map_err(sqlx_error_to_query_err)? + } + // should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {} + } + Ok(res) + } + + pub(crate) async fn run(self, callback: F) -> Result> where - I: IntoIterator, + F: for<'b> FnOnce( + &'b DatabaseTransaction, + ) -> Pin> + Send + 'b>> + + Send, + T: Send, + E: std::error::Error + Send, { - Self { - stmts: stmts.into_iter().collect(), + let res = callback(&self) + .await + .map_err(|e| TransactionError::Transaction(e)); + if res.is_ok() { + self.commit() + .await + .map_err(|e| TransactionError::Connection(e))?; + } else { + self.rollback() + .await + .map_err(|e| TransactionError::Connection(e))?; + } + res + } + + pub async fn commit(mut self) -> Result<(), DbErr> { + self.open = false; + match *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::commit(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::commit(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::commit(c) + .await + .map_err(sqlx_error_to_query_err)? + } + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {} + } + Ok(()) + } + + pub async fn rollback(mut self) -> Result<(), DbErr> { + self.open = false; + match *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::rollback(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::rollback(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::rollback(c) + .await + .map_err(sqlx_error_to_query_err)? + } + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {} + } + Ok(()) + } + + // the rollback is queued and will be performed on next async operation, like returning the connection to the pool + fn start_rollback(&mut self) { + if self.open { + if let Some(mut conn) = self.conn.try_lock() { + match &mut *conn { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + ::TransactionManager::start_rollback(c); + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + ::TransactionManager::start_rollback(c); + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + ::TransactionManager::start_rollback(c); + } + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {} + } + } else { + //this should never happen + panic!("Dropping a locked Transaction"); + } + } + } +} + +impl Drop for DatabaseTransaction { + fn drop(&mut self) { + self.start_rollback(); + } +} + +#[async_trait::async_trait] +impl<'a> ConnectionTrait<'a> for DatabaseTransaction { + type Stream = TransactionStream<'a>; + + fn get_database_backend(&self) -> DbBackend { + // this way we don't need to lock + self.backend + } + + async fn execute(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query.execute(conn).await.map(Into::into) + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query.execute(conn).await.map(Into::into) + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query.execute(conn).await.map(Into::into) + } + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.execute(stmt), + }; + #[cfg(feature = "sqlx-dep")] + _res.map_err(sqlx_error_to_exec_err) + } + + async fn query_one(&self, stmt: Statement) -> Result, DbErr> { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query.fetch_one(conn).await.map(|row| Some(row.into())) + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query.fetch_one(conn).await.map(|row| Some(row.into())) + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query.fetch_one(conn).await.map(|row| Some(row.into())) + } + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.query_one(stmt), + }; + #[cfg(feature = "sqlx-dep")] + if let Err(sqlx::Error::RowNotFound) = _res { + Ok(None) + } else { + _res.map_err(sqlx_error_to_query_err) } } - /// Wrap each Statement as a single-statement Transaction - pub fn wrap(stmts: I) -> Vec + async fn query_all(&self, stmt: Statement) -> Result, DbErr> { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query + .fetch_all(conn) + .await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query + .fetch_all(conn) + .await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query + .fetch_all(conn) + .await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + } + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.query_all(stmt), + }; + #[cfg(feature = "sqlx-dep")] + _res.map_err(sqlx_error_to_query_err) + } + + fn stream( + &'a self, + stmt: Statement, + ) -> Pin> + 'a>> { + Box::pin( + async move { Ok(crate::TransactionStream::build(self.conn.lock().await, stmt).await) }, + ) + } + + async fn begin(&self) -> Result { + DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend).await + } + + /// Execute the function inside a transaction. + /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. + async fn transaction(&self, _callback: F) -> Result> where - I: IntoIterator, + F: for<'c> FnOnce( + &'c DatabaseTransaction, + ) -> Pin> + Send + 'c>> + + Send, + T: Send, + E: std::error::Error + Send, { - stmts.into_iter().map(Self::one).collect() + let transaction = self + .begin() + .await + .map_err(|e| TransactionError::Connection(e))?; + transaction.run(_callback).await } } + +#[derive(Debug)] +pub enum TransactionError +where + E: std::error::Error, +{ + Connection(DbErr), + Transaction(E), +} + +impl std::fmt::Display for TransactionError +where + E: std::error::Error, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TransactionError::Connection(e) => std::fmt::Display::fmt(e, f), + TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f), + } + } +} + +impl std::error::Error for TransactionError where E: std::error::Error {}