diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 32e96980..35ef5cc1 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -395,7 +395,7 @@ jobs: strategy: fail-fast: false matrix: - version: [8.0.27, 5.7.36] + version: [8.0, 5.7] runtime: [async-std, actix, tokio] tls: [native-tls] services: @@ -456,7 +456,7 @@ jobs: strategy: fail-fast: false matrix: - version: [10.6, 10.5, 10.0, 5.5] + version: [10.6, 10.5, 10.4] runtime: [async-std, actix, tokio] tls: [native-tls] services: diff --git a/Cargo.toml b/Cargo.toml index 46d8c617..3f0e0ea3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ futures-util = { version = "^0.3" } log = { version = "^0.4", optional = true } rust_decimal = { version = "^1", optional = true } sea-orm-macros = { version = "^0.3.1", path = "sea-orm-macros", optional = true } -sea-query = { version = "^0.18.2", features = ["thread-safe"] } +sea-query = { version = "^0.18.2", git = "https://github.com/SeaQL/sea-query.git", branch = "sea-orm/returning", features = ["thread-safe"] } sea-strum = { version = "^0.21", features = ["derive", "sea-orm"] } serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 59aaf5f2..66385876 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -228,7 +228,9 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { fn support_returning(&self) -> bool { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { support_returning, .. } => *support_returning, + DatabaseConnection::SqlxMySqlPoolConnection { + support_returning, .. + } => *support_returning, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(_) => true, #[cfg(feature = "sqlx-sqlite")] @@ -267,27 +269,13 @@ impl DatabaseConnection { } impl DatabaseConnection { - /// Get database version - pub fn db_version(&self) -> String { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { version, .. } => version.to_string(), - // #[cfg(feature = "sqlx-postgres")] - // DatabaseConnection::SqlxPostgresPoolConnection(conn) => , - // #[cfg(feature = "sqlx-sqlite")] - // DatabaseConnection::SqlxSqlitePoolConnection(conn) => , - // #[cfg(feature = "mock")] - // DatabaseConnection::MockDatabaseConnection(conn) => , - DatabaseConnection::Disconnected => panic!("Disconnected"), - _ => unimplemented!(), - } - } - /// Check if database supports `RETURNING` - pub fn db_support_returning(&self) -> bool { + pub fn support_returning(&self) -> bool { match self { #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { support_returning, .. } => *support_returning, + DatabaseConnection::SqlxMySqlPoolConnection { + support_returning, .. + } => *support_returning, #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(_) => true, // #[cfg(feature = "sqlx-sqlite")] diff --git a/src/executor/insert.rs b/src/executor/insert.rs index b5c2cf98..a6dbcbd5 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -43,8 +43,9 @@ where let mut query = self.query; if db.support_returning() && ::PrimaryKey::iter().count() > 0 { let mut returning = Query::select(); - returning - .columns(::PrimaryKey::iter().map(|c| c.into_column_ref())); + returning.columns( + ::PrimaryKey::iter().map(|c| c.into_column_ref()), + ); query.returning(returning); } Inserter::::new(self.primary_key, query).exec(db) diff --git a/tests/returning_tests.rs b/tests/returning_tests.rs index 0a1e02c9..df8fc1a9 100644 --- a/tests/returning_tests.rs +++ b/tests/returning_tests.rs @@ -1,7 +1,8 @@ pub mod common; -pub use common::{features::*, setup::*, TestContext}; -use sea_orm::{entity::prelude::*, entity::*, DatabaseConnection}; +pub use common::{bakery_chain::*, setup::*, TestContext}; +use sea_orm::{entity::prelude::*, *}; +use sea_query::Query; #[sea_orm_macros::test] #[cfg(any( @@ -10,27 +11,38 @@ use sea_orm::{entity::prelude::*, entity::*, DatabaseConnection}; feature = "sqlx-postgres" ))] async fn main() -> Result<(), DbErr> { + use bakery::*; + let ctx = TestContext::new("returning_tests").await; let db = &ctx.db; + let builder = db.get_database_backend(); - match db { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection { .. } => { - let version = db.db_version(); - match version.as_str() { - "5.7.26" => assert!(!db.db_support_returning()), - _ => unimplemented!("Version {} is not included", version), - }; - }, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(_) => { - assert!(db.db_support_returning()); - }, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(_) => {}, - _ => unreachable!(), + let mut insert = Query::insert(); + insert + .into_table(Entity) + .columns(vec![Column::Name, Column::ProfitMargin]) + .values_panic(vec!["Bakery Shop".into(), 0.5.into()]); + + let mut update = Query::update(); + update + .table(Entity) + .values(vec![ + (Column::Name, "Bakery Shop".into()), + (Column::ProfitMargin, 0.5.into()), + ]) + .and_where(Column::Id.eq(1)); + + if db.support_returning() { + let mut returning = Query::select(); + returning.columns(vec![Column::Id, Column::Name, Column::ProfitMargin]); + insert.returning(returning.clone()); + update.returning(returning); } + create_tables(db).await?; + db.query_one(builder.build(&insert)).await?; + db.query_one(builder.build(&update)).await?; + assert!(false); ctx.delete().await; Ok(())