Basic MockTransaction implementation
TODO: nested transaction
This commit is contained in:
parent
8d0ba28b7e
commit
c7532bcc08
@ -9,6 +9,7 @@ use std::{collections::BTreeMap, sync::Arc};
|
||||
#[derive(Debug)]
|
||||
pub struct MockDatabase {
|
||||
db_backend: DbBackend,
|
||||
transaction: Option<OpenTransaction>,
|
||||
transaction_log: Vec<MockTransaction>,
|
||||
exec_results: Vec<MockExecResult>,
|
||||
query_results: Vec<Vec<MockRow>>,
|
||||
@ -29,6 +30,12 @@ pub trait IntoMockRow {
|
||||
fn into_mock_row(self) -> MockRow;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OpenTransaction {
|
||||
stmts: Vec<Statement>,
|
||||
transaction_depth: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct MockTransaction {
|
||||
stmts: Vec<Statement>,
|
||||
@ -38,6 +45,7 @@ impl MockDatabase {
|
||||
pub fn new(db_backend: DbBackend) -> Self {
|
||||
Self {
|
||||
db_backend,
|
||||
transaction: None,
|
||||
transaction_log: Vec::new(),
|
||||
exec_results: Vec::new(),
|
||||
query_results: Vec::new(),
|
||||
@ -67,7 +75,11 @@ impl MockDatabase {
|
||||
|
||||
impl MockDatabaseTrait for MockDatabase {
|
||||
fn execute(&mut self, counter: usize, statement: Statement) -> Result<ExecResult, DbErr> {
|
||||
if let Some(transaction) = &mut self.transaction {
|
||||
transaction.push(statement);
|
||||
} else {
|
||||
self.transaction_log.push(MockTransaction::one(statement));
|
||||
}
|
||||
if counter < self.exec_results.len() {
|
||||
Ok(ExecResult {
|
||||
result: ExecResultHolder::Mock(std::mem::take(&mut self.exec_results[counter])),
|
||||
@ -78,7 +90,11 @@ impl MockDatabaseTrait for MockDatabase {
|
||||
}
|
||||
|
||||
fn query(&mut self, counter: usize, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
|
||||
if let Some(transaction) = &mut self.transaction {
|
||||
transaction.push(statement);
|
||||
} else {
|
||||
self.transaction_log.push(MockTransaction::one(statement));
|
||||
}
|
||||
if counter < self.query_results.len() {
|
||||
Ok(std::mem::take(&mut self.query_results[counter])
|
||||
.into_iter()
|
||||
@ -91,6 +107,32 @@ impl MockDatabaseTrait for MockDatabase {
|
||||
}
|
||||
}
|
||||
|
||||
fn begin(&mut self) {
|
||||
if self.transaction.is_some() {
|
||||
panic!("There is uncommitted transaction");
|
||||
} else {
|
||||
self.transaction = Some(OpenTransaction::init());
|
||||
}
|
||||
}
|
||||
|
||||
fn commit(&mut self) {
|
||||
if self.transaction.is_some() {
|
||||
let transaction = self.transaction.take().unwrap();
|
||||
self.transaction_log
|
||||
.push(transaction.into_mock_transaction());
|
||||
} else {
|
||||
panic!("There is no open transaction to commit");
|
||||
}
|
||||
}
|
||||
|
||||
fn rollback(&mut self) {
|
||||
if self.transaction.is_some() {
|
||||
self.transaction = None;
|
||||
} else {
|
||||
panic!("There is no open transaction to rollback");
|
||||
}
|
||||
}
|
||||
|
||||
fn drain_transaction_log(&mut self) -> Vec<MockTransaction> {
|
||||
std::mem::take(&mut self.transaction_log)
|
||||
}
|
||||
@ -174,3 +216,70 @@ impl MockTransaction {
|
||||
stmts.into_iter().map(Self::one).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenTransaction {
|
||||
fn init() -> Self {
|
||||
Self {
|
||||
stmts: Vec::new(),
|
||||
transaction_depth: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn push(&mut self, stmt: Statement) {
|
||||
self.stmts.push(stmt);
|
||||
}
|
||||
|
||||
fn into_mock_transaction(self) -> MockTransaction {
|
||||
MockTransaction { stmts: self.stmts }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(feature = "mock")]
|
||||
mod tests {
|
||||
use crate::{
|
||||
entity::*, tests_cfg::*, ConnectionTrait, DbBackend, DbErr, MockDatabase, MockTransaction,
|
||||
Statement,
|
||||
};
|
||||
|
||||
#[smol_potat::test]
|
||||
async fn test_transaction_1() {
|
||||
let db = MockDatabase::new(DbBackend::Postgres).into_connection();
|
||||
|
||||
db.transaction::<_, _, DbErr>(|txn| {
|
||||
Box::pin(async move {
|
||||
let _1 = cake::Entity::find().one(txn).await;
|
||||
let _2 = fruit::Entity::find().all(txn).await;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let _ = cake::Entity::find().all(&db).await;
|
||||
|
||||
assert_eq!(
|
||||
db.into_transaction_log(),
|
||||
vec![
|
||||
MockTransaction::many(vec![
|
||||
Statement::from_sql_and_values(
|
||||
DbBackend::Postgres,
|
||||
r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#,
|
||||
vec![1u64.into()]
|
||||
),
|
||||
Statement::from_sql_and_values(
|
||||
DbBackend::Postgres,
|
||||
r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#,
|
||||
vec![]
|
||||
),
|
||||
]),
|
||||
MockTransaction::from_sql_and_values(
|
||||
DbBackend::Postgres,
|
||||
r#"SELECT "cake"."id", "cake"."name" FROM "cake""#,
|
||||
vec![]
|
||||
),
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -92,9 +92,10 @@ impl DatabaseTransaction {
|
||||
.await
|
||||
.map_err(sqlx_error_to_query_err)?
|
||||
}
|
||||
// should we do something for mocked connections?
|
||||
#[cfg(feature = "mock")]
|
||||
InnerConnection::Mock(_) => {}
|
||||
InnerConnection::Mock(ref mut c) => {
|
||||
c.begin();
|
||||
}
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
@ -108,13 +109,9 @@ impl DatabaseTransaction {
|
||||
T: Send,
|
||||
E: std::error::Error + Send,
|
||||
{
|
||||
let res = callback(&self)
|
||||
.await
|
||||
.map_err(TransactionError::Transaction);
|
||||
let res = callback(&self).await.map_err(TransactionError::Transaction);
|
||||
if res.is_ok() {
|
||||
self.commit()
|
||||
.await
|
||||
.map_err(TransactionError::Connection)?;
|
||||
self.commit().await.map_err(TransactionError::Connection)?;
|
||||
} else {
|
||||
self.rollback()
|
||||
.await
|
||||
@ -144,9 +141,10 @@ impl DatabaseTransaction {
|
||||
.await
|
||||
.map_err(sqlx_error_to_query_err)?
|
||||
}
|
||||
//Should we do something for mocked connections?
|
||||
#[cfg(feature = "mock")]
|
||||
InnerConnection::Mock(_) => {}
|
||||
InnerConnection::Mock(ref mut c) => {
|
||||
c.commit();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@ -172,9 +170,10 @@ impl DatabaseTransaction {
|
||||
.await
|
||||
.map_err(sqlx_error_to_query_err)?
|
||||
}
|
||||
//Should we do something for mocked connections?
|
||||
#[cfg(feature = "mock")]
|
||||
InnerConnection::Mock(_) => {}
|
||||
InnerConnection::Mock(ref mut c) => {
|
||||
c.rollback();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@ -196,9 +195,10 @@ impl DatabaseTransaction {
|
||||
InnerConnection::Sqlite(c) => {
|
||||
<sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
|
||||
}
|
||||
//Should we do something for mocked connections?
|
||||
#[cfg(feature = "mock")]
|
||||
InnerConnection::Mock(_) => {}
|
||||
InnerConnection::Mock(c) => {
|
||||
c.rollback();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
//this should never happen
|
||||
@ -338,10 +338,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
|
||||
T: Send,
|
||||
E: std::error::Error + Send,
|
||||
{
|
||||
let transaction = self
|
||||
.begin()
|
||||
.await
|
||||
.map_err(TransactionError::Connection)?;
|
||||
let transaction = self.begin().await.map_err(TransactionError::Connection)?;
|
||||
transaction.run(_callback).await
|
||||
}
|
||||
}
|
||||
|
@ -26,6 +26,12 @@ pub trait MockDatabaseTrait: Send + Debug {
|
||||
|
||||
fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
|
||||
|
||||
fn begin(&mut self);
|
||||
|
||||
fn commit(&mut self);
|
||||
|
||||
fn rollback(&mut self);
|
||||
|
||||
fn drain_transaction_log(&mut self) -> Vec<MockTransaction>;
|
||||
|
||||
fn get_database_backend(&self) -> DbBackend;
|
||||
@ -86,10 +92,14 @@ impl MockDatabaseConnection {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> {
|
||||
pub(crate) fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> {
|
||||
&self.mocker
|
||||
}
|
||||
|
||||
pub fn get_database_backend(&self) -> DbBackend {
|
||||
self.mocker.lock().unwrap().get_database_backend()
|
||||
}
|
||||
|
||||
pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> {
|
||||
debug_print!("{}", statement);
|
||||
let counter = self.counter.fetch_add(1, Ordering::SeqCst);
|
||||
@ -119,7 +129,15 @@ impl MockDatabaseConnection {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_database_backend(&self) -> DbBackend {
|
||||
self.mocker.lock().unwrap().get_database_backend()
|
||||
pub fn begin(&self) {
|
||||
self.mocker.lock().unwrap().begin()
|
||||
}
|
||||
|
||||
pub fn commit(&self) {
|
||||
self.mocker.lock().unwrap().commit()
|
||||
}
|
||||
|
||||
pub fn rollback(&self) {
|
||||
self.mocker.lock().unwrap().rollback()
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
pub mod common;
|
||||
|
||||
pub use common::{bakery_chain::*, setup::*, TestContext};
|
||||
use futures::StreamExt;
|
||||
pub use sea_orm::entity::*;
|
||||
pub use sea_orm::{ConnectionTrait, DbErr, QueryFilter};
|
||||
|
||||
@ -12,6 +11,8 @@ pub use sea_orm::{ConnectionTrait, DbErr, QueryFilter};
|
||||
feature = "sqlx-postgres"
|
||||
))]
|
||||
pub async fn stream() -> Result<(), DbErr> {
|
||||
use futures::StreamExt;
|
||||
|
||||
let ctx = TestContext::new("stream").await;
|
||||
|
||||
let bakery = bakery::ActiveModel {
|
||||
|
Loading…
x
Reference in New Issue
Block a user