Nested transaction unit tests

This commit is contained in:
Chris Tsang 2021-10-11 18:39:46 +08:00
parent f94c33d1ea
commit 1a2bd13158

View File

@ -109,7 +109,10 @@ impl MockDatabaseTrait for MockDatabase {
fn begin(&mut self) {
if self.transaction.is_some() {
panic!("There is uncommitted transaction");
self.transaction
.as_mut()
.unwrap()
.begin_nested(self.db_backend);
} else {
self.transaction = Some(OpenTransaction::init());
}
@ -117,9 +120,10 @@ impl MockDatabaseTrait for MockDatabase {
fn commit(&mut self) {
if self.transaction.is_some() {
let mut transaction = self.transaction.take().unwrap();
transaction.push(Statement::from_string(self.db_backend, "COMMIT".to_owned()));
if self.transaction.as_mut().unwrap().commit(self.db_backend) {
let transaction = self.transaction.take().unwrap();
self.transaction_log.push(transaction.into_transaction());
}
} else {
panic!("There is no open transaction to commit");
}
@ -127,12 +131,10 @@ impl MockDatabaseTrait for MockDatabase {
fn rollback(&mut self) {
if self.transaction.is_some() {
let mut transaction = self.transaction.take().unwrap();
transaction.push(Statement::from_string(
self.db_backend,
"ROLLBACK".to_owned(),
));
if self.transaction.as_mut().unwrap().rollback(self.db_backend) {
let transaction = self.transaction.take().unwrap();
self.transaction_log.push(transaction.into_transaction());
}
} else {
panic!("There is no open transaction to rollback");
}
@ -230,11 +232,50 @@ impl OpenTransaction {
}
}
fn begin_nested(&mut self, db_backend: DbBackend) {
self.transaction_depth += 1;
self.push(Statement::from_string(
db_backend,
format!("SAVEPOINT savepoint_{}", self.transaction_depth),
));
}
fn commit(&mut self, db_backend: DbBackend) -> bool {
if self.transaction_depth == 0 {
self.push(Statement::from_string(db_backend, "COMMIT".to_owned()));
true
} else {
self.push(Statement::from_string(
db_backend,
format!("RELEASE SAVEPOINT savepoint_{}", self.transaction_depth),
));
self.transaction_depth -= 1;
false
}
}
fn rollback(&mut self, db_backend: DbBackend) -> bool {
if self.transaction_depth == 0 {
self.push(Statement::from_string(db_backend, "ROLLBACK".to_owned()));
true
} else {
self.push(Statement::from_string(
db_backend,
format!("ROLLBACK TO SAVEPOINT savepoint_{}", self.transaction_depth),
));
self.transaction_depth -= 1;
false
}
}
fn push(&mut self, stmt: Statement) {
self.stmts.push(stmt);
}
fn into_transaction(self) -> Transaction {
if self.transaction_depth != 0 {
panic!("There is uncommitted nested transaction.");
}
Transaction { stmts: self.stmts }
}
}
@ -246,6 +287,7 @@ mod tests {
entity::*, tests_cfg::*, ConnectionTrait, DbBackend, DbErr, MockDatabase, Statement,
Transaction, TransactionError,
};
use pretty_assertions::assert_eq;
#[derive(Debug, PartialEq)]
pub struct MyErr(String);
@ -333,6 +375,120 @@ mod tests {
);
}
#[smol_potat::test]
async fn test_nested_transaction_1() {
let db = MockDatabase::new(DbBackend::Postgres).into_connection();
db.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = cake::Entity::find().one(txn).await;
txn.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = fruit::Entity::find().all(txn).await;
Ok(())
})
})
.await
.unwrap();
Ok(())
})
})
.await
.unwrap();
assert_eq!(
db.into_transaction_log(),
vec![Transaction::many(vec![
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#,
vec![1u64.into()]
),
Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_1".to_owned()),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#,
vec![]
),
Statement::from_string(
DbBackend::Postgres,
"RELEASE SAVEPOINT savepoint_1".to_owned()
),
Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()),
]),]
);
}
#[smol_potat::test]
async fn test_nested_transaction_2() {
let db = MockDatabase::new(DbBackend::Postgres).into_connection();
db.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = cake::Entity::find().one(txn).await;
txn.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = fruit::Entity::find().all(txn).await;
txn.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = cake::Entity::find().all(txn).await;
Ok(())
})
})
.await
.unwrap();
Ok(())
})
})
.await
.unwrap();
Ok(())
})
})
.await
.unwrap();
assert_eq!(
db.into_transaction_log(),
vec![Transaction::many(vec![
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#,
vec![1u64.into()]
),
Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_1".to_owned()),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#,
vec![]
),
Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_2".to_owned()),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake""#,
vec![]
),
Statement::from_string(
DbBackend::Postgres,
"RELEASE SAVEPOINT savepoint_2".to_owned()
),
Statement::from_string(
DbBackend::Postgres,
"RELEASE SAVEPOINT savepoint_1".to_owned()
),
Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()),
]),]
);
}
#[smol_potat::test]
async fn test_stream_1() -> Result<(), DbErr> {
use futures::TryStreamExt;