This commit is contained in:
Billy Chan 2021-11-09 15:38:28 +08:00
parent 8020ae1209
commit 533c3cf175
No known key found for this signature in database
GPG Key ID: A2D690CAC7DF3CC7
3 changed files with 111 additions and 76 deletions

View File

@ -18,14 +18,7 @@ use std::sync::Arc;
pub enum DatabaseConnection {
/// Create a MYSQL database connection and pool
#[cfg(feature = "sqlx-mysql")]
SqlxMySqlPoolConnection {
/// The SQLx MySQL pool
conn: crate::SqlxMySqlPoolConnection,
/// The MySQL version
version: String,
/// The flag indicating whether `RETURNING` syntax is supported
support_returning: bool,
},
SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
/// Create a PostgreSQL database connection and pool
#[cfg(feature = "sqlx-postgres")]
SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
@ -80,7 +73,7 @@ impl std::fmt::Debug for DatabaseConnection {
"{}",
match self {
#[cfg(feature = "sqlx-mysql")]
Self::SqlxMySqlPoolConnection { .. } => "SqlxMySqlPoolConnection",
Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
#[cfg(feature = "sqlx-postgres")]
Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
#[cfg(feature = "sqlx-sqlite")]
@ -100,7 +93,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
fn get_database_backend(&self) -> DbBackend {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { .. } => DbBackend::MySql,
DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
#[cfg(feature = "sqlx-sqlite")]
@ -114,7 +107,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.execute(stmt).await,
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await,
#[cfg(feature = "sqlx-sqlite")]
@ -128,7 +121,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.query_one(stmt).await,
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await,
#[cfg(feature = "sqlx-sqlite")]
@ -142,7 +135,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.query_all(stmt).await,
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await,
#[cfg(feature = "sqlx-sqlite")]
@ -160,9 +153,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
Box::pin(async move {
Ok(match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => {
conn.stream(stmt).await?
}
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "sqlx-sqlite")]
@ -179,7 +170,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.begin().await,
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await,
#[cfg(feature = "sqlx-sqlite")]
@ -205,9 +196,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
{
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => {
conn.transaction(_callback).await
}
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => {
conn.transaction(_callback).await
@ -228,12 +217,10 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
fn returning_on_insert(&self) -> bool {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection {
support_returning, ..
} => {
DatabaseConnection::SqlxMySqlPoolConnection(conn) => {
// Supported if it's MariaDB on or after version 10.5.0
// Not supported in all MySQL versions
*support_returning
conn.support_returning
}
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => {
@ -258,7 +245,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
fn returning_on_update(&self) -> bool {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { .. } => {
DatabaseConnection::SqlxMySqlPoolConnection(_) => {
// Not supported in all MySQL & MariaDB versions
false
}
@ -310,7 +297,7 @@ impl DatabaseConnection {
pub fn version(&self) -> String {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { version, .. } => version.to_string(),
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.version.to_string(),
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => "".to_string(),
#[cfg(feature = "sqlx-sqlite")]

View File

@ -16,6 +16,7 @@ pub struct DatabaseTransaction {
conn: Arc<Mutex<InnerConnection>>,
backend: DbBackend,
open: bool,
support_returning: bool,
}
impl std::fmt::Debug for DatabaseTransaction {
@ -28,10 +29,12 @@ impl DatabaseTransaction {
#[cfg(feature = "sqlx-mysql")]
pub(crate) async fn new_mysql(
inner: PoolConnection<sqlx::MySql>,
support_returning: bool,
) -> Result<DatabaseTransaction, DbErr> {
Self::begin(
Arc::new(Mutex::new(InnerConnection::MySql(inner))),
DbBackend::MySql,
support_returning,
)
.await
}
@ -43,6 +46,7 @@ impl DatabaseTransaction {
Self::begin(
Arc::new(Mutex::new(InnerConnection::Postgres(inner))),
DbBackend::Postgres,
true,
)
.await
}
@ -54,6 +58,7 @@ impl DatabaseTransaction {
Self::begin(
Arc::new(Mutex::new(InnerConnection::Sqlite(inner))),
DbBackend::Sqlite,
false,
)
.await
}
@ -63,17 +68,28 @@ impl DatabaseTransaction {
inner: Arc<crate::MockDatabaseConnection>,
) -> Result<DatabaseTransaction, DbErr> {
let backend = inner.get_database_backend();
Self::begin(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await
Self::begin(
Arc::new(Mutex::new(InnerConnection::Mock(inner))),
backend,
match backend {
DbBackend::MySql => false,
DbBackend::Postgres => true,
DbBackend::Sqlite => false,
},
)
.await
}
async fn begin(
conn: Arc<Mutex<InnerConnection>>,
backend: DbBackend,
support_returning: bool,
) -> Result<DatabaseTransaction, DbErr> {
let res = DatabaseTransaction {
conn,
backend,
open: true,
support_returning,
};
match *res.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
@ -330,7 +346,8 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
}
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend).await
DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend, self.support_returning)
.await
}
/// Execute the function inside a transaction.
@ -349,14 +366,39 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
}
fn returning_on_insert(&self) -> bool {
// FIXME: How?
match self.backend {
DbBackend::MySql => {
// Supported if it's MariaDB on or after version 10.5.0
// Not supported in all MySQL versions
self.support_returning
}
DbBackend::Postgres => {
// Supported by all Postgres versions
true
}
DbBackend::Sqlite => {
// Supported by SQLite on or after version 3.35.0 (2021-03-12)
false
}
}
}
fn returning_on_update(&self) -> bool {
// FIXME: How?
match self.backend {
DbBackend::MySql => {
// Not supported in all MySQL & MariaDB versions
false
}
DbBackend::Postgres => {
// Supported by all Postgres versions
true
}
DbBackend::Sqlite => {
// Supported by SQLite on or after version 3.35.0 (2021-03-12)
false
}
}
}
}
/// Defines errors for handling transaction failures

View File

@ -3,7 +3,7 @@ use std::{future::Future, pin::Pin};
use sqlx::{
mysql::{MySqlArguments, MySqlConnectOptions, MySqlQueryResult, MySqlRow},
MySql, MySqlPool,
MySql, MySqlPool, Row,
};
sea_query::sea_query_driver_mysql!();
@ -24,6 +24,8 @@ pub struct SqlxMySqlConnector;
#[derive(Debug, Clone)]
pub struct SqlxMySqlPoolConnection {
pool: MySqlPool,
pub(crate) version: String,
pub(crate) support_returning: bool,
}
impl SqlxMySqlConnector {
@ -128,7 +130,7 @@ impl SqlxMySqlPoolConnection {
/// Bundle a set of SQL statements that execute together.
pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
if let Ok(conn) = self.pool.acquire().await {
DatabaseTransaction::new_mysql(conn).await
DatabaseTransaction::new_mysql(conn, self.support_returning).await
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
@ -147,7 +149,7 @@ impl SqlxMySqlPoolConnection {
E: std::error::Error + Send,
{
if let Ok(conn) = self.pool.acquire().await {
let transaction = DatabaseTransaction::new_mysql(conn)
let transaction = DatabaseTransaction::new_mysql(conn, self.support_returning)
.await
.map_err(|e| TransactionError::Connection(e))?;
transaction.run(callback).await
@ -184,15 +186,27 @@ pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySq
}
async fn into_db_connection(pool: MySqlPool) -> Result<DatabaseConnection, DbErr> {
let conn = SqlxMySqlPoolConnection { pool };
let res = conn
.query_one(Statement::from_string(
let (version, support_returning) = parse_support_returning(&pool).await?;
Ok(DatabaseConnection::SqlxMySqlPoolConnection(
SqlxMySqlPoolConnection {
pool,
version,
support_returning,
},
))
}
async fn parse_support_returning(pool: &MySqlPool) -> Result<(String, bool), DbErr> {
let stmt = Statement::from_string(
DbBackend::MySql,
r#"SHOW VARIABLES LIKE "version""#.to_owned(),
))
.await?;
let (version, support_returning) = if let Some(query_result) = res {
let version: String = query_result.try_get("", "Value")?;
);
let query = sqlx_query(&stmt);
let row = query
.fetch_one(pool)
.await
.map_err(sqlx_error_to_query_err)?;
let version: String = row.try_get("Value").map_err(sqlx_error_to_query_err)?;
let support_returning = if !version.contains("MariaDB") {
// This is MySQL
// Not supported in all MySQL versions
@ -216,13 +230,5 @@ async fn into_db_connection(pool: MySqlPool) -> Result<DatabaseConnection, DbErr
// Supported if it's MariaDB with version 10.5.0 or after
ver_major >= 10 && ver_minor >= 5
};
(version, support_returning)
} else {
return Err(DbErr::Conn("Fail to parse MySQL version".to_owned()));
};
Ok(DatabaseConnection::SqlxMySqlPoolConnection {
conn,
version,
support_returning,
})
Ok((version, support_returning))
}