From 429b920dedcabc697a5b354ca6acc612c861b56c Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Tue, 9 Nov 2021 11:05:55 +0800 Subject: [PATCH] Fixup --- src/database/connection.rs | 7 ++++-- src/database/db_connection.rs | 40 +++++++++++++---------------------- src/database/transaction.rs | 7 +++++- src/executor/insert.rs | 6 +++--- src/executor/update.rs | 2 +- tests/returning_tests.rs | 12 ++++++----- 6 files changed, 37 insertions(+), 37 deletions(-) diff --git a/src/database/connection.rs b/src/database/connection.rs index e06c6e57..2a16156e 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -45,8 +45,11 @@ pub trait ConnectionTrait<'a>: Sync { T: Send, E: std::error::Error + Send; - /// Check if the connection supports `RETURNING` syntax - fn support_returning(&self) -> bool; + /// Check if the connection supports `RETURNING` syntax on insert + fn returning_on_insert(&self) -> bool; + + /// Check if the connection supports `RETURNING` syntax on update + fn returning_on_update(&self) -> bool; /// Check if the connection is a test connection for the Mock database fn is_mock_connection(&self) -> bool { diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 9c49927a..90a472e4 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -225,7 +225,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { } } - fn support_returning(&self) -> bool { + fn returning_on_insert(&self) -> bool { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection { @@ -235,13 +235,21 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { DatabaseConnection::SqlxPostgresPoolConnection(_) => true, #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(_) => false, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() { - DbBackend::MySql => false, - DbBackend::Postgres => true, - DbBackend::Sqlite => false, - }, DatabaseConnection::Disconnected => panic!("Disconnected"), + _ => unimplemented!(), + } + } + + fn returning_on_update(&self) -> bool { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection { .. } => false, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(_) => true, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(_) => false, + DatabaseConnection::Disconnected => panic!("Disconnected"), + _ => unimplemented!(), } } @@ -278,24 +286,6 @@ impl DatabaseConnection { DatabaseConnection::SqlxPostgresPoolConnection(_) => "".to_string(), #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(_) => "".to_string(), - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(_) => "".to_string(), - DatabaseConnection::Disconnected => panic!("Disconnected"), - _ => unimplemented!(), - } - } - - /// Check if database supports `RETURNING` - pub fn support_returning(&self) -> bool { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { - support_returning, .. - } => *support_returning, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(_) => true, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(_) => false, DatabaseConnection::Disconnected => panic!("Disconnected"), _ => unimplemented!(), } diff --git a/src/database/transaction.rs b/src/database/transaction.rs index 5865311f..727f9bc6 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -348,7 +348,12 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { transaction.run(_callback).await } - fn support_returning(&self) -> bool { + fn returning_on_insert(&self) -> bool { + // FIXME: How? + false + } + + fn returning_on_update(&self) -> bool { // FIXME: How? false } diff --git a/src/executor/insert.rs b/src/executor/insert.rs index a6dbcbd5..fde1a3ab 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -41,7 +41,7 @@ where { // so that self is dropped before entering await let mut query = self.query; - if db.support_returning() && ::PrimaryKey::iter().count() > 0 { + if db.returning_on_insert() && ::PrimaryKey::iter().count() > 0 { let mut returning = Query::select(); returning.columns( ::PrimaryKey::iter().map(|c| c.into_column_ref()), @@ -113,7 +113,7 @@ where { type PrimaryKey = <::Entity as EntityTrait>::PrimaryKey; type ValueTypeOf = as PrimaryKeyTrait>::ValueType; - let last_insert_id_opt = match db.support_returning() { + let last_insert_id_opt = match db.returning_on_insert() { true => { let cols = PrimaryKey::::iter() .map(|col| col.to_string()) @@ -147,7 +147,7 @@ where A: ActiveModelTrait, { let db_backend = db.get_database_backend(); - let found = match db.support_returning() { + let found = match db.returning_on_insert() { true => { let mut returning = Query::select(); returning.exprs(::Column::iter().map(|c| { diff --git a/src/executor/update.rs b/src/executor/update.rs index 9870b10d..d27aa41d 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -90,7 +90,7 @@ where A: ActiveModelTrait, C: ConnectionTrait<'a>, { - match db.support_returning() { + match db.returning_on_update() { true => { let mut returning = Query::select(); returning.exprs(::Column::iter().map(|c| { diff --git a/tests/returning_tests.rs b/tests/returning_tests.rs index 47690506..561ba2c5 100644 --- a/tests/returning_tests.rs +++ b/tests/returning_tests.rs @@ -31,12 +31,14 @@ async fn main() -> Result<(), DbErr> { (Column::ProfitMargin, 0.5.into()), ]) .and_where(Column::Id.eq(1)); - - if db.support_returning() { - let mut returning = Query::select(); - returning.columns(vec![Column::Id, Column::Name, Column::ProfitMargin]); + + let mut returning = Query::select(); + returning.columns(vec![Column::Id, Column::Name, Column::ProfitMargin]); + if db.returning_on_insert() { insert.returning(returning.clone()); - update.returning(returning); + } + if db.returning_on_update() { + update.returning(returning.clone()); } create_tables(db).await?;