diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index f04b67c7..2f78bdd7 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -18,14 +18,7 @@ use std::sync::Arc; pub enum DatabaseConnection { /// Create a MYSQL database connection and pool #[cfg(feature = "sqlx-mysql")] - SqlxMySqlPoolConnection { - /// The SQLx MySQL pool - conn: crate::SqlxMySqlPoolConnection, - /// The MySQL version - version: String, - /// The flag indicating whether `RETURNING` syntax is supported - support_returning: bool, - }, + SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection), /// Create a PostgreSQL database connection and pool #[cfg(feature = "sqlx-postgres")] SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection), @@ -80,7 +73,7 @@ impl std::fmt::Debug for DatabaseConnection { "{}", match self { #[cfg(feature = "sqlx-mysql")] - Self::SqlxMySqlPoolConnection { .. } => "SqlxMySqlPoolConnection", + Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection", #[cfg(feature = "sqlx-postgres")] Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection", #[cfg(feature = "sqlx-sqlite")] @@ -100,7 +93,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { fn get_database_backend(&self) -> DbBackend { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { .. } => DbBackend::MySql, + DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres, #[cfg(feature = "sqlx-sqlite")] @@ -114,7 +107,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { async fn execute(&self, stmt: Statement) -> Result { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.execute(stmt).await, + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await, #[cfg(feature = "sqlx-sqlite")] @@ -128,7 +121,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { async fn query_one(&self, stmt: Statement) -> Result, DbErr> { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.query_one(stmt).await, + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await, #[cfg(feature = "sqlx-sqlite")] @@ -142,7 +135,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { async fn query_all(&self, stmt: Statement) -> Result, DbErr> { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.query_all(stmt).await, + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await, #[cfg(feature = "sqlx-sqlite")] @@ -160,9 +153,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { Box::pin(async move { Ok(match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => { - conn.stream(stmt).await? - } + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?, #[cfg(feature = "sqlx-sqlite")] @@ -179,7 +170,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { async fn begin(&self) -> Result { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.begin().await, + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await, #[cfg(feature = "sqlx-sqlite")] @@ -205,9 +196,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => { - conn.transaction(_callback).await - } + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => { conn.transaction(_callback).await @@ -228,12 +217,10 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { fn returning_on_insert(&self) -> bool { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { - support_returning, .. - } => { + DatabaseConnection::SqlxMySqlPoolConnection(conn) => { // Supported if it's MariaDB on or after version 10.5.0 // Not supported in all MySQL versions - *support_returning + conn.support_returning } #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(_) => { @@ -258,7 +245,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { fn returning_on_update(&self) -> bool { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { .. } => { + DatabaseConnection::SqlxMySqlPoolConnection(_) => { // Not supported in all MySQL & MariaDB versions false } @@ -310,7 +297,7 @@ impl DatabaseConnection { pub fn version(&self) -> String { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { version, .. } => version.to_string(), + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.version.to_string(), #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(_) => "".to_string(), #[cfg(feature = "sqlx-sqlite")] diff --git a/src/database/transaction.rs b/src/database/transaction.rs index 727f9bc6..71573f0d 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -16,6 +16,7 @@ pub struct DatabaseTransaction { conn: Arc>, backend: DbBackend, open: bool, + support_returning: bool, } impl std::fmt::Debug for DatabaseTransaction { @@ -28,10 +29,12 @@ impl DatabaseTransaction { #[cfg(feature = "sqlx-mysql")] pub(crate) async fn new_mysql( inner: PoolConnection, + support_returning: bool, ) -> Result { Self::begin( Arc::new(Mutex::new(InnerConnection::MySql(inner))), DbBackend::MySql, + support_returning, ) .await } @@ -43,6 +46,7 @@ impl DatabaseTransaction { Self::begin( Arc::new(Mutex::new(InnerConnection::Postgres(inner))), DbBackend::Postgres, + true, ) .await } @@ -54,6 +58,7 @@ impl DatabaseTransaction { Self::begin( Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), DbBackend::Sqlite, + false, ) .await } @@ -63,17 +68,28 @@ impl DatabaseTransaction { inner: Arc, ) -> Result { let backend = inner.get_database_backend(); - Self::begin(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await + Self::begin( + Arc::new(Mutex::new(InnerConnection::Mock(inner))), + backend, + match backend { + DbBackend::MySql => false, + DbBackend::Postgres => true, + DbBackend::Sqlite => false, + }, + ) + .await } async fn begin( conn: Arc>, backend: DbBackend, + support_returning: bool, ) -> Result { let res = DatabaseTransaction { conn, backend, open: true, + support_returning, }; match *res.conn.lock().await { #[cfg(feature = "sqlx-mysql")] @@ -330,7 +346,8 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { } async fn begin(&self) -> Result { - DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend).await + DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend, self.support_returning) + .await } /// Execute the function inside a transaction. @@ -349,13 +366,38 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { } fn returning_on_insert(&self) -> bool { - // FIXME: How? - false + match self.backend { + DbBackend::MySql => { + // Supported if it's MariaDB on or after version 10.5.0 + // Not supported in all MySQL versions + self.support_returning + } + DbBackend::Postgres => { + // Supported by all Postgres versions + true + } + DbBackend::Sqlite => { + // Supported by SQLite on or after version 3.35.0 (2021-03-12) + false + } + } } fn returning_on_update(&self) -> bool { - // FIXME: How? - false + match self.backend { + DbBackend::MySql => { + // Not supported in all MySQL & MariaDB versions + false + } + DbBackend::Postgres => { + // Supported by all Postgres versions + true + } + DbBackend::Sqlite => { + // Supported by SQLite on or after version 3.35.0 (2021-03-12) + false + } + } } } diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index 55f46a94..e31307ee 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -3,7 +3,7 @@ use std::{future::Future, pin::Pin}; use sqlx::{ mysql::{MySqlArguments, MySqlConnectOptions, MySqlQueryResult, MySqlRow}, - MySql, MySqlPool, + MySql, MySqlPool, Row, }; sea_query::sea_query_driver_mysql!(); @@ -24,6 +24,8 @@ pub struct SqlxMySqlConnector; #[derive(Debug, Clone)] pub struct SqlxMySqlPoolConnection { pool: MySqlPool, + pub(crate) version: String, + pub(crate) support_returning: bool, } impl SqlxMySqlConnector { @@ -128,7 +130,7 @@ impl SqlxMySqlPoolConnection { /// Bundle a set of SQL statements that execute together. pub async fn begin(&self) -> Result { if let Ok(conn) = self.pool.acquire().await { - DatabaseTransaction::new_mysql(conn).await + DatabaseTransaction::new_mysql(conn, self.support_returning).await } else { Err(DbErr::Query( "Failed to acquire connection from pool.".to_owned(), @@ -147,7 +149,7 @@ impl SqlxMySqlPoolConnection { E: std::error::Error + Send, { if let Ok(conn) = self.pool.acquire().await { - let transaction = DatabaseTransaction::new_mysql(conn) + let transaction = DatabaseTransaction::new_mysql(conn, self.support_returning) .await .map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await @@ -184,45 +186,49 @@ pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySq } async fn into_db_connection(pool: MySqlPool) -> Result { - let conn = SqlxMySqlPoolConnection { pool }; - let res = conn - .query_one(Statement::from_string( - DbBackend::MySql, - r#"SHOW VARIABLES LIKE "version""#.to_owned(), - )) - .await?; - let (version, support_returning) = if let Some(query_result) = res { - let version: String = query_result.try_get("", "Value")?; - let support_returning = if !version.contains("MariaDB") { - // This is MySQL - // Not supported in all MySQL versions - false - } else { - // This is MariaDB - let regex = Regex::new(r"^(\d+)?.(\d+)?.(\*|\d+)").unwrap(); - let captures = regex.captures(&version).unwrap(); - macro_rules! parse_captures { - ( $idx: expr ) => { - captures.get($idx).map_or(0, |m| { - m.as_str() - .parse::() - .map_err(|e| DbErr::Conn(e.to_string())) - .unwrap() - }) - }; - } - let ver_major = parse_captures!(1); - let ver_minor = parse_captures!(2); - // Supported if it's MariaDB with version 10.5.0 or after - ver_major >= 10 && ver_minor >= 5 - }; - (version, support_returning) - } else { - return Err(DbErr::Conn("Fail to parse MySQL version".to_owned())); - }; - Ok(DatabaseConnection::SqlxMySqlPoolConnection { - conn, - version, - support_returning, - }) + let (version, support_returning) = parse_support_returning(&pool).await?; + Ok(DatabaseConnection::SqlxMySqlPoolConnection( + SqlxMySqlPoolConnection { + pool, + version, + support_returning, + }, + )) +} + +async fn parse_support_returning(pool: &MySqlPool) -> Result<(String, bool), DbErr> { + let stmt = Statement::from_string( + DbBackend::MySql, + r#"SHOW VARIABLES LIKE "version""#.to_owned(), + ); + let query = sqlx_query(&stmt); + let row = query + .fetch_one(pool) + .await + .map_err(sqlx_error_to_query_err)?; + let version: String = row.try_get("Value").map_err(sqlx_error_to_query_err)?; + let support_returning = if !version.contains("MariaDB") { + // This is MySQL + // Not supported in all MySQL versions + false + } else { + // This is MariaDB + let regex = Regex::new(r"^(\d+)?.(\d+)?.(\*|\d+)").unwrap(); + let captures = regex.captures(&version).unwrap(); + macro_rules! parse_captures { + ( $idx: expr ) => { + captures.get($idx).map_or(0, |m| { + m.as_str() + .parse::() + .map_err(|e| DbErr::Conn(e.to_string())) + .unwrap() + }) + }; + } + let ver_major = parse_captures!(1); + let ver_minor = parse_captures!(2); + // Supported if it's MariaDB with version 10.5.0 or after + ver_major >= 10 && ver_minor >= 5 + }; + Ok((version, support_returning)) }