This commit is contained in:
Billy Chan 2021-11-09 11:05:55 +08:00
parent 30a50ca75d
commit 429b920ded
No known key found for this signature in database
GPG Key ID: A2D690CAC7DF3CC7
6 changed files with 37 additions and 37 deletions

View File

@ -45,8 +45,11 @@ pub trait ConnectionTrait<'a>: Sync {
T: Send, T: Send,
E: std::error::Error + Send; E: std::error::Error + Send;
/// Check if the connection supports `RETURNING` syntax /// Check if the connection supports `RETURNING` syntax on insert
fn support_returning(&self) -> bool; fn returning_on_insert(&self) -> bool;
/// Check if the connection supports `RETURNING` syntax on update
fn returning_on_update(&self) -> bool;
/// Check if the connection is a test connection for the Mock database /// Check if the connection is a test connection for the Mock database
fn is_mock_connection(&self) -> bool { fn is_mock_connection(&self) -> bool {

View File

@ -225,7 +225,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
} }
} }
fn support_returning(&self) -> bool { fn returning_on_insert(&self) -> bool {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { DatabaseConnection::SqlxMySqlPoolConnection {
@ -235,13 +235,21 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
DatabaseConnection::SqlxPostgresPoolConnection(_) => true, DatabaseConnection::SqlxPostgresPoolConnection(_) => true,
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(_) => false, DatabaseConnection::SqlxSqlitePoolConnection(_) => false,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() {
DbBackend::MySql => false,
DbBackend::Postgres => true,
DbBackend::Sqlite => false,
},
DatabaseConnection::Disconnected => panic!("Disconnected"), DatabaseConnection::Disconnected => panic!("Disconnected"),
_ => unimplemented!(),
}
}
fn returning_on_update(&self) -> bool {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { .. } => false,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => true,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(_) => false,
DatabaseConnection::Disconnected => panic!("Disconnected"),
_ => unimplemented!(),
} }
} }
@ -278,24 +286,6 @@ impl DatabaseConnection {
DatabaseConnection::SqlxPostgresPoolConnection(_) => "".to_string(), DatabaseConnection::SqlxPostgresPoolConnection(_) => "".to_string(),
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(_) => "".to_string(), DatabaseConnection::SqlxSqlitePoolConnection(_) => "".to_string(),
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(_) => "".to_string(),
DatabaseConnection::Disconnected => panic!("Disconnected"),
_ => unimplemented!(),
}
}
/// Check if database supports `RETURNING`
pub fn support_returning(&self) -> bool {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection {
support_returning, ..
} => *support_returning,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => true,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(_) => false,
DatabaseConnection::Disconnected => panic!("Disconnected"), DatabaseConnection::Disconnected => panic!("Disconnected"),
_ => unimplemented!(), _ => unimplemented!(),
} }

View File

@ -348,7 +348,12 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
transaction.run(_callback).await transaction.run(_callback).await
} }
fn support_returning(&self) -> bool { fn returning_on_insert(&self) -> bool {
// FIXME: How?
false
}
fn returning_on_update(&self) -> bool {
// FIXME: How? // FIXME: How?
false false
} }

View File

@ -41,7 +41,7 @@ where
{ {
// so that self is dropped before entering await // so that self is dropped before entering await
let mut query = self.query; let mut query = self.query;
if db.support_returning() && <A::Entity as EntityTrait>::PrimaryKey::iter().count() > 0 { if db.returning_on_insert() && <A::Entity as EntityTrait>::PrimaryKey::iter().count() > 0 {
let mut returning = Query::select(); let mut returning = Query::select();
returning.columns( returning.columns(
<A::Entity as EntityTrait>::PrimaryKey::iter().map(|c| c.into_column_ref()), <A::Entity as EntityTrait>::PrimaryKey::iter().map(|c| c.into_column_ref()),
@ -113,7 +113,7 @@ where
{ {
type PrimaryKey<A> = <<A as ActiveModelTrait>::Entity as EntityTrait>::PrimaryKey; type PrimaryKey<A> = <<A as ActiveModelTrait>::Entity as EntityTrait>::PrimaryKey;
type ValueTypeOf<A> = <PrimaryKey<A> as PrimaryKeyTrait>::ValueType; type ValueTypeOf<A> = <PrimaryKey<A> as PrimaryKeyTrait>::ValueType;
let last_insert_id_opt = match db.support_returning() { let last_insert_id_opt = match db.returning_on_insert() {
true => { true => {
let cols = PrimaryKey::<A>::iter() let cols = PrimaryKey::<A>::iter()
.map(|col| col.to_string()) .map(|col| col.to_string())
@ -147,7 +147,7 @@ where
A: ActiveModelTrait, A: ActiveModelTrait,
{ {
let db_backend = db.get_database_backend(); let db_backend = db.get_database_backend();
let found = match db.support_returning() { let found = match db.returning_on_insert() {
true => { true => {
let mut returning = Query::select(); let mut returning = Query::select();
returning.exprs(<A::Entity as EntityTrait>::Column::iter().map(|c| { returning.exprs(<A::Entity as EntityTrait>::Column::iter().map(|c| {

View File

@ -90,7 +90,7 @@ where
A: ActiveModelTrait, A: ActiveModelTrait,
C: ConnectionTrait<'a>, C: ConnectionTrait<'a>,
{ {
match db.support_returning() { match db.returning_on_update() {
true => { true => {
let mut returning = Query::select(); let mut returning = Query::select();
returning.exprs(<A::Entity as EntityTrait>::Column::iter().map(|c| { returning.exprs(<A::Entity as EntityTrait>::Column::iter().map(|c| {

View File

@ -31,12 +31,14 @@ async fn main() -> Result<(), DbErr> {
(Column::ProfitMargin, 0.5.into()), (Column::ProfitMargin, 0.5.into()),
]) ])
.and_where(Column::Id.eq(1)); .and_where(Column::Id.eq(1));
if db.support_returning() { let mut returning = Query::select();
let mut returning = Query::select(); returning.columns(vec![Column::Id, Column::Name, Column::ProfitMargin]);
returning.columns(vec![Column::Id, Column::Name, Column::ProfitMargin]); if db.returning_on_insert() {
insert.returning(returning.clone()); insert.returning(returning.clone());
update.returning(returning); }
if db.returning_on_update() {
update.returning(returning.clone());
} }
create_tables(db).await?; create_tables(db).await?;