From 1a2bd13158edaa58d2bfe8739bdd15b27676cd32 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 11 Oct 2021 18:39:46 +0800 Subject: [PATCH] Nested transaction unit tests --- src/database/mock.rs | 176 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 166 insertions(+), 10 deletions(-) diff --git a/src/database/mock.rs b/src/database/mock.rs index bdab864c..6bb710b7 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -109,7 +109,10 @@ impl MockDatabaseTrait for MockDatabase { fn begin(&mut self) { if self.transaction.is_some() { - panic!("There is uncommitted transaction"); + self.transaction + .as_mut() + .unwrap() + .begin_nested(self.db_backend); } else { self.transaction = Some(OpenTransaction::init()); } @@ -117,9 +120,10 @@ impl MockDatabaseTrait for MockDatabase { fn commit(&mut self) { if self.transaction.is_some() { - let mut transaction = self.transaction.take().unwrap(); - transaction.push(Statement::from_string(self.db_backend, "COMMIT".to_owned())); - self.transaction_log.push(transaction.into_transaction()); + if self.transaction.as_mut().unwrap().commit(self.db_backend) { + let transaction = self.transaction.take().unwrap(); + self.transaction_log.push(transaction.into_transaction()); + } } else { panic!("There is no open transaction to commit"); } @@ -127,12 +131,10 @@ impl MockDatabaseTrait for MockDatabase { fn rollback(&mut self) { if self.transaction.is_some() { - let mut transaction = self.transaction.take().unwrap(); - transaction.push(Statement::from_string( - self.db_backend, - "ROLLBACK".to_owned(), - )); - self.transaction_log.push(transaction.into_transaction()); + if self.transaction.as_mut().unwrap().rollback(self.db_backend) { + let transaction = self.transaction.take().unwrap(); + self.transaction_log.push(transaction.into_transaction()); + } } else { panic!("There is no open transaction to rollback"); } @@ -230,11 +232,50 @@ impl OpenTransaction { } } + fn begin_nested(&mut self, db_backend: DbBackend) { + self.transaction_depth += 1; + self.push(Statement::from_string( + db_backend, + format!("SAVEPOINT savepoint_{}", self.transaction_depth), + )); + } + + fn commit(&mut self, db_backend: DbBackend) -> bool { + if self.transaction_depth == 0 { + self.push(Statement::from_string(db_backend, "COMMIT".to_owned())); + true + } else { + self.push(Statement::from_string( + db_backend, + format!("RELEASE SAVEPOINT savepoint_{}", self.transaction_depth), + )); + self.transaction_depth -= 1; + false + } + } + + fn rollback(&mut self, db_backend: DbBackend) -> bool { + if self.transaction_depth == 0 { + self.push(Statement::from_string(db_backend, "ROLLBACK".to_owned())); + true + } else { + self.push(Statement::from_string( + db_backend, + format!("ROLLBACK TO SAVEPOINT savepoint_{}", self.transaction_depth), + )); + self.transaction_depth -= 1; + false + } + } + fn push(&mut self, stmt: Statement) { self.stmts.push(stmt); } fn into_transaction(self) -> Transaction { + if self.transaction_depth != 0 { + panic!("There is uncommitted nested transaction."); + } Transaction { stmts: self.stmts } } } @@ -246,6 +287,7 @@ mod tests { entity::*, tests_cfg::*, ConnectionTrait, DbBackend, DbErr, MockDatabase, Statement, Transaction, TransactionError, }; + use pretty_assertions::assert_eq; #[derive(Debug, PartialEq)] pub struct MyErr(String); @@ -333,6 +375,120 @@ mod tests { ); } + #[smol_potat::test] + async fn test_nested_transaction_1() { + let db = MockDatabase::new(DbBackend::Postgres).into_connection(); + + db.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = cake::Entity::find().one(txn).await; + + txn.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = fruit::Entity::find().all(txn).await; + + Ok(()) + }) + }) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + assert_eq!( + db.into_transaction_log(), + vec![Transaction::many(vec![ + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#, + vec![1u64.into()] + ), + Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_1".to_owned()), + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#, + vec![] + ), + Statement::from_string( + DbBackend::Postgres, + "RELEASE SAVEPOINT savepoint_1".to_owned() + ), + Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()), + ]),] + ); + } + + #[smol_potat::test] + async fn test_nested_transaction_2() { + let db = MockDatabase::new(DbBackend::Postgres).into_connection(); + + db.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = cake::Entity::find().one(txn).await; + + txn.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = fruit::Entity::find().all(txn).await; + + txn.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = cake::Entity::find().all(txn).await; + + Ok(()) + }) + }) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + assert_eq!( + db.into_transaction_log(), + vec![Transaction::many(vec![ + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#, + vec![1u64.into()] + ), + Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_1".to_owned()), + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#, + vec![] + ), + Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_2".to_owned()), + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake""#, + vec![] + ), + Statement::from_string( + DbBackend::Postgres, + "RELEASE SAVEPOINT savepoint_2".to_owned() + ), + Statement::from_string( + DbBackend::Postgres, + "RELEASE SAVEPOINT savepoint_1".to_owned() + ), + Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()), + ]),] + ); + } + #[smol_potat::test] async fn test_stream_1() -> Result<(), DbErr> { use futures::TryStreamExt;