diff --git a/Cargo.toml b/Cargo.toml index db5abc69..32589dea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ sqlx = { version = "^0.5", optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true } ouroboros = "0.11" url = "^2.2" +regex = "^1" [dev-dependencies] smol = { version = "^1.2" } diff --git a/src/database/connection.rs b/src/database/connection.rs index be47f9da..e06c6e57 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -45,6 +45,9 @@ pub trait ConnectionTrait<'a>: Sync { T: Send, E: std::error::Error + Send; + /// Check if the connection supports `RETURNING` syntax + fn support_returning(&self) -> bool; + /// Check if the connection is a test connection for the Mock database fn is_mock_connection(&self) -> bool { false diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 681903dd..9038dc20 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -18,7 +18,12 @@ use std::sync::Arc; pub enum DatabaseConnection { /// Create a MYSQL database connection and pool #[cfg(feature = "sqlx-mysql")] - SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection), + SqlxMySqlPoolConnection { + /// A SQLx MySQL pool + conn: crate::SqlxMySqlPoolConnection, + /// A flag indicating whether `RETURNING` syntax is supported + support_returning: bool, + }, /// Create a PostgreSQL database connection and pool #[cfg(feature = "sqlx-postgres")] SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection), @@ -73,7 +78,7 @@ impl std::fmt::Debug for DatabaseConnection { "{}", match self { #[cfg(feature = "sqlx-mysql")] - Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection", + Self::SqlxMySqlPoolConnection { .. } => "SqlxMySqlPoolConnection", #[cfg(feature = "sqlx-postgres")] Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection", #[cfg(feature = "sqlx-sqlite")] @@ -93,7 +98,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { fn get_database_backend(&self) -> DbBackend { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql, + DatabaseConnection::SqlxMySqlPoolConnection { .. } => DbBackend::MySql, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres, #[cfg(feature = "sqlx-sqlite")] @@ -107,7 +112,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { async fn execute(&self, stmt: Statement) -> Result { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await, + DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.execute(stmt).await, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await, #[cfg(feature = "sqlx-sqlite")] @@ -121,7 +126,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { async fn query_one(&self, stmt: Statement) -> Result, DbErr> { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await, + DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.query_one(stmt).await, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await, #[cfg(feature = "sqlx-sqlite")] @@ -135,7 +140,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { async fn query_all(&self, stmt: Statement) -> Result, DbErr> { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await, + DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.query_all(stmt).await, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await, #[cfg(feature = "sqlx-sqlite")] @@ -153,7 +158,9 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { Box::pin(async move { Ok(match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?, + DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => { + conn.stream(stmt).await? + } #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?, #[cfg(feature = "sqlx-sqlite")] @@ -170,7 +177,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { async fn begin(&self) -> Result { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await, + DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.begin().await, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await, #[cfg(feature = "sqlx-sqlite")] @@ -196,7 +203,9 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await, + DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => { + conn.transaction(_callback).await + } #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => { conn.transaction(_callback).await @@ -214,6 +223,24 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { } } + fn support_returning(&self) -> bool { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection { .. } => false, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(_) => true, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(_) => false, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() { + DbBackend::MySql => false, + DbBackend::Postgres => true, + DbBackend::Sqlite => false, + }, + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } + #[cfg(feature = "mock")] fn is_mock_connection(&self) -> bool { matches!(self, DatabaseConnection::MockDatabaseConnection(_)) diff --git a/src/database/transaction.rs b/src/database/transaction.rs index f4a1b678..77394acd 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -347,6 +347,14 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { let transaction = self.begin().await.map_err(TransactionError::Connection)?; transaction.run(_callback).await } + + fn support_returning(&self) -> bool { + match self.backend { + DbBackend::MySql => false, + DbBackend::Postgres => true, + DbBackend::Sqlite => false, + } + } } /// Defines errors for handling transaction failures diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index b2b89c68..b8803edb 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -1,3 +1,4 @@ +use regex::Regex; use std::{future::Future, pin::Pin}; use sqlx::{ @@ -10,7 +11,7 @@ use sea_query_driver_mysql::bind_query; use crate::{ debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, - QueryStream, Statement, TransactionError, + DbBackend, QueryStream, Statement, TransactionError, }; use super::sqlx_common::*; @@ -42,9 +43,7 @@ impl SqlxMySqlConnector { opt.disable_statement_logging(); } if let Ok(pool) = options.pool_options().connect_with(opt).await { - Ok(DatabaseConnection::SqlxMySqlPoolConnection( - SqlxMySqlPoolConnection { pool }, - )) + into_db_connection(pool).await } else { Err(DbErr::Conn("Failed to connect.".to_owned())) } @@ -53,8 +52,8 @@ impl SqlxMySqlConnector { impl SqlxMySqlConnector { /// Instantiate a sqlx pool connection to a [DatabaseConnection] - pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection { - DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection { pool }) + pub async fn from_sqlx_mysql_pool(pool: MySqlPool) -> Result { + into_db_connection(pool).await } } @@ -183,3 +182,43 @@ pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySq } query } + +async fn into_db_connection(pool: MySqlPool) -> Result { + let conn = SqlxMySqlPoolConnection { pool }; + let res = conn + .query_one(Statement::from_string( + DbBackend::MySql, + r#"SHOW VARIABLES LIKE "version""#.to_owned(), + )) + .await?; + let support_returning = if let Some(query_result) = res { + let version: String = query_result.try_get("", "Value")?; + if !version.contains("MariaDB") { + // This is MySQL + false + } else { + // This is MariaDB + let regex = Regex::new(r"^(\d+)?.(\d+)?.(\*|\d+)").unwrap(); + let captures = regex.captures(&version).unwrap(); + macro_rules! parse_captures { + ( $idx: expr ) => { + captures.get($idx).map_or(0, |m| { + m.as_str() + .parse::() + .map_err(|e| DbErr::Conn(e.to_string())) + .unwrap() + }) + }; + } + let ver_major = parse_captures!(1); + let ver_minor = parse_captures!(2); + ver_major >= 10 && ver_minor >= 5 + } + } else { + return Err(DbErr::Conn("Fail to parse MySQL version".to_owned())); + }; + Ok(DatabaseConnection::SqlxMySqlPoolConnection { + conn, + support_returning, + }) +} diff --git a/src/executor/insert.rs b/src/executor/insert.rs index 6117e782..9f371373 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,6 +1,6 @@ use crate::{ - error::*, ActiveModelTrait, ConnectionTrait, DbBackend, EntityTrait, Insert, IntoActiveModel, - Iterable, PrimaryKeyTrait, SelectModel, SelectorRaw, Statement, TryFromU64, + error::*, ActiveModelTrait, ConnectionTrait, EntityTrait, Insert, IntoActiveModel, Iterable, + PrimaryKeyTrait, SelectModel, SelectorRaw, Statement, TryFromU64, }; use sea_query::{FromValueTuple, Iden, InsertStatement, IntoColumnRef, Returning, ValueTuple}; use std::{future::Future, marker::PhantomData}; @@ -39,9 +39,7 @@ where { // so that self is dropped before entering await let mut query = self.query; - if db.get_database_backend() == DbBackend::Postgres - && ::PrimaryKey::iter().count() > 0 - { + if db.support_returning() && ::PrimaryKey::iter().count() > 0 { query.returning(Returning::Columns( ::PrimaryKey::iter() .map(|c| c.into_column_ref()) @@ -113,15 +111,15 @@ where { type PrimaryKey = <::Entity as EntityTrait>::PrimaryKey; type ValueTypeOf = as PrimaryKeyTrait>::ValueType; - let last_insert_id_opt = match db.get_database_backend() { - DbBackend::Postgres => { + let last_insert_id_opt = match db.support_returning() { + true => { let cols = PrimaryKey::::iter() .map(|col| col.to_string()) .collect::>(); let res = db.query_one(statement).await?.unwrap(); res.try_get_many("", cols.as_ref()).ok() } - _ => { + false => { let last_insert_id = db.execute(statement).await?.last_insert_id(); ValueTypeOf::::try_from_u64(last_insert_id).ok() } @@ -147,8 +145,8 @@ where A: ActiveModelTrait, { let db_backend = db.get_database_backend(); - let found = match db_backend { - DbBackend::Postgres => { + let found = match db.support_returning() { + true => { insert_statement.returning(Returning::Columns( ::Column::iter() .map(|c| c.into_column_ref()) @@ -160,7 +158,7 @@ where .one(db) .await? } - _ => { + false => { let insert_res = exec_insert::(primary_key, db_backend.build(&insert_statement), db).await?; ::find_by_id(insert_res.last_insert_id) diff --git a/src/executor/update.rs b/src/executor/update.rs index c16d4644..f83e8efb 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -1,5 +1,5 @@ use crate::{ - error::*, ActiveModelTrait, ConnectionTrait, DbBackend, EntityTrait, IntoActiveModel, Iterable, + error::*, ActiveModelTrait, ConnectionTrait, EntityTrait, IntoActiveModel, Iterable, SelectModel, SelectorRaw, Statement, UpdateMany, UpdateOne, }; use sea_query::{FromValueTuple, IntoColumnRef, Returning, UpdateStatement}; @@ -90,14 +90,14 @@ where A: ActiveModelTrait, C: ConnectionTrait<'a>, { - let db_backend = db.get_database_backend(); - match db_backend { - DbBackend::Postgres => { + match db.support_returning() { + true => { query.returning(Returning::Columns( ::Column::iter() .map(|c| c.into_column_ref()) .collect(), )); + let db_backend = db.get_database_backend(); let found: Option<::Model> = SelectorRaw::::Model>>::from_statement( db_backend.build(&query), @@ -112,7 +112,7 @@ where )), } } - _ => { + false => { // If we updating a row that does not exist then an error will be thrown here. Updater::new(query).check_record_exists().exec(db).await?; let primary_key_value = match model.get_primary_key_value() {