This commit is contained in:
Billy Chan 2021-11-09 15:38:28 +08:00
parent 8020ae1209
commit 533c3cf175
No known key found for this signature in database
GPG Key ID: A2D690CAC7DF3CC7
3 changed files with 111 additions and 76 deletions

View File

@ -18,14 +18,7 @@ use std::sync::Arc;
pub enum DatabaseConnection { pub enum DatabaseConnection {
/// Create a MYSQL database connection and pool /// Create a MYSQL database connection and pool
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
SqlxMySqlPoolConnection { SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
/// The SQLx MySQL pool
conn: crate::SqlxMySqlPoolConnection,
/// The MySQL version
version: String,
/// The flag indicating whether `RETURNING` syntax is supported
support_returning: bool,
},
/// Create a PostgreSQL database connection and pool /// Create a PostgreSQL database connection and pool
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection), SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
@ -80,7 +73,7 @@ impl std::fmt::Debug for DatabaseConnection {
"{}", "{}",
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
Self::SqlxMySqlPoolConnection { .. } => "SqlxMySqlPoolConnection", Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection", Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
@ -100,7 +93,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
fn get_database_backend(&self) -> DbBackend { fn get_database_backend(&self) -> DbBackend {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { .. } => DbBackend::MySql, DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres, DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
@ -114,7 +107,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> { async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.execute(stmt).await, DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await,
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await, DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await,
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
@ -128,7 +121,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> { async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.query_one(stmt).await, DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await,
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await, DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await,
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
@ -142,7 +135,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> { async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.query_all(stmt).await, DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await,
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await, DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await,
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
@ -160,9 +153,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
Box::pin(async move { Box::pin(async move {
Ok(match self { Ok(match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => { DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?,
conn.stream(stmt).await?
}
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?, DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
@ -179,7 +170,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> { async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.begin().await, DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await,
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await, DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await,
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
@ -205,9 +196,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
{ {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => { DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await,
conn.transaction(_callback).await
}
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => { DatabaseConnection::SqlxPostgresPoolConnection(conn) => {
conn.transaction(_callback).await conn.transaction(_callback).await
@ -228,12 +217,10 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
fn returning_on_insert(&self) -> bool { fn returning_on_insert(&self) -> bool {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { DatabaseConnection::SqlxMySqlPoolConnection(conn) => {
support_returning, ..
} => {
// Supported if it's MariaDB on or after version 10.5.0 // Supported if it's MariaDB on or after version 10.5.0
// Not supported in all MySQL versions // Not supported in all MySQL versions
*support_returning conn.support_returning
} }
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => { DatabaseConnection::SqlxPostgresPoolConnection(_) => {
@ -258,7 +245,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
fn returning_on_update(&self) -> bool { fn returning_on_update(&self) -> bool {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { .. } => { DatabaseConnection::SqlxMySqlPoolConnection(_) => {
// Not supported in all MySQL & MariaDB versions // Not supported in all MySQL & MariaDB versions
false false
} }
@ -310,7 +297,7 @@ impl DatabaseConnection {
pub fn version(&self) -> String { pub fn version(&self) -> String {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { version, .. } => version.to_string(), DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.version.to_string(),
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => "".to_string(), DatabaseConnection::SqlxPostgresPoolConnection(_) => "".to_string(),
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]

View File

@ -16,6 +16,7 @@ pub struct DatabaseTransaction {
conn: Arc<Mutex<InnerConnection>>, conn: Arc<Mutex<InnerConnection>>,
backend: DbBackend, backend: DbBackend,
open: bool, open: bool,
support_returning: bool,
} }
impl std::fmt::Debug for DatabaseTransaction { impl std::fmt::Debug for DatabaseTransaction {
@ -28,10 +29,12 @@ impl DatabaseTransaction {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
pub(crate) async fn new_mysql( pub(crate) async fn new_mysql(
inner: PoolConnection<sqlx::MySql>, inner: PoolConnection<sqlx::MySql>,
support_returning: bool,
) -> Result<DatabaseTransaction, DbErr> { ) -> Result<DatabaseTransaction, DbErr> {
Self::begin( Self::begin(
Arc::new(Mutex::new(InnerConnection::MySql(inner))), Arc::new(Mutex::new(InnerConnection::MySql(inner))),
DbBackend::MySql, DbBackend::MySql,
support_returning,
) )
.await .await
} }
@ -43,6 +46,7 @@ impl DatabaseTransaction {
Self::begin( Self::begin(
Arc::new(Mutex::new(InnerConnection::Postgres(inner))), Arc::new(Mutex::new(InnerConnection::Postgres(inner))),
DbBackend::Postgres, DbBackend::Postgres,
true,
) )
.await .await
} }
@ -54,6 +58,7 @@ impl DatabaseTransaction {
Self::begin( Self::begin(
Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), Arc::new(Mutex::new(InnerConnection::Sqlite(inner))),
DbBackend::Sqlite, DbBackend::Sqlite,
false,
) )
.await .await
} }
@ -63,17 +68,28 @@ impl DatabaseTransaction {
inner: Arc<crate::MockDatabaseConnection>, inner: Arc<crate::MockDatabaseConnection>,
) -> Result<DatabaseTransaction, DbErr> { ) -> Result<DatabaseTransaction, DbErr> {
let backend = inner.get_database_backend(); let backend = inner.get_database_backend();
Self::begin(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await Self::begin(
Arc::new(Mutex::new(InnerConnection::Mock(inner))),
backend,
match backend {
DbBackend::MySql => false,
DbBackend::Postgres => true,
DbBackend::Sqlite => false,
},
)
.await
} }
async fn begin( async fn begin(
conn: Arc<Mutex<InnerConnection>>, conn: Arc<Mutex<InnerConnection>>,
backend: DbBackend, backend: DbBackend,
support_returning: bool,
) -> Result<DatabaseTransaction, DbErr> { ) -> Result<DatabaseTransaction, DbErr> {
let res = DatabaseTransaction { let res = DatabaseTransaction {
conn, conn,
backend, backend,
open: true, open: true,
support_returning,
}; };
match *res.conn.lock().await { match *res.conn.lock().await {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
@ -330,7 +346,8 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
} }
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> { async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend).await DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend, self.support_returning)
.await
} }
/// Execute the function inside a transaction. /// Execute the function inside a transaction.
@ -349,14 +366,39 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
} }
fn returning_on_insert(&self) -> bool { fn returning_on_insert(&self) -> bool {
// FIXME: How? match self.backend {
DbBackend::MySql => {
// Supported if it's MariaDB on or after version 10.5.0
// Not supported in all MySQL versions
self.support_returning
}
DbBackend::Postgres => {
// Supported by all Postgres versions
true
}
DbBackend::Sqlite => {
// Supported by SQLite on or after version 3.35.0 (2021-03-12)
false false
} }
}
}
fn returning_on_update(&self) -> bool { fn returning_on_update(&self) -> bool {
// FIXME: How? match self.backend {
DbBackend::MySql => {
// Not supported in all MySQL & MariaDB versions
false false
} }
DbBackend::Postgres => {
// Supported by all Postgres versions
true
}
DbBackend::Sqlite => {
// Supported by SQLite on or after version 3.35.0 (2021-03-12)
false
}
}
}
} }
/// Defines errors for handling transaction failures /// Defines errors for handling transaction failures

View File

@ -3,7 +3,7 @@ use std::{future::Future, pin::Pin};
use sqlx::{ use sqlx::{
mysql::{MySqlArguments, MySqlConnectOptions, MySqlQueryResult, MySqlRow}, mysql::{MySqlArguments, MySqlConnectOptions, MySqlQueryResult, MySqlRow},
MySql, MySqlPool, MySql, MySqlPool, Row,
}; };
sea_query::sea_query_driver_mysql!(); sea_query::sea_query_driver_mysql!();
@ -24,6 +24,8 @@ pub struct SqlxMySqlConnector;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct SqlxMySqlPoolConnection { pub struct SqlxMySqlPoolConnection {
pool: MySqlPool, pool: MySqlPool,
pub(crate) version: String,
pub(crate) support_returning: bool,
} }
impl SqlxMySqlConnector { impl SqlxMySqlConnector {
@ -128,7 +130,7 @@ impl SqlxMySqlPoolConnection {
/// Bundle a set of SQL statements that execute together. /// Bundle a set of SQL statements that execute together.
pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> { pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
if let Ok(conn) = self.pool.acquire().await { if let Ok(conn) = self.pool.acquire().await {
DatabaseTransaction::new_mysql(conn).await DatabaseTransaction::new_mysql(conn, self.support_returning).await
} else { } else {
Err(DbErr::Query( Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(), "Failed to acquire connection from pool.".to_owned(),
@ -147,7 +149,7 @@ impl SqlxMySqlPoolConnection {
E: std::error::Error + Send, E: std::error::Error + Send,
{ {
if let Ok(conn) = self.pool.acquire().await { if let Ok(conn) = self.pool.acquire().await {
let transaction = DatabaseTransaction::new_mysql(conn) let transaction = DatabaseTransaction::new_mysql(conn, self.support_returning)
.await .await
.map_err(|e| TransactionError::Connection(e))?; .map_err(|e| TransactionError::Connection(e))?;
transaction.run(callback).await transaction.run(callback).await
@ -184,15 +186,27 @@ pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySq
} }
async fn into_db_connection(pool: MySqlPool) -> Result<DatabaseConnection, DbErr> { async fn into_db_connection(pool: MySqlPool) -> Result<DatabaseConnection, DbErr> {
let conn = SqlxMySqlPoolConnection { pool }; let (version, support_returning) = parse_support_returning(&pool).await?;
let res = conn Ok(DatabaseConnection::SqlxMySqlPoolConnection(
.query_one(Statement::from_string( SqlxMySqlPoolConnection {
pool,
version,
support_returning,
},
))
}
async fn parse_support_returning(pool: &MySqlPool) -> Result<(String, bool), DbErr> {
let stmt = Statement::from_string(
DbBackend::MySql, DbBackend::MySql,
r#"SHOW VARIABLES LIKE "version""#.to_owned(), r#"SHOW VARIABLES LIKE "version""#.to_owned(),
)) );
.await?; let query = sqlx_query(&stmt);
let (version, support_returning) = if let Some(query_result) = res { let row = query
let version: String = query_result.try_get("", "Value")?; .fetch_one(pool)
.await
.map_err(sqlx_error_to_query_err)?;
let version: String = row.try_get("Value").map_err(sqlx_error_to_query_err)?;
let support_returning = if !version.contains("MariaDB") { let support_returning = if !version.contains("MariaDB") {
// This is MySQL // This is MySQL
// Not supported in all MySQL versions // Not supported in all MySQL versions
@ -216,13 +230,5 @@ async fn into_db_connection(pool: MySqlPool) -> Result<DatabaseConnection, DbErr
// Supported if it's MariaDB with version 10.5.0 or after // Supported if it's MariaDB with version 10.5.0 or after
ver_major >= 10 && ver_minor >= 5 ver_major >= 10 && ver_minor >= 5
}; };
(version, support_returning) Ok((version, support_returning))
} else {
return Err(DbErr::Conn("Fail to parse MySQL version".to_owned()));
};
Ok(DatabaseConnection::SqlxMySqlPoolConnection {
conn,
version,
support_returning,
})
} }