From c7532bcc08c26a644816e81e77f2ef9befa3d514 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Tue, 5 Oct 2021 19:21:05 +0800 Subject: [PATCH] Basic MockTransaction implementation TODO: nested transaction --- src/database/mock.rs | 113 +++++++++++++++++++++++++++++++++++- src/database/statement.rs | 2 +- src/database/transaction.rs | 33 +++++------ src/driver/mock.rs | 24 +++++++- tests/stream_tests.rs | 3 +- 5 files changed, 150 insertions(+), 25 deletions(-) diff --git a/src/database/mock.rs b/src/database/mock.rs index 012c941f..d1bb65a4 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -9,6 +9,7 @@ use std::{collections::BTreeMap, sync::Arc}; #[derive(Debug)] pub struct MockDatabase { db_backend: DbBackend, + transaction: Option, transaction_log: Vec, exec_results: Vec, query_results: Vec>, @@ -29,6 +30,12 @@ pub trait IntoMockRow { fn into_mock_row(self) -> MockRow; } +#[derive(Debug)] +pub struct OpenTransaction { + stmts: Vec, + transaction_depth: usize, +} + #[derive(Debug, Clone, PartialEq)] pub struct MockTransaction { stmts: Vec, @@ -38,6 +45,7 @@ impl MockDatabase { pub fn new(db_backend: DbBackend) -> Self { Self { db_backend, + transaction: None, transaction_log: Vec::new(), exec_results: Vec::new(), query_results: Vec::new(), @@ -67,7 +75,11 @@ impl MockDatabase { impl MockDatabaseTrait for MockDatabase { fn execute(&mut self, counter: usize, statement: Statement) -> Result { - self.transaction_log.push(MockTransaction::one(statement)); + if let Some(transaction) = &mut self.transaction { + transaction.push(statement); + } else { + self.transaction_log.push(MockTransaction::one(statement)); + } if counter < self.exec_results.len() { Ok(ExecResult { result: ExecResultHolder::Mock(std::mem::take(&mut self.exec_results[counter])), @@ -78,7 +90,11 @@ impl MockDatabaseTrait for MockDatabase { } fn query(&mut self, counter: usize, statement: Statement) -> Result, DbErr> { - self.transaction_log.push(MockTransaction::one(statement)); + if let Some(transaction) = &mut self.transaction { + transaction.push(statement); + } else { + self.transaction_log.push(MockTransaction::one(statement)); + } if counter < self.query_results.len() { Ok(std::mem::take(&mut self.query_results[counter]) .into_iter() @@ -91,6 +107,32 @@ impl MockDatabaseTrait for MockDatabase { } } + fn begin(&mut self) { + if self.transaction.is_some() { + panic!("There is uncommitted transaction"); + } else { + self.transaction = Some(OpenTransaction::init()); + } + } + + fn commit(&mut self) { + if self.transaction.is_some() { + let transaction = self.transaction.take().unwrap(); + self.transaction_log + .push(transaction.into_mock_transaction()); + } else { + panic!("There is no open transaction to commit"); + } + } + + fn rollback(&mut self) { + if self.transaction.is_some() { + self.transaction = None; + } else { + panic!("There is no open transaction to rollback"); + } + } + fn drain_transaction_log(&mut self) -> Vec { std::mem::take(&mut self.transaction_log) } @@ -174,3 +216,70 @@ impl MockTransaction { stmts.into_iter().map(Self::one).collect() } } + +impl OpenTransaction { + fn init() -> Self { + Self { + stmts: Vec::new(), + transaction_depth: 0, + } + } + + fn push(&mut self, stmt: Statement) { + self.stmts.push(stmt); + } + + fn into_mock_transaction(self) -> MockTransaction { + MockTransaction { stmts: self.stmts } + } +} + +#[cfg(test)] +#[cfg(feature = "mock")] +mod tests { + use crate::{ + entity::*, tests_cfg::*, ConnectionTrait, DbBackend, DbErr, MockDatabase, MockTransaction, + Statement, + }; + + #[smol_potat::test] + async fn test_transaction_1() { + let db = MockDatabase::new(DbBackend::Postgres).into_connection(); + + db.transaction::<_, _, DbErr>(|txn| { + Box::pin(async move { + let _1 = cake::Entity::find().one(txn).await; + let _2 = fruit::Entity::find().all(txn).await; + + Ok(()) + }) + }) + .await + .unwrap(); + + let _ = cake::Entity::find().all(&db).await; + + assert_eq!( + db.into_transaction_log(), + vec![ + MockTransaction::many(vec![ + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#, + vec![1u64.into()] + ), + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#, + vec![] + ), + ]), + MockTransaction::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake""#, + vec![] + ), + ] + ); + } +} diff --git a/src/database/statement.rs b/src/database/statement.rs index 63a1d57f..12b07487 100644 --- a/src/database/statement.rs +++ b/src/database/statement.rs @@ -104,4 +104,4 @@ build_schema_stmt!(sea_query::TableCreateStatement); build_schema_stmt!(sea_query::TableDropStatement); build_schema_stmt!(sea_query::TableAlterStatement); build_schema_stmt!(sea_query::TableRenameStatement); -build_schema_stmt!(sea_query::TableTruncateStatement); \ No newline at end of file +build_schema_stmt!(sea_query::TableTruncateStatement); diff --git a/src/database/transaction.rs b/src/database/transaction.rs index 7230488b..d7bbc058 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -92,9 +92,10 @@ impl DatabaseTransaction { .await .map_err(sqlx_error_to_query_err)? } - // should we do something for mocked connections? #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} + InnerConnection::Mock(ref mut c) => { + c.begin(); + } } Ok(res) } @@ -108,13 +109,9 @@ impl DatabaseTransaction { T: Send, E: std::error::Error + Send, { - let res = callback(&self) - .await - .map_err(TransactionError::Transaction); + let res = callback(&self).await.map_err(TransactionError::Transaction); if res.is_ok() { - self.commit() - .await - .map_err(TransactionError::Connection)?; + self.commit().await.map_err(TransactionError::Connection)?; } else { self.rollback() .await @@ -144,9 +141,10 @@ impl DatabaseTransaction { .await .map_err(sqlx_error_to_query_err)? } - //Should we do something for mocked connections? #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} + InnerConnection::Mock(ref mut c) => { + c.commit(); + } } Ok(()) } @@ -172,9 +170,10 @@ impl DatabaseTransaction { .await .map_err(sqlx_error_to_query_err)? } - //Should we do something for mocked connections? #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} + InnerConnection::Mock(ref mut c) => { + c.rollback(); + } } Ok(()) } @@ -196,9 +195,10 @@ impl DatabaseTransaction { InnerConnection::Sqlite(c) => { ::TransactionManager::start_rollback(c); } - //Should we do something for mocked connections? #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} + InnerConnection::Mock(c) => { + c.rollback(); + } } } else { //this should never happen @@ -338,10 +338,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { T: Send, E: std::error::Error + Send, { - let transaction = self - .begin() - .await - .map_err(TransactionError::Connection)?; + let transaction = self.begin().await.map_err(TransactionError::Connection)?; transaction.run(_callback).await } } diff --git a/src/driver/mock.rs b/src/driver/mock.rs index ad6b2298..d3c36262 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -26,6 +26,12 @@ pub trait MockDatabaseTrait: Send + Debug { fn query(&mut self, counter: usize, stmt: Statement) -> Result, DbErr>; + fn begin(&mut self); + + fn commit(&mut self); + + fn rollback(&mut self); + fn drain_transaction_log(&mut self) -> Vec; fn get_database_backend(&self) -> DbBackend; @@ -86,10 +92,14 @@ impl MockDatabaseConnection { } } - pub fn get_mocker_mutex(&self) -> &Mutex> { + pub(crate) fn get_mocker_mutex(&self) -> &Mutex> { &self.mocker } + pub fn get_database_backend(&self) -> DbBackend { + self.mocker.lock().unwrap().get_database_backend() + } + pub fn execute(&self, statement: Statement) -> Result { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); @@ -119,7 +129,15 @@ impl MockDatabaseConnection { } } - pub fn get_database_backend(&self) -> DbBackend { - self.mocker.lock().unwrap().get_database_backend() + pub fn begin(&self) { + self.mocker.lock().unwrap().begin() + } + + pub fn commit(&self) { + self.mocker.lock().unwrap().commit() + } + + pub fn rollback(&self) { + self.mocker.lock().unwrap().rollback() } } diff --git a/tests/stream_tests.rs b/tests/stream_tests.rs index d30063e5..560fc01e 100644 --- a/tests/stream_tests.rs +++ b/tests/stream_tests.rs @@ -1,7 +1,6 @@ pub mod common; pub use common::{bakery_chain::*, setup::*, TestContext}; -use futures::StreamExt; pub use sea_orm::entity::*; pub use sea_orm::{ConnectionTrait, DbErr, QueryFilter}; @@ -12,6 +11,8 @@ pub use sea_orm::{ConnectionTrait, DbErr, QueryFilter}; feature = "sqlx-postgres" ))] pub async fn stream() -> Result<(), DbErr> { + use futures::StreamExt; + let ctx = TestContext::new("stream").await; let bakery = bakery::ActiveModel {