Revert MySQL & SQLite returning support

This commit is contained in:
Billy Chan 2021-11-10 14:40:44 +08:00
parent cc035d7aa7
commit 66c23c85db
No known key found for this signature in database
GPG Key ID: A2D690CAC7DF3CC7
10 changed files with 43 additions and 213 deletions

View File

@ -288,6 +288,7 @@ jobs:
name: Examples
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
path: [basic, actix_example, actix4_example, axum_example, rocket_example]
@ -312,6 +313,7 @@ jobs:
if: ${{ (needs.init.outputs.run-partial == 'true' && needs.init.outputs.run-issues == 'true') }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
path: [86, 249, 262]

View File

@ -38,7 +38,6 @@ sqlx = { version = "^0.5", optional = true }
uuid = { version = "0.8", features = ["serde", "v4"], optional = true }
ouroboros = "0.11"
url = "^2.2"
regex = "^1"
[dev-dependencies]
smol = { version = "^1.2" }

View File

@ -45,11 +45,11 @@ pub trait ConnectionTrait<'a>: Sync {
T: Send,
E: std::error::Error + Send;
/// Check if the connection supports `RETURNING` syntax on insert
fn returning_on_insert(&self) -> bool;
/// Check if the connection supports `RETURNING` syntax on update
fn returning_on_update(&self) -> bool;
/// Check if the connection supports `RETURNING` syntax on insert and update
fn support_returning(&self) -> bool {
let db_backend = self.get_database_backend();
db_backend.support_returning()
}
/// Check if the connection is a test connection for the Mock database
fn is_mock_connection(&self) -> bool {

View File

@ -214,61 +214,6 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
}
}
fn returning_on_insert(&self) -> bool {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => {
// Supported if it's MariaDB on or after version 10.5.0
// Not supported in all MySQL versions
conn.support_returning
}
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => {
// Supported by all Postgres versions
true
}
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => {
// Supported by SQLite on or after version 3.35.0 (2021-03-12)
conn.support_returning
}
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() {
DbBackend::MySql => false,
DbBackend::Postgres => true,
DbBackend::Sqlite => false,
},
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}
fn returning_on_update(&self) -> bool {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(_) => {
// Not supported in all MySQL & MariaDB versions
false
}
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => {
// Supported by all Postgres versions
true
}
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => {
// Supported by SQLite on or after version 3.35.0 (2021-03-12)
conn.support_returning
}
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() {
DbBackend::MySql => false,
DbBackend::Postgres => true,
DbBackend::Sqlite => false,
},
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}
#[cfg(feature = "mock")]
fn is_mock_connection(&self) -> bool {
matches!(self, DatabaseConnection::MockDatabaseConnection(_))
@ -322,6 +267,11 @@ impl DbBackend {
Self::Sqlite => Box::new(SqliteQueryBuilder),
}
}
/// Check if the database supports `RETURNING` syntax on insert and update
pub fn support_returning(&self) -> bool {
matches!(self, Self::Postgres)
}
}
#[cfg(test)]

View File

@ -16,7 +16,6 @@ pub struct DatabaseTransaction {
conn: Arc<Mutex<InnerConnection>>,
backend: DbBackend,
open: bool,
support_returning: bool,
}
impl std::fmt::Debug for DatabaseTransaction {
@ -29,12 +28,10 @@ impl DatabaseTransaction {
#[cfg(feature = "sqlx-mysql")]
pub(crate) async fn new_mysql(
inner: PoolConnection<sqlx::MySql>,
support_returning: bool,
) -> Result<DatabaseTransaction, DbErr> {
Self::begin(
Arc::new(Mutex::new(InnerConnection::MySql(inner))),
DbBackend::MySql,
support_returning,
)
.await
}
@ -46,7 +43,6 @@ impl DatabaseTransaction {
Self::begin(
Arc::new(Mutex::new(InnerConnection::Postgres(inner))),
DbBackend::Postgres,
true,
)
.await
}
@ -54,12 +50,10 @@ impl DatabaseTransaction {
#[cfg(feature = "sqlx-sqlite")]
pub(crate) async fn new_sqlite(
inner: PoolConnection<sqlx::Sqlite>,
support_returning: bool,
) -> Result<DatabaseTransaction, DbErr> {
Self::begin(
Arc::new(Mutex::new(InnerConnection::Sqlite(inner))),
DbBackend::Sqlite,
support_returning,
)
.await
}
@ -69,28 +63,17 @@ impl DatabaseTransaction {
inner: Arc<crate::MockDatabaseConnection>,
) -> Result<DatabaseTransaction, DbErr> {
let backend = inner.get_database_backend();
Self::begin(
Arc::new(Mutex::new(InnerConnection::Mock(inner))),
backend,
match backend {
DbBackend::MySql => false,
DbBackend::Postgres => true,
DbBackend::Sqlite => false,
},
)
.await
Self::begin(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await
}
async fn begin(
conn: Arc<Mutex<InnerConnection>>,
backend: DbBackend,
support_returning: bool,
) -> Result<DatabaseTransaction, DbErr> {
let res = DatabaseTransaction {
conn,
backend,
open: true,
support_returning,
};
match *res.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
@ -347,8 +330,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
}
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend, self.support_returning)
.await
DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend).await
}
/// Execute the function inside a transaction.
@ -365,17 +347,6 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
let transaction = self.begin().await.map_err(TransactionError::Connection)?;
transaction.run(_callback).await
}
fn returning_on_insert(&self) -> bool {
self.support_returning
}
fn returning_on_update(&self) -> bool {
match self.backend {
DbBackend::MySql => false,
_ => self.support_returning,
}
}
}
/// Defines errors for handling transaction failures

View File

@ -1,9 +1,8 @@
use regex::Regex;
use std::{future::Future, pin::Pin};
use sqlx::{
mysql::{MySqlArguments, MySqlConnectOptions, MySqlQueryResult, MySqlRow},
MySql, MySqlPool, Row,
MySql, MySqlPool,
};
sea_query::sea_query_driver_mysql!();
@ -11,7 +10,7 @@ use sea_query_driver_mysql::bind_query;
use crate::{
debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction,
DbBackend, QueryStream, Statement, TransactionError,
QueryStream, Statement, TransactionError,
};
use super::sqlx_common::*;
@ -24,7 +23,6 @@ pub struct SqlxMySqlConnector;
#[derive(Debug, Clone)]
pub struct SqlxMySqlPoolConnection {
pool: MySqlPool,
pub(crate) support_returning: bool,
}
impl SqlxMySqlConnector {
@ -44,7 +42,9 @@ impl SqlxMySqlConnector {
opt.disable_statement_logging();
}
if let Ok(pool) = options.pool_options().connect_with(opt).await {
into_db_connection(pool).await
Ok(DatabaseConnection::SqlxMySqlPoolConnection(
SqlxMySqlPoolConnection { pool },
))
} else {
Err(DbErr::Conn("Failed to connect.".to_owned()))
}
@ -53,8 +53,8 @@ impl SqlxMySqlConnector {
impl SqlxMySqlConnector {
/// Instantiate a sqlx pool connection to a [DatabaseConnection]
pub async fn from_sqlx_mysql_pool(pool: MySqlPool) -> Result<DatabaseConnection, DbErr> {
into_db_connection(pool).await
pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection { pool })
}
}
@ -129,7 +129,7 @@ impl SqlxMySqlPoolConnection {
/// Bundle a set of SQL statements that execute together.
pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
if let Ok(conn) = self.pool.acquire().await {
DatabaseTransaction::new_mysql(conn, self.support_returning).await
DatabaseTransaction::new_mysql(conn).await
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
@ -148,7 +148,7 @@ impl SqlxMySqlPoolConnection {
E: std::error::Error + Send,
{
if let Ok(conn) = self.pool.acquire().await {
let transaction = DatabaseTransaction::new_mysql(conn, self.support_returning)
let transaction = DatabaseTransaction::new_mysql(conn)
.await
.map_err(|e| TransactionError::Connection(e))?;
transaction.run(callback).await
@ -183,52 +183,3 @@ pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySq
}
query
}
async fn into_db_connection(pool: MySqlPool) -> Result<DatabaseConnection, DbErr> {
let support_returning = parse_support_returning(&pool).await?;
Ok(DatabaseConnection::SqlxMySqlPoolConnection(
SqlxMySqlPoolConnection {
pool,
support_returning,
},
))
}
async fn parse_support_returning(pool: &MySqlPool) -> Result<bool, DbErr> {
let stmt = Statement::from_string(
DbBackend::MySql,
r#"SHOW VARIABLES LIKE "version""#.to_owned(),
);
let query = sqlx_query(&stmt);
let row = query
.fetch_one(pool)
.await
.map_err(sqlx_error_to_query_err)?;
let version: String = row.try_get("Value").map_err(sqlx_error_to_query_err)?;
let support_returning = if !version.contains("MariaDB") {
// This is MySQL
// Not supported in all MySQL versions
false
} else {
// This is MariaDB
let regex = Regex::new(r"^(\d+)?.(\d+)?.(\*|\d+)").unwrap();
let captures = regex.captures(&version).unwrap();
macro_rules! parse_captures {
( $idx: expr ) => {
captures.get($idx).map_or(0, |m| {
m.as_str()
.parse::<usize>()
.map_err(|e| DbErr::Conn(e.to_string()))
.unwrap()
})
};
}
let ver_major = parse_captures!(1);
let ver_minor = parse_captures!(2);
// Supported if it's MariaDB with version 10.5.0 or after
ver_major >= 10 && ver_minor >= 5
};
debug_print!("db_version: {}", version);
debug_print!("db_support_returning: {}", support_returning);
Ok(support_returning)
}

View File

@ -1,9 +1,8 @@
use regex::Regex;
use std::{future::Future, pin::Pin};
use sqlx::{
sqlite::{SqliteArguments, SqliteConnectOptions, SqliteQueryResult, SqliteRow},
Row, Sqlite, SqlitePool,
Sqlite, SqlitePool,
};
sea_query::sea_query_driver_sqlite!();
@ -11,7 +10,7 @@ use sea_query_driver_sqlite::bind_query;
use crate::{
debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction,
DbBackend, QueryStream, Statement, TransactionError,
QueryStream, Statement, TransactionError,
};
use super::sqlx_common::*;
@ -24,7 +23,6 @@ pub struct SqlxSqliteConnector;
#[derive(Debug, Clone)]
pub struct SqlxSqlitePoolConnection {
pool: SqlitePool,
pub(crate) support_returning: bool,
}
impl SqlxSqliteConnector {
@ -48,7 +46,9 @@ impl SqlxSqliteConnector {
options.max_connections(1);
}
if let Ok(pool) = options.pool_options().connect_with(opt).await {
into_db_connection(pool).await
Ok(DatabaseConnection::SqlxSqlitePoolConnection(
SqlxSqlitePoolConnection { pool },
))
} else {
Err(DbErr::Conn("Failed to connect.".to_owned()))
}
@ -57,8 +57,8 @@ impl SqlxSqliteConnector {
impl SqlxSqliteConnector {
/// Instantiate a sqlx pool connection to a [DatabaseConnection]
pub async fn from_sqlx_sqlite_pool(pool: SqlitePool) -> Result<DatabaseConnection, DbErr> {
into_db_connection(pool).await
pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
DatabaseConnection::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection { pool })
}
}
@ -133,7 +133,7 @@ impl SqlxSqlitePoolConnection {
/// Bundle a set of SQL statements that execute together.
pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
if let Ok(conn) = self.pool.acquire().await {
DatabaseTransaction::new_sqlite(conn, self.support_returning).await
DatabaseTransaction::new_sqlite(conn).await
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
@ -152,7 +152,7 @@ impl SqlxSqlitePoolConnection {
E: std::error::Error + Send,
{
if let Ok(conn) = self.pool.acquire().await {
let transaction = DatabaseTransaction::new_sqlite(conn, self.support_returning)
let transaction = DatabaseTransaction::new_sqlite(conn)
.await
.map_err(|e| TransactionError::Connection(e))?;
transaction.run(callback).await
@ -187,45 +187,3 @@ pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, Sql
}
query
}
async fn into_db_connection(pool: SqlitePool) -> Result<DatabaseConnection, DbErr> {
let support_returning = parse_support_returning(&pool).await?;
Ok(DatabaseConnection::SqlxSqlitePoolConnection(
SqlxSqlitePoolConnection {
pool,
support_returning,
},
))
}
async fn parse_support_returning(pool: &SqlitePool) -> Result<bool, DbErr> {
let stmt = Statement::from_string(
DbBackend::Sqlite,
r#"SELECT sqlite_version() AS version"#.to_owned(),
);
let query = sqlx_query(&stmt);
let row = query
.fetch_one(pool)
.await
.map_err(sqlx_error_to_query_err)?;
let version: String = row.try_get("version").map_err(sqlx_error_to_query_err)?;
let regex = Regex::new(r"^(\d+)?.(\d+)?.(\*|\d+)").unwrap();
let captures = regex.captures(&version).unwrap();
macro_rules! parse_captures {
( $idx: expr ) => {
captures.get($idx).map_or(0, |m| {
m.as_str()
.parse::<usize>()
.map_err(|e| DbErr::Conn(e.to_string()))
.unwrap()
})
};
}
let ver_major = parse_captures!(1);
let ver_minor = parse_captures!(2);
// Supported if it's version 3.35.0 (2021-03-12) or after
let support_returning = ver_major >= 3 && ver_minor >= 35;
debug_print!("db_version: {}", version);
debug_print!("db_support_returning: {}", support_returning);
Ok(support_returning)
}

View File

@ -41,7 +41,7 @@ where
{
// so that self is dropped before entering await
let mut query = self.query;
if db.returning_on_insert() && <A::Entity as EntityTrait>::PrimaryKey::iter().count() > 0 {
if db.support_returning() && <A::Entity as EntityTrait>::PrimaryKey::iter().count() > 0 {
let mut returning = Query::select();
returning.columns(
<A::Entity as EntityTrait>::PrimaryKey::iter().map(|c| c.into_column_ref()),
@ -113,7 +113,7 @@ where
{
type PrimaryKey<A> = <<A as ActiveModelTrait>::Entity as EntityTrait>::PrimaryKey;
type ValueTypeOf<A> = <PrimaryKey<A> as PrimaryKeyTrait>::ValueType;
let last_insert_id_opt = match db.returning_on_insert() {
let last_insert_id_opt = match db.support_returning() {
true => {
let cols = PrimaryKey::<A>::iter()
.map(|col| col.to_string())
@ -147,7 +147,7 @@ where
A: ActiveModelTrait,
{
let db_backend = db.get_database_backend();
let found = match db.returning_on_insert() {
let found = match db.support_returning() {
true => {
let mut returning = Query::select();
returning.exprs(<A::Entity as EntityTrait>::Column::iter().map(|c| {

View File

@ -90,7 +90,7 @@ where
A: ActiveModelTrait,
C: ConnectionTrait<'a>,
{
match db.returning_on_update() {
match db.support_returning() {
true => {
let mut returning = Query::select();
returning.exprs(<A::Entity as EntityTrait>::Column::iter().map(|c| {

View File

@ -1,8 +1,8 @@
pub mod common;
pub use common::{bakery_chain::*, setup::*, TestContext};
use sea_orm::{entity::prelude::*, *};
use sea_query::Query;
pub use sea_orm::{entity::prelude::*, *};
pub use sea_query::Query;
#[sea_orm_macros::test]
#[cfg(any(
@ -37,7 +37,7 @@ async fn main() -> Result<(), DbErr> {
create_tables(db).await?;
if db.returning_on_insert() {
if db.support_returning() {
insert.returning(returning.clone());
let insert_res = db
.query_one(builder.build(&insert))
@ -46,11 +46,7 @@ async fn main() -> Result<(), DbErr> {
let _id: i32 = insert_res.try_get("", "id")?;
let _name: String = insert_res.try_get("", "name")?;
let _profit_margin: f64 = insert_res.try_get("", "profit_margin")?;
} else {
let insert_res = db.execute(builder.build(&insert)).await?;
assert!(insert_res.rows_affected() > 0);
}
if db.returning_on_update() {
update.returning(returning.clone());
let update_res = db
.query_one(builder.build(&update))
@ -60,6 +56,9 @@ async fn main() -> Result<(), DbErr> {
let _name: String = update_res.try_get("", "name")?;
let _profit_margin: f64 = update_res.try_get("", "profit_margin")?;
} else {
let insert_res = db.execute(builder.build(&insert)).await?;
assert!(insert_res.rows_affected() > 0);
let update_res = db.execute(builder.build(&update)).await?;
assert!(update_res.rows_affected() > 0);
}