Try returning on MariaDB

This commit is contained in:
Billy Chan 2021-11-08 17:36:30 +08:00
parent 732d080020
commit 0eafacc2a1
No known key found for this signature in database
GPG Key ID: A2D690CAC7DF3CC7
7 changed files with 107 additions and 31 deletions

View File

@ -38,6 +38,7 @@ sqlx = { version = "^0.5", optional = true }
uuid = { version = "0.8", features = ["serde", "v4"], optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true }
ouroboros = "0.11" ouroboros = "0.11"
url = "^2.2" url = "^2.2"
regex = "^1"
[dev-dependencies] [dev-dependencies]
smol = { version = "^1.2" } smol = { version = "^1.2" }

View File

@ -45,6 +45,9 @@ pub trait ConnectionTrait<'a>: Sync {
T: Send, T: Send,
E: std::error::Error + Send; E: std::error::Error + Send;
/// Check if the connection supports `RETURNING` syntax
fn support_returning(&self) -> bool;
/// Check if the connection is a test connection for the Mock database /// Check if the connection is a test connection for the Mock database
fn is_mock_connection(&self) -> bool { fn is_mock_connection(&self) -> bool {
false false

View File

@ -18,7 +18,12 @@ use std::sync::Arc;
pub enum DatabaseConnection { pub enum DatabaseConnection {
/// Create a MYSQL database connection and pool /// Create a MYSQL database connection and pool
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection), SqlxMySqlPoolConnection {
/// A SQLx MySQL pool
conn: crate::SqlxMySqlPoolConnection,
/// A flag indicating whether `RETURNING` syntax is supported
support_returning: bool,
},
/// Create a PostgreSQL database connection and pool /// Create a PostgreSQL database connection and pool
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection), SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
@ -73,7 +78,7 @@ impl std::fmt::Debug for DatabaseConnection {
"{}", "{}",
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection", Self::SqlxMySqlPoolConnection { .. } => "SqlxMySqlPoolConnection",
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection", Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
@ -93,7 +98,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
fn get_database_backend(&self) -> DbBackend { fn get_database_backend(&self) -> DbBackend {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql, DatabaseConnection::SqlxMySqlPoolConnection { .. } => DbBackend::MySql,
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres, DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
@ -107,7 +112,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> { async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await, DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.execute(stmt).await,
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await, DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await,
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
@ -121,7 +126,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> { async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await, DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.query_one(stmt).await,
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await, DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await,
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
@ -135,7 +140,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> { async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await, DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.query_all(stmt).await,
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await, DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await,
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
@ -153,7 +158,9 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
Box::pin(async move { Box::pin(async move {
Ok(match self { Ok(match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?, DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => {
conn.stream(stmt).await?
}
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?, DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
@ -170,7 +177,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> { async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await, DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.begin().await,
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await, DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await,
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
@ -196,7 +203,9 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
{ {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await, DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => {
conn.transaction(_callback).await
}
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => { DatabaseConnection::SqlxPostgresPoolConnection(conn) => {
conn.transaction(_callback).await conn.transaction(_callback).await
@ -214,6 +223,24 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
} }
} }
fn support_returning(&self) -> bool {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { .. } => false,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => true,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(_) => false,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() {
DbBackend::MySql => false,
DbBackend::Postgres => true,
DbBackend::Sqlite => false,
},
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
fn is_mock_connection(&self) -> bool { fn is_mock_connection(&self) -> bool {
matches!(self, DatabaseConnection::MockDatabaseConnection(_)) matches!(self, DatabaseConnection::MockDatabaseConnection(_))

View File

@ -347,6 +347,14 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
let transaction = self.begin().await.map_err(TransactionError::Connection)?; let transaction = self.begin().await.map_err(TransactionError::Connection)?;
transaction.run(_callback).await transaction.run(_callback).await
} }
fn support_returning(&self) -> bool {
match self.backend {
DbBackend::MySql => false,
DbBackend::Postgres => true,
DbBackend::Sqlite => false,
}
}
} }
/// Defines errors for handling transaction failures /// Defines errors for handling transaction failures

View File

@ -1,3 +1,4 @@
use regex::Regex;
use std::{future::Future, pin::Pin}; use std::{future::Future, pin::Pin};
use sqlx::{ use sqlx::{
@ -10,7 +11,7 @@ use sea_query_driver_mysql::bind_query;
use crate::{ use crate::{
debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction,
QueryStream, Statement, TransactionError, DbBackend, QueryStream, Statement, TransactionError,
}; };
use super::sqlx_common::*; use super::sqlx_common::*;
@ -42,9 +43,7 @@ impl SqlxMySqlConnector {
opt.disable_statement_logging(); opt.disable_statement_logging();
} }
if let Ok(pool) = options.pool_options().connect_with(opt).await { if let Ok(pool) = options.pool_options().connect_with(opt).await {
Ok(DatabaseConnection::SqlxMySqlPoolConnection( into_db_connection(pool).await
SqlxMySqlPoolConnection { pool },
))
} else { } else {
Err(DbErr::Conn("Failed to connect.".to_owned())) Err(DbErr::Conn("Failed to connect.".to_owned()))
} }
@ -53,8 +52,8 @@ impl SqlxMySqlConnector {
impl SqlxMySqlConnector { impl SqlxMySqlConnector {
/// Instantiate a sqlx pool connection to a [DatabaseConnection] /// Instantiate a sqlx pool connection to a [DatabaseConnection]
pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection { pub async fn from_sqlx_mysql_pool(pool: MySqlPool) -> Result<DatabaseConnection, DbErr> {
DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection { pool }) into_db_connection(pool).await
} }
} }
@ -183,3 +182,43 @@ pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySq
} }
query query
} }
async fn into_db_connection(pool: MySqlPool) -> Result<DatabaseConnection, DbErr> {
let conn = SqlxMySqlPoolConnection { pool };
let res = conn
.query_one(Statement::from_string(
DbBackend::MySql,
r#"SHOW VARIABLES LIKE "version""#.to_owned(),
))
.await?;
let support_returning = if let Some(query_result) = res {
let version: String = query_result.try_get("", "Value")?;
if !version.contains("MariaDB") {
// This is MySQL
false
} else {
// This is MariaDB
let regex = Regex::new(r"^(\d+)?.(\d+)?.(\*|\d+)").unwrap();
let captures = regex.captures(&version).unwrap();
macro_rules! parse_captures {
( $idx: expr ) => {
captures.get($idx).map_or(0, |m| {
m.as_str()
.parse::<usize>()
.map_err(|e| DbErr::Conn(e.to_string()))
.unwrap()
})
};
}
let ver_major = parse_captures!(1);
let ver_minor = parse_captures!(2);
ver_major >= 10 && ver_minor >= 5
}
} else {
return Err(DbErr::Conn("Fail to parse MySQL version".to_owned()));
};
Ok(DatabaseConnection::SqlxMySqlPoolConnection {
conn,
support_returning,
})
}

View File

@ -1,6 +1,6 @@
use crate::{ use crate::{
error::*, ActiveModelTrait, ConnectionTrait, DbBackend, EntityTrait, Insert, IntoActiveModel, error::*, ActiveModelTrait, ConnectionTrait, EntityTrait, Insert, IntoActiveModel, Iterable,
Iterable, PrimaryKeyTrait, SelectModel, SelectorRaw, Statement, TryFromU64, PrimaryKeyTrait, SelectModel, SelectorRaw, Statement, TryFromU64,
}; };
use sea_query::{FromValueTuple, Iden, InsertStatement, IntoColumnRef, Returning, ValueTuple}; use sea_query::{FromValueTuple, Iden, InsertStatement, IntoColumnRef, Returning, ValueTuple};
use std::{future::Future, marker::PhantomData}; use std::{future::Future, marker::PhantomData};
@ -39,9 +39,7 @@ where
{ {
// so that self is dropped before entering await // so that self is dropped before entering await
let mut query = self.query; let mut query = self.query;
if db.get_database_backend() == DbBackend::Postgres if db.support_returning() && <A::Entity as EntityTrait>::PrimaryKey::iter().count() > 0 {
&& <A::Entity as EntityTrait>::PrimaryKey::iter().count() > 0
{
query.returning(Returning::Columns( query.returning(Returning::Columns(
<A::Entity as EntityTrait>::PrimaryKey::iter() <A::Entity as EntityTrait>::PrimaryKey::iter()
.map(|c| c.into_column_ref()) .map(|c| c.into_column_ref())
@ -113,15 +111,15 @@ where
{ {
type PrimaryKey<A> = <<A as ActiveModelTrait>::Entity as EntityTrait>::PrimaryKey; type PrimaryKey<A> = <<A as ActiveModelTrait>::Entity as EntityTrait>::PrimaryKey;
type ValueTypeOf<A> = <PrimaryKey<A> as PrimaryKeyTrait>::ValueType; type ValueTypeOf<A> = <PrimaryKey<A> as PrimaryKeyTrait>::ValueType;
let last_insert_id_opt = match db.get_database_backend() { let last_insert_id_opt = match db.support_returning() {
DbBackend::Postgres => { true => {
let cols = PrimaryKey::<A>::iter() let cols = PrimaryKey::<A>::iter()
.map(|col| col.to_string()) .map(|col| col.to_string())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let res = db.query_one(statement).await?.unwrap(); let res = db.query_one(statement).await?.unwrap();
res.try_get_many("", cols.as_ref()).ok() res.try_get_many("", cols.as_ref()).ok()
} }
_ => { false => {
let last_insert_id = db.execute(statement).await?.last_insert_id(); let last_insert_id = db.execute(statement).await?.last_insert_id();
ValueTypeOf::<A>::try_from_u64(last_insert_id).ok() ValueTypeOf::<A>::try_from_u64(last_insert_id).ok()
} }
@ -147,8 +145,8 @@ where
A: ActiveModelTrait, A: ActiveModelTrait,
{ {
let db_backend = db.get_database_backend(); let db_backend = db.get_database_backend();
let found = match db_backend { let found = match db.support_returning() {
DbBackend::Postgres => { true => {
insert_statement.returning(Returning::Columns( insert_statement.returning(Returning::Columns(
<A::Entity as EntityTrait>::Column::iter() <A::Entity as EntityTrait>::Column::iter()
.map(|c| c.into_column_ref()) .map(|c| c.into_column_ref())
@ -160,7 +158,7 @@ where
.one(db) .one(db)
.await? .await?
} }
_ => { false => {
let insert_res = let insert_res =
exec_insert::<A, _>(primary_key, db_backend.build(&insert_statement), db).await?; exec_insert::<A, _>(primary_key, db_backend.build(&insert_statement), db).await?;
<A::Entity as EntityTrait>::find_by_id(insert_res.last_insert_id) <A::Entity as EntityTrait>::find_by_id(insert_res.last_insert_id)

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
error::*, ActiveModelTrait, ConnectionTrait, DbBackend, EntityTrait, IntoActiveModel, Iterable, error::*, ActiveModelTrait, ConnectionTrait, EntityTrait, IntoActiveModel, Iterable,
SelectModel, SelectorRaw, Statement, UpdateMany, UpdateOne, SelectModel, SelectorRaw, Statement, UpdateMany, UpdateOne,
}; };
use sea_query::{FromValueTuple, IntoColumnRef, Returning, UpdateStatement}; use sea_query::{FromValueTuple, IntoColumnRef, Returning, UpdateStatement};
@ -90,14 +90,14 @@ where
A: ActiveModelTrait, A: ActiveModelTrait,
C: ConnectionTrait<'a>, C: ConnectionTrait<'a>,
{ {
let db_backend = db.get_database_backend(); match db.support_returning() {
match db_backend { true => {
DbBackend::Postgres => {
query.returning(Returning::Columns( query.returning(Returning::Columns(
<A::Entity as EntityTrait>::Column::iter() <A::Entity as EntityTrait>::Column::iter()
.map(|c| c.into_column_ref()) .map(|c| c.into_column_ref())
.collect(), .collect(),
)); ));
let db_backend = db.get_database_backend();
let found: Option<<A::Entity as EntityTrait>::Model> = let found: Option<<A::Entity as EntityTrait>::Model> =
SelectorRaw::<SelectModel<<A::Entity as EntityTrait>::Model>>::from_statement( SelectorRaw::<SelectModel<<A::Entity as EntityTrait>::Model>>::from_statement(
db_backend.build(&query), db_backend.build(&query),
@ -112,7 +112,7 @@ where
)), )),
} }
} }
_ => { false => {
// If we updating a row that does not exist then an error will be thrown here. // If we updating a row that does not exist then an error will be thrown here.
Updater::new(query).check_record_exists().exec(db).await?; Updater::new(query).check_record_exists().exec(db).await?;
let primary_key_value = match model.get_primary_key_value() { let primary_key_value = match model.get_primary_key_value() {