diff --git a/src/database/connection.rs b/src/database/connection.rs index 4440ad84..17cc9d0f 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -1,4 +1,4 @@ -use crate::{ExecErr, ExecResult, QueryErr, QueryResult, Statement}; +use crate::{ExecErr, ExecResult, MockDatabaseConnection, QueryErr, QueryResult, Statement}; use sea_query::{ DeleteStatement, InsertStatement, MysqlQueryBuilder, PostgresQueryBuilder, SelectStatement, UpdateStatement, @@ -91,6 +91,13 @@ impl DatabaseConnection { DatabaseConnection::Disconnected => panic!("Disconnected"), } } + + pub fn as_mock_connection(&self) -> &MockDatabaseConnection { + match self { + DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn, + _ => panic!("not mock connection"), + } + } } impl QueryBuilderBackend { diff --git a/src/database/mock.rs b/src/database/mock.rs index a8ba80e6..5d197f4d 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -51,10 +51,6 @@ impl MockDatabase { } self } - - pub fn into_transaction_log(self) -> Vec { - self.transaction_log - } } impl MockDatabaseTrait for MockDatabase { @@ -87,18 +83,9 @@ impl MockDatabaseTrait for MockDatabase { } } - fn into_transaction_log(&mut self) -> Vec { + fn drain_transaction_log(&mut self) -> Vec { std::mem::take(&mut self.transaction_log) } - - fn assert_transaction_log(&mut self, stmts: Vec) { - for stmt in stmts.iter() { - assert!(!self.transaction_log.is_empty()); - let log = self.transaction_log.first().unwrap(); - assert_eq!(log.to_string(), stmt.to_string()); - self.transaction_log = self.transaction_log.drain(1..).collect(); - } - } } impl MockRow { diff --git a/src/database/statement.rs b/src/database/statement.rs index 61604b38..f27578bd 100644 --- a/src/database/statement.rs +++ b/src/database/statement.rs @@ -1,7 +1,7 @@ use sea_query::{inject_parameters, MySqlQueryBuilder, Values}; use std::fmt; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Statement { pub sql: String, pub values: Option, diff --git a/src/driver/mock.rs b/src/driver/mock.rs index 8be8478b..39cec2f9 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -10,8 +10,8 @@ use std::sync::{ pub struct MockDatabaseConnector; pub struct MockDatabaseConnection { - pub(crate) counter: AtomicUsize, - pub(crate) mocker: Mutex>, + counter: AtomicUsize, + mocker: Mutex>, } pub trait MockDatabaseTrait: Send { @@ -19,9 +19,7 @@ pub trait MockDatabaseTrait: Send { fn query(&mut self, counter: usize, stmt: Statement) -> Result, QueryErr>; - fn into_transaction_log(&mut self) -> Vec; - - fn assert_transaction_log(&mut self, stmts: Vec); + fn drain_transaction_log(&mut self) -> Vec; } impl MockDatabaseConnector { @@ -46,9 +44,11 @@ impl MockDatabaseConnection { mocker: Mutex::new(Box::new(m)), } } -} -impl MockDatabaseConnection { + pub fn get_mocker_mutex(&self) -> &Mutex> { + &self.mocker + } + pub async fn execute(&self, statement: Statement) -> Result { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index b159ee7b..19654d63 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -104,7 +104,7 @@ where #[cfg(feature = "mock")] mod tests { use crate::entity::prelude::*; - use crate::tests_cfg::{util::*, *}; + use crate::tests_cfg::*; use crate::{Database, MockDatabase, QueryErr}; use futures::TryStreamExt; use sea_query::{Alias, Expr, SelectStatement, Value}; @@ -174,8 +174,13 @@ mod tests { query_builder.build_select_statement(select.clone().offset(2).limit(2)), query_builder.build_select_statement(select.clone().offset(4).limit(2)), ]; - let mut mocker = get_mock_db_connection(&db).mocker.lock().unwrap(); - mocker.assert_transaction_log(stmts); + let mut mocker = db + .get_connection() + .as_mock_connection() + .get_mocker_mutex() + .lock() + .unwrap(); + assert_eq!(mocker.drain_transaction_log(), stmts); Ok(()) } @@ -209,8 +214,13 @@ mod tests { query_builder.build_select_statement(select.clone().offset(2).limit(2)), query_builder.build_select_statement(select.clone().offset(4).limit(2)), ]; - let mut mocker = get_mock_db_connection(&db).mocker.lock().unwrap(); - mocker.assert_transaction_log(stmts); + let mut mocker = db + .get_connection() + .as_mock_connection() + .get_mocker_mutex() + .lock() + .unwrap(); + assert_eq!(mocker.drain_transaction_log(), stmts); Ok(()) } @@ -242,8 +252,13 @@ mod tests { let query_builder = db.get_query_builder_backend(); let stmts = vec![query_builder.build_select_statement(&select)]; - let mut mocker = get_mock_db_connection(&db).mocker.lock().unwrap(); - mocker.assert_transaction_log(stmts); + let mut mocker = db + .get_connection() + .as_mock_connection() + .get_mocker_mutex() + .lock() + .unwrap(); + assert_eq!(mocker.drain_transaction_log(), stmts); Ok(()) } @@ -295,8 +310,13 @@ mod tests { query_builder.build_select_statement(select.clone().offset(2).limit(2)), query_builder.build_select_statement(select.clone().offset(4).limit(2)), ]; - let mut mocker = get_mock_db_connection(&db).mocker.lock().unwrap(); - mocker.assert_transaction_log(stmts); + let mut mocker = db + .get_connection() + .as_mock_connection() + .get_mocker_mutex() + .lock() + .unwrap(); + assert_eq!(mocker.drain_transaction_log(), stmts); Ok(()) } @@ -328,9 +348,13 @@ mod tests { query_builder.build_select_statement(select.clone().offset(2).limit(2)), query_builder.build_select_statement(select.clone().offset(4).limit(2)), ]; - let mut mocker = get_mock_db_connection(&db).mocker.lock().unwrap(); - mocker.assert_transaction_log(stmts[0..1].to_vec()); - mocker.assert_transaction_log(stmts[1..].to_vec()); + let mut mocker = db + .get_connection() + .as_mock_connection() + .get_mocker_mutex() + .lock() + .unwrap(); + assert_eq!(mocker.drain_transaction_log(), stmts); Ok(()) } diff --git a/src/tests_cfg/util.rs b/src/tests_cfg/util.rs index 1927e9ad..177fdf14 100644 --- a/src/tests_cfg/util.rs +++ b/src/tests_cfg/util.rs @@ -1,6 +1,4 @@ -use crate::{ - tests_cfg::*, Database, DatabaseConnection, IntoMockRow, MockDatabaseConnection, MockRow, -}; +use crate::{tests_cfg::*, IntoMockRow, MockRow}; use sea_query::Value; impl From for MockRow { @@ -43,10 +41,3 @@ impl From for MockRow { map.into_mock_row() } } - -pub fn get_mock_db_connection(db: &Database) -> &MockDatabaseConnection { - match db.get_connection() { - DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn, - _ => unreachable!(), - } -}