use crate::{ error::*, ConnectionTrait, DatabaseTransaction, ExecResult, QueryResult, Statement, StatementBuilder, TransactionError, }; use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; use std::{future::Future, pin::Pin}; #[cfg(feature = "sqlx-dep")] use sqlx::pool::PoolConnection; #[cfg(feature = "mock")] use std::sync::Arc; #[cfg_attr(not(feature = "mock"), derive(Clone))] pub enum DatabaseConnection { #[cfg(feature = "sqlx-mysql")] SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection), #[cfg(feature = "sqlx-postgres")] SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection), #[cfg(feature = "sqlx-sqlite")] SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection), #[cfg(feature = "mock")] MockDatabaseConnection(Arc), Disconnected, } pub type DbConn = DatabaseConnection; #[derive(Debug, Copy, Clone, PartialEq)] pub enum DatabaseBackend { MySql, Postgres, Sqlite, } pub type DbBackend = DatabaseBackend; pub(crate) enum InnerConnection { #[cfg(feature = "sqlx-mysql")] MySql(PoolConnection), #[cfg(feature = "sqlx-postgres")] Postgres(PoolConnection), #[cfg(feature = "sqlx-sqlite")] Sqlite(PoolConnection), #[cfg(feature = "mock")] Mock(std::sync::Arc), } impl Default for DatabaseConnection { fn default() -> Self { Self::Disconnected } } impl std::fmt::Debug for DatabaseConnection { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!( f, "{}", match self { #[cfg(feature = "sqlx-mysql")] Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection", #[cfg(feature = "sqlx-postgres")] Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection", #[cfg(feature = "sqlx-sqlite")] Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection", #[cfg(feature = "mock")] Self::MockDatabaseConnection(_) => "MockDatabaseConnection", Self::Disconnected => "Disconnected", } ) } } #[async_trait::async_trait] impl<'a> ConnectionTrait<'a> for DatabaseConnection { type Stream = crate::QueryStream; fn get_database_backend(&self) -> DbBackend { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres, #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite, #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => conn.get_database_backend(), DatabaseConnection::Disconnected => panic!("Disconnected"), } } async fn execute(&self, stmt: Statement) -> Result { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await, #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await, #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } async fn query_one(&self, stmt: Statement) -> Result, DbErr> { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await, #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await, #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } async fn query_all(&self, stmt: Statement) -> Result, DbErr> { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await, #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await, #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } fn stream( &'a self, stmt: Statement, ) -> Pin> + 'a>> { Box::pin(async move { Ok(match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?, #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?, #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => { crate::QueryStream::from((Arc::clone(conn), stmt)) } DatabaseConnection::Disconnected => panic!("Disconnected"), }) }) } async fn begin(&self) -> Result { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await, #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await, #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => { DatabaseTransaction::new_mock(Arc::clone(conn)).await } DatabaseConnection::Disconnected => panic!("Disconnected"), } } /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. async fn transaction(&self, _callback: F) -> Result> where F: for<'c> FnOnce( &'c DatabaseTransaction, ) -> Pin> + Send + 'c>> + Send, T: Send, E: std::error::Error + Send, { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => { conn.transaction(_callback).await } #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await, #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => { let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)) .await .map_err(|e| TransactionError::Connection(e))?; transaction.run(_callback).await } DatabaseConnection::Disconnected => panic!("Disconnected"), } } #[cfg(feature = "mock")] fn is_mock_connection(&self) -> bool { match self { DatabaseConnection::MockDatabaseConnection(_) => true, _ => false, } } } #[cfg(feature = "mock")] impl DatabaseConnection { pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection { match self { DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn, _ => panic!("not mock connection"), } } pub fn into_transaction_log(self) -> Vec { let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap(); mocker.drain_transaction_log() } } impl DbBackend { pub fn is_prefix_of(self, base_url: &str) -> bool { match self { Self::Postgres => { base_url.starts_with("postgres://") || base_url.starts_with("postgresql://") } Self::MySql => base_url.starts_with("mysql://"), Self::Sqlite => base_url.starts_with("sqlite:"), } } pub fn build(&self, statement: &S) -> Statement where S: StatementBuilder, { statement.build(self) } pub fn get_query_builder(&self) -> Box { match self { Self::MySql => Box::new(MysqlQueryBuilder), Self::Postgres => Box::new(PostgresQueryBuilder), Self::Sqlite => Box::new(SqliteQueryBuilder), } } } #[cfg(test)] mod tests { use crate::DatabaseConnection; #[test] fn assert_database_connection_traits() { fn assert_send_sync() {} assert_send_sync::(); } }