Basic MockTransaction implementation

TODO: nested transaction
This commit is contained in:
Chris Tsang 2021-10-05 19:21:05 +08:00
parent 8d0ba28b7e
commit c7532bcc08
5 changed files with 150 additions and 25 deletions

View File

@ -9,6 +9,7 @@ use std::{collections::BTreeMap, sync::Arc};
#[derive(Debug)] #[derive(Debug)]
pub struct MockDatabase { pub struct MockDatabase {
db_backend: DbBackend, db_backend: DbBackend,
transaction: Option<OpenTransaction>,
transaction_log: Vec<MockTransaction>, transaction_log: Vec<MockTransaction>,
exec_results: Vec<MockExecResult>, exec_results: Vec<MockExecResult>,
query_results: Vec<Vec<MockRow>>, query_results: Vec<Vec<MockRow>>,
@ -29,6 +30,12 @@ pub trait IntoMockRow {
fn into_mock_row(self) -> MockRow; fn into_mock_row(self) -> MockRow;
} }
#[derive(Debug)]
pub struct OpenTransaction {
stmts: Vec<Statement>,
transaction_depth: usize,
}
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct MockTransaction { pub struct MockTransaction {
stmts: Vec<Statement>, stmts: Vec<Statement>,
@ -38,6 +45,7 @@ impl MockDatabase {
pub fn new(db_backend: DbBackend) -> Self { pub fn new(db_backend: DbBackend) -> Self {
Self { Self {
db_backend, db_backend,
transaction: None,
transaction_log: Vec::new(), transaction_log: Vec::new(),
exec_results: Vec::new(), exec_results: Vec::new(),
query_results: Vec::new(), query_results: Vec::new(),
@ -67,7 +75,11 @@ impl MockDatabase {
impl MockDatabaseTrait for MockDatabase { impl MockDatabaseTrait for MockDatabase {
fn execute(&mut self, counter: usize, statement: Statement) -> Result<ExecResult, DbErr> { fn execute(&mut self, counter: usize, statement: Statement) -> Result<ExecResult, DbErr> {
self.transaction_log.push(MockTransaction::one(statement)); if let Some(transaction) = &mut self.transaction {
transaction.push(statement);
} else {
self.transaction_log.push(MockTransaction::one(statement));
}
if counter < self.exec_results.len() { if counter < self.exec_results.len() {
Ok(ExecResult { Ok(ExecResult {
result: ExecResultHolder::Mock(std::mem::take(&mut self.exec_results[counter])), result: ExecResultHolder::Mock(std::mem::take(&mut self.exec_results[counter])),
@ -78,7 +90,11 @@ impl MockDatabaseTrait for MockDatabase {
} }
fn query(&mut self, counter: usize, statement: Statement) -> Result<Vec<QueryResult>, DbErr> { fn query(&mut self, counter: usize, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
self.transaction_log.push(MockTransaction::one(statement)); if let Some(transaction) = &mut self.transaction {
transaction.push(statement);
} else {
self.transaction_log.push(MockTransaction::one(statement));
}
if counter < self.query_results.len() { if counter < self.query_results.len() {
Ok(std::mem::take(&mut self.query_results[counter]) Ok(std::mem::take(&mut self.query_results[counter])
.into_iter() .into_iter()
@ -91,6 +107,32 @@ impl MockDatabaseTrait for MockDatabase {
} }
} }
fn begin(&mut self) {
if self.transaction.is_some() {
panic!("There is uncommitted transaction");
} else {
self.transaction = Some(OpenTransaction::init());
}
}
fn commit(&mut self) {
if self.transaction.is_some() {
let transaction = self.transaction.take().unwrap();
self.transaction_log
.push(transaction.into_mock_transaction());
} else {
panic!("There is no open transaction to commit");
}
}
fn rollback(&mut self) {
if self.transaction.is_some() {
self.transaction = None;
} else {
panic!("There is no open transaction to rollback");
}
}
fn drain_transaction_log(&mut self) -> Vec<MockTransaction> { fn drain_transaction_log(&mut self) -> Vec<MockTransaction> {
std::mem::take(&mut self.transaction_log) std::mem::take(&mut self.transaction_log)
} }
@ -174,3 +216,70 @@ impl MockTransaction {
stmts.into_iter().map(Self::one).collect() stmts.into_iter().map(Self::one).collect()
} }
} }
impl OpenTransaction {
fn init() -> Self {
Self {
stmts: Vec::new(),
transaction_depth: 0,
}
}
fn push(&mut self, stmt: Statement) {
self.stmts.push(stmt);
}
fn into_mock_transaction(self) -> MockTransaction {
MockTransaction { stmts: self.stmts }
}
}
#[cfg(test)]
#[cfg(feature = "mock")]
mod tests {
use crate::{
entity::*, tests_cfg::*, ConnectionTrait, DbBackend, DbErr, MockDatabase, MockTransaction,
Statement,
};
#[smol_potat::test]
async fn test_transaction_1() {
let db = MockDatabase::new(DbBackend::Postgres).into_connection();
db.transaction::<_, _, DbErr>(|txn| {
Box::pin(async move {
let _1 = cake::Entity::find().one(txn).await;
let _2 = fruit::Entity::find().all(txn).await;
Ok(())
})
})
.await
.unwrap();
let _ = cake::Entity::find().all(&db).await;
assert_eq!(
db.into_transaction_log(),
vec![
MockTransaction::many(vec![
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#,
vec![1u64.into()]
),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#,
vec![]
),
]),
MockTransaction::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake""#,
vec![]
),
]
);
}
}

View File

@ -104,4 +104,4 @@ build_schema_stmt!(sea_query::TableCreateStatement);
build_schema_stmt!(sea_query::TableDropStatement); build_schema_stmt!(sea_query::TableDropStatement);
build_schema_stmt!(sea_query::TableAlterStatement); build_schema_stmt!(sea_query::TableAlterStatement);
build_schema_stmt!(sea_query::TableRenameStatement); build_schema_stmt!(sea_query::TableRenameStatement);
build_schema_stmt!(sea_query::TableTruncateStatement); build_schema_stmt!(sea_query::TableTruncateStatement);

View File

@ -92,9 +92,10 @@ impl DatabaseTransaction {
.await .await
.map_err(sqlx_error_to_query_err)? .map_err(sqlx_error_to_query_err)?
} }
// should we do something for mocked connections?
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
InnerConnection::Mock(_) => {} InnerConnection::Mock(ref mut c) => {
c.begin();
}
} }
Ok(res) Ok(res)
} }
@ -108,13 +109,9 @@ impl DatabaseTransaction {
T: Send, T: Send,
E: std::error::Error + Send, E: std::error::Error + Send,
{ {
let res = callback(&self) let res = callback(&self).await.map_err(TransactionError::Transaction);
.await
.map_err(TransactionError::Transaction);
if res.is_ok() { if res.is_ok() {
self.commit() self.commit().await.map_err(TransactionError::Connection)?;
.await
.map_err(TransactionError::Connection)?;
} else { } else {
self.rollback() self.rollback()
.await .await
@ -144,9 +141,10 @@ impl DatabaseTransaction {
.await .await
.map_err(sqlx_error_to_query_err)? .map_err(sqlx_error_to_query_err)?
} }
//Should we do something for mocked connections?
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
InnerConnection::Mock(_) => {} InnerConnection::Mock(ref mut c) => {
c.commit();
}
} }
Ok(()) Ok(())
} }
@ -172,9 +170,10 @@ impl DatabaseTransaction {
.await .await
.map_err(sqlx_error_to_query_err)? .map_err(sqlx_error_to_query_err)?
} }
//Should we do something for mocked connections?
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
InnerConnection::Mock(_) => {} InnerConnection::Mock(ref mut c) => {
c.rollback();
}
} }
Ok(()) Ok(())
} }
@ -196,9 +195,10 @@ impl DatabaseTransaction {
InnerConnection::Sqlite(c) => { InnerConnection::Sqlite(c) => {
<sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c); <sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
} }
//Should we do something for mocked connections?
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
InnerConnection::Mock(_) => {} InnerConnection::Mock(c) => {
c.rollback();
}
} }
} else { } else {
//this should never happen //this should never happen
@ -338,10 +338,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
T: Send, T: Send,
E: std::error::Error + Send, E: std::error::Error + Send,
{ {
let transaction = self let transaction = self.begin().await.map_err(TransactionError::Connection)?;
.begin()
.await
.map_err(TransactionError::Connection)?;
transaction.run(_callback).await transaction.run(_callback).await
} }
} }

View File

@ -26,6 +26,12 @@ pub trait MockDatabaseTrait: Send + Debug {
fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>; fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
fn begin(&mut self);
fn commit(&mut self);
fn rollback(&mut self);
fn drain_transaction_log(&mut self) -> Vec<MockTransaction>; fn drain_transaction_log(&mut self) -> Vec<MockTransaction>;
fn get_database_backend(&self) -> DbBackend; fn get_database_backend(&self) -> DbBackend;
@ -86,10 +92,14 @@ impl MockDatabaseConnection {
} }
} }
pub fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> { pub(crate) fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> {
&self.mocker &self.mocker
} }
pub fn get_database_backend(&self) -> DbBackend {
self.mocker.lock().unwrap().get_database_backend()
}
pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> { pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> {
debug_print!("{}", statement); debug_print!("{}", statement);
let counter = self.counter.fetch_add(1, Ordering::SeqCst); let counter = self.counter.fetch_add(1, Ordering::SeqCst);
@ -119,7 +129,15 @@ impl MockDatabaseConnection {
} }
} }
pub fn get_database_backend(&self) -> DbBackend { pub fn begin(&self) {
self.mocker.lock().unwrap().get_database_backend() self.mocker.lock().unwrap().begin()
}
pub fn commit(&self) {
self.mocker.lock().unwrap().commit()
}
pub fn rollback(&self) {
self.mocker.lock().unwrap().rollback()
} }
} }

View File

@ -1,7 +1,6 @@
pub mod common; pub mod common;
pub use common::{bakery_chain::*, setup::*, TestContext}; pub use common::{bakery_chain::*, setup::*, TestContext};
use futures::StreamExt;
pub use sea_orm::entity::*; pub use sea_orm::entity::*;
pub use sea_orm::{ConnectionTrait, DbErr, QueryFilter}; pub use sea_orm::{ConnectionTrait, DbErr, QueryFilter};
@ -12,6 +11,8 @@ pub use sea_orm::{ConnectionTrait, DbErr, QueryFilter};
feature = "sqlx-postgres" feature = "sqlx-postgres"
))] ))]
pub async fn stream() -> Result<(), DbErr> { pub async fn stream() -> Result<(), DbErr> {
use futures::StreamExt;
let ctx = TestContext::new("stream").await; let ctx = TestContext::new("stream").await;
let bakery = bakery::ActiveModel { let bakery = bakery::ActiveModel {