Try
This commit is contained in:
parent
8020ae1209
commit
533c3cf175
@ -18,14 +18,7 @@ use std::sync::Arc;
|
||||
pub enum DatabaseConnection {
|
||||
/// Create a MYSQL database connection and pool
|
||||
#[cfg(feature = "sqlx-mysql")]
|
||||
SqlxMySqlPoolConnection {
|
||||
/// The SQLx MySQL pool
|
||||
conn: crate::SqlxMySqlPoolConnection,
|
||||
/// The MySQL version
|
||||
version: String,
|
||||
/// The flag indicating whether `RETURNING` syntax is supported
|
||||
support_returning: bool,
|
||||
},
|
||||
SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
|
||||
/// Create a PostgreSQL database connection and pool
|
||||
#[cfg(feature = "sqlx-postgres")]
|
||||
SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
|
||||
@ -80,7 +73,7 @@ impl std::fmt::Debug for DatabaseConnection {
|
||||
"{}",
|
||||
match self {
|
||||
#[cfg(feature = "sqlx-mysql")]
|
||||
Self::SqlxMySqlPoolConnection { .. } => "SqlxMySqlPoolConnection",
|
||||
Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
|
||||
#[cfg(feature = "sqlx-postgres")]
|
||||
Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
|
||||
#[cfg(feature = "sqlx-sqlite")]
|
||||
@ -100,7 +93,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
|
||||
fn get_database_backend(&self) -> DbBackend {
|
||||
match self {
|
||||
#[cfg(feature = "sqlx-mysql")]
|
||||
DatabaseConnection::SqlxMySqlPoolConnection { .. } => DbBackend::MySql,
|
||||
DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
|
||||
#[cfg(feature = "sqlx-postgres")]
|
||||
DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
|
||||
#[cfg(feature = "sqlx-sqlite")]
|
||||
@ -114,7 +107,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
|
||||
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
|
||||
match self {
|
||||
#[cfg(feature = "sqlx-mysql")]
|
||||
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.execute(stmt).await,
|
||||
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await,
|
||||
#[cfg(feature = "sqlx-postgres")]
|
||||
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await,
|
||||
#[cfg(feature = "sqlx-sqlite")]
|
||||
@ -128,7 +121,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
|
||||
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
|
||||
match self {
|
||||
#[cfg(feature = "sqlx-mysql")]
|
||||
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.query_one(stmt).await,
|
||||
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await,
|
||||
#[cfg(feature = "sqlx-postgres")]
|
||||
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await,
|
||||
#[cfg(feature = "sqlx-sqlite")]
|
||||
@ -142,7 +135,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
|
||||
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
|
||||
match self {
|
||||
#[cfg(feature = "sqlx-mysql")]
|
||||
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.query_all(stmt).await,
|
||||
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await,
|
||||
#[cfg(feature = "sqlx-postgres")]
|
||||
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await,
|
||||
#[cfg(feature = "sqlx-sqlite")]
|
||||
@ -160,9 +153,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
|
||||
Box::pin(async move {
|
||||
Ok(match self {
|
||||
#[cfg(feature = "sqlx-mysql")]
|
||||
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => {
|
||||
conn.stream(stmt).await?
|
||||
}
|
||||
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?,
|
||||
#[cfg(feature = "sqlx-postgres")]
|
||||
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?,
|
||||
#[cfg(feature = "sqlx-sqlite")]
|
||||
@ -179,7 +170,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
|
||||
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
|
||||
match self {
|
||||
#[cfg(feature = "sqlx-mysql")]
|
||||
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.begin().await,
|
||||
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await,
|
||||
#[cfg(feature = "sqlx-postgres")]
|
||||
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await,
|
||||
#[cfg(feature = "sqlx-sqlite")]
|
||||
@ -205,9 +196,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
|
||||
{
|
||||
match self {
|
||||
#[cfg(feature = "sqlx-mysql")]
|
||||
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => {
|
||||
conn.transaction(_callback).await
|
||||
}
|
||||
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await,
|
||||
#[cfg(feature = "sqlx-postgres")]
|
||||
DatabaseConnection::SqlxPostgresPoolConnection(conn) => {
|
||||
conn.transaction(_callback).await
|
||||
@ -228,12 +217,10 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
|
||||
fn returning_on_insert(&self) -> bool {
|
||||
match self {
|
||||
#[cfg(feature = "sqlx-mysql")]
|
||||
DatabaseConnection::SqlxMySqlPoolConnection {
|
||||
support_returning, ..
|
||||
} => {
|
||||
DatabaseConnection::SqlxMySqlPoolConnection(conn) => {
|
||||
// Supported if it's MariaDB on or after version 10.5.0
|
||||
// Not supported in all MySQL versions
|
||||
*support_returning
|
||||
conn.support_returning
|
||||
}
|
||||
#[cfg(feature = "sqlx-postgres")]
|
||||
DatabaseConnection::SqlxPostgresPoolConnection(_) => {
|
||||
@ -258,7 +245,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
|
||||
fn returning_on_update(&self) -> bool {
|
||||
match self {
|
||||
#[cfg(feature = "sqlx-mysql")]
|
||||
DatabaseConnection::SqlxMySqlPoolConnection { .. } => {
|
||||
DatabaseConnection::SqlxMySqlPoolConnection(_) => {
|
||||
// Not supported in all MySQL & MariaDB versions
|
||||
false
|
||||
}
|
||||
@ -310,7 +297,7 @@ impl DatabaseConnection {
|
||||
pub fn version(&self) -> String {
|
||||
match self {
|
||||
#[cfg(feature = "sqlx-mysql")]
|
||||
DatabaseConnection::SqlxMySqlPoolConnection { version, .. } => version.to_string(),
|
||||
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.version.to_string(),
|
||||
#[cfg(feature = "sqlx-postgres")]
|
||||
DatabaseConnection::SqlxPostgresPoolConnection(_) => "".to_string(),
|
||||
#[cfg(feature = "sqlx-sqlite")]
|
||||
|
@ -16,6 +16,7 @@ pub struct DatabaseTransaction {
|
||||
conn: Arc<Mutex<InnerConnection>>,
|
||||
backend: DbBackend,
|
||||
open: bool,
|
||||
support_returning: bool,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for DatabaseTransaction {
|
||||
@ -28,10 +29,12 @@ impl DatabaseTransaction {
|
||||
#[cfg(feature = "sqlx-mysql")]
|
||||
pub(crate) async fn new_mysql(
|
||||
inner: PoolConnection<sqlx::MySql>,
|
||||
support_returning: bool,
|
||||
) -> Result<DatabaseTransaction, DbErr> {
|
||||
Self::begin(
|
||||
Arc::new(Mutex::new(InnerConnection::MySql(inner))),
|
||||
DbBackend::MySql,
|
||||
support_returning,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@ -43,6 +46,7 @@ impl DatabaseTransaction {
|
||||
Self::begin(
|
||||
Arc::new(Mutex::new(InnerConnection::Postgres(inner))),
|
||||
DbBackend::Postgres,
|
||||
true,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@ -54,6 +58,7 @@ impl DatabaseTransaction {
|
||||
Self::begin(
|
||||
Arc::new(Mutex::new(InnerConnection::Sqlite(inner))),
|
||||
DbBackend::Sqlite,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@ -63,17 +68,28 @@ impl DatabaseTransaction {
|
||||
inner: Arc<crate::MockDatabaseConnection>,
|
||||
) -> Result<DatabaseTransaction, DbErr> {
|
||||
let backend = inner.get_database_backend();
|
||||
Self::begin(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await
|
||||
Self::begin(
|
||||
Arc::new(Mutex::new(InnerConnection::Mock(inner))),
|
||||
backend,
|
||||
match backend {
|
||||
DbBackend::MySql => false,
|
||||
DbBackend::Postgres => true,
|
||||
DbBackend::Sqlite => false,
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn begin(
|
||||
conn: Arc<Mutex<InnerConnection>>,
|
||||
backend: DbBackend,
|
||||
support_returning: bool,
|
||||
) -> Result<DatabaseTransaction, DbErr> {
|
||||
let res = DatabaseTransaction {
|
||||
conn,
|
||||
backend,
|
||||
open: true,
|
||||
support_returning,
|
||||
};
|
||||
match *res.conn.lock().await {
|
||||
#[cfg(feature = "sqlx-mysql")]
|
||||
@ -330,7 +346,8 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
|
||||
}
|
||||
|
||||
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
|
||||
DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend).await
|
||||
DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend, self.support_returning)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Execute the function inside a transaction.
|
||||
@ -349,14 +366,39 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
|
||||
}
|
||||
|
||||
fn returning_on_insert(&self) -> bool {
|
||||
// FIXME: How?
|
||||
match self.backend {
|
||||
DbBackend::MySql => {
|
||||
// Supported if it's MariaDB on or after version 10.5.0
|
||||
// Not supported in all MySQL versions
|
||||
self.support_returning
|
||||
}
|
||||
DbBackend::Postgres => {
|
||||
// Supported by all Postgres versions
|
||||
true
|
||||
}
|
||||
DbBackend::Sqlite => {
|
||||
// Supported by SQLite on or after version 3.35.0 (2021-03-12)
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn returning_on_update(&self) -> bool {
|
||||
// FIXME: How?
|
||||
match self.backend {
|
||||
DbBackend::MySql => {
|
||||
// Not supported in all MySQL & MariaDB versions
|
||||
false
|
||||
}
|
||||
DbBackend::Postgres => {
|
||||
// Supported by all Postgres versions
|
||||
true
|
||||
}
|
||||
DbBackend::Sqlite => {
|
||||
// Supported by SQLite on or after version 3.35.0 (2021-03-12)
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Defines errors for handling transaction failures
|
||||
|
@ -3,7 +3,7 @@ use std::{future::Future, pin::Pin};
|
||||
|
||||
use sqlx::{
|
||||
mysql::{MySqlArguments, MySqlConnectOptions, MySqlQueryResult, MySqlRow},
|
||||
MySql, MySqlPool,
|
||||
MySql, MySqlPool, Row,
|
||||
};
|
||||
|
||||
sea_query::sea_query_driver_mysql!();
|
||||
@ -24,6 +24,8 @@ pub struct SqlxMySqlConnector;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SqlxMySqlPoolConnection {
|
||||
pool: MySqlPool,
|
||||
pub(crate) version: String,
|
||||
pub(crate) support_returning: bool,
|
||||
}
|
||||
|
||||
impl SqlxMySqlConnector {
|
||||
@ -128,7 +130,7 @@ impl SqlxMySqlPoolConnection {
|
||||
/// Bundle a set of SQL statements that execute together.
|
||||
pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
|
||||
if let Ok(conn) = self.pool.acquire().await {
|
||||
DatabaseTransaction::new_mysql(conn).await
|
||||
DatabaseTransaction::new_mysql(conn, self.support_returning).await
|
||||
} else {
|
||||
Err(DbErr::Query(
|
||||
"Failed to acquire connection from pool.".to_owned(),
|
||||
@ -147,7 +149,7 @@ impl SqlxMySqlPoolConnection {
|
||||
E: std::error::Error + Send,
|
||||
{
|
||||
if let Ok(conn) = self.pool.acquire().await {
|
||||
let transaction = DatabaseTransaction::new_mysql(conn)
|
||||
let transaction = DatabaseTransaction::new_mysql(conn, self.support_returning)
|
||||
.await
|
||||
.map_err(|e| TransactionError::Connection(e))?;
|
||||
transaction.run(callback).await
|
||||
@ -184,15 +186,27 @@ pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySq
|
||||
}
|
||||
|
||||
async fn into_db_connection(pool: MySqlPool) -> Result<DatabaseConnection, DbErr> {
|
||||
let conn = SqlxMySqlPoolConnection { pool };
|
||||
let res = conn
|
||||
.query_one(Statement::from_string(
|
||||
let (version, support_returning) = parse_support_returning(&pool).await?;
|
||||
Ok(DatabaseConnection::SqlxMySqlPoolConnection(
|
||||
SqlxMySqlPoolConnection {
|
||||
pool,
|
||||
version,
|
||||
support_returning,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
async fn parse_support_returning(pool: &MySqlPool) -> Result<(String, bool), DbErr> {
|
||||
let stmt = Statement::from_string(
|
||||
DbBackend::MySql,
|
||||
r#"SHOW VARIABLES LIKE "version""#.to_owned(),
|
||||
))
|
||||
.await?;
|
||||
let (version, support_returning) = if let Some(query_result) = res {
|
||||
let version: String = query_result.try_get("", "Value")?;
|
||||
);
|
||||
let query = sqlx_query(&stmt);
|
||||
let row = query
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.map_err(sqlx_error_to_query_err)?;
|
||||
let version: String = row.try_get("Value").map_err(sqlx_error_to_query_err)?;
|
||||
let support_returning = if !version.contains("MariaDB") {
|
||||
// This is MySQL
|
||||
// Not supported in all MySQL versions
|
||||
@ -216,13 +230,5 @@ async fn into_db_connection(pool: MySqlPool) -> Result<DatabaseConnection, DbErr
|
||||
// Supported if it's MariaDB with version 10.5.0 or after
|
||||
ver_major >= 10 && ver_minor >= 5
|
||||
};
|
||||
(version, support_returning)
|
||||
} else {
|
||||
return Err(DbErr::Conn("Fail to parse MySQL version".to_owned()));
|
||||
};
|
||||
Ok(DatabaseConnection::SqlxMySqlPoolConnection {
|
||||
conn,
|
||||
version,
|
||||
support_returning,
|
||||
})
|
||||
Ok((version, support_returning))
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user