Noop when update without providing any values (#1384)

* Noop when update without providing any values

* Add tests

* Update Cargo.toml

* Bump SeaQuery's version

* Fixup

Co-authored-by: Chris Tsang <chris.2y3@outlook.com>
This commit is contained in:
Billy Chan 2023-01-26 17:12:50 +08:00 committed by GitHub
parent b84c2ffdcb
commit 036edf9d70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 133 additions and 83 deletions

View File

@ -34,7 +34,7 @@ tracing = { version = "0.1", default-features = false, features = ["attributes",
rust_decimal = { version = "1", default-features = false, optional = true } rust_decimal = { version = "1", default-features = false, optional = true }
bigdecimal = { version = "0.3", default-features = false, optional = true } bigdecimal = { version = "0.3", default-features = false, optional = true }
sea-orm-macros = { version = "0.10.3", path = "sea-orm-macros", default-features = false, optional = true } sea-orm-macros = { version = "0.10.3", path = "sea-orm-macros", default-features = false, optional = true }
sea-query = { version = "0.28", features = ["thread-safe"] } sea-query = { version = "0.28.3", features = ["thread-safe"] }
sea-query-binder = { version = "0.3", default-features = false, optional = true } sea-query-binder = { version = "0.3", default-features = false, optional = true }
sea-strum = { version = "0.23", default-features = false, features = ["derive", "sea-orm"] } sea-strum = { version = "0.23", default-features = false, features = ["derive", "sea-orm"] }
serde = { version = "1.0", default-features = false } serde = { version = "1.0", default-features = false }

View File

@ -17,7 +17,7 @@ name = "sea_orm_codegen"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
sea-query = { version = "0.28", default-features = false, features = ["thread-safe"] } sea-query = { version = "0.28.3", default-features = false, features = ["thread-safe"] }
syn = { version = "1", default-features = false } syn = { version = "1", default-features = false }
quote = { version = "1", default-features = false } quote = { version = "1", default-features = false }
heck = { version = "0.3", default-features = false } heck = { version = "0.3", default-features = false }

View File

@ -1423,8 +1423,8 @@ mod tests {
vec![ vec![
Transaction::from_sql_and_values( Transaction::from_sql_and_values(
DbBackend::Postgres, DbBackend::Postgres,
r#"UPDATE "fruit" SET WHERE "fruit"."id" = $1 RETURNING "id", "name", "cake_id""#, r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit" WHERE "fruit"."id" = $1 LIMIT $2"#,
vec![1i32.into()], vec![1i32.into(), 1u64.into()],
), ),
Transaction::from_sql_and_values( Transaction::from_sql_and_values(
DbBackend::Postgres, DbBackend::Postgres,

View File

@ -1,9 +1,8 @@
use crate::{ use crate::{
error::*, ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, IntoActiveModel, error::*, ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, IntoActiveModel,
Iterable, PrimaryKeyTrait, SelectModel, SelectorRaw, Statement, UpdateMany, UpdateOne, Iterable, PrimaryKeyTrait, SelectModel, SelectorRaw, UpdateMany, UpdateOne,
}; };
use sea_query::{Expr, FromValueTuple, Query, UpdateStatement}; use sea_query::{Expr, FromValueTuple, Query, UpdateStatement};
use std::future::Future;
/// Defines an update operation /// Defines an update operation
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -13,7 +12,7 @@ pub struct Updater {
} }
/// The result of an update operation on an ActiveModel /// The result of an update operation on an ActiveModel
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq, Default)]
pub struct UpdateResult { pub struct UpdateResult {
/// The rows affected by the update operation /// The rows affected by the update operation
pub rows_affected: u64, pub rows_affected: u64,
@ -29,8 +28,9 @@ where
<A::Entity as EntityTrait>::Model: IntoActiveModel<A>, <A::Entity as EntityTrait>::Model: IntoActiveModel<A>,
C: ConnectionTrait, C: ConnectionTrait,
{ {
// so that self is dropped before entering await Updater::new(self.query)
exec_update_and_return_updated(self.query, self.model, db).await .exec_update_and_return_updated(self.model, db)
.await
} }
} }
@ -39,12 +39,11 @@ where
E: EntityTrait, E: EntityTrait,
{ {
/// Execute an update operation on multiple ActiveModels /// Execute an update operation on multiple ActiveModels
pub fn exec<C>(self, db: &'a C) -> impl Future<Output = Result<UpdateResult, DbErr>> + '_ pub async fn exec<C>(self, db: &'a C) -> Result<UpdateResult, DbErr>
where where
C: ConnectionTrait, C: ConnectionTrait,
{ {
// so that self is dropped before entering await Updater::new(self.query).exec(db).await
exec_update_only(self.query, db)
} }
} }
@ -64,24 +63,76 @@ impl Updater {
} }
/// Execute an update operation /// Execute an update operation
pub fn exec<C>(self, db: &C) -> impl Future<Output = Result<UpdateResult, DbErr>> + '_ pub async fn exec<C>(self, db: &C) -> Result<UpdateResult, DbErr>
where where
C: ConnectionTrait, C: ConnectionTrait,
{ {
if self.is_noop() {
return Ok(UpdateResult::default());
}
let builder = db.get_database_backend(); let builder = db.get_database_backend();
exec_update(builder.build(&self.query), db, self.check_record_exists) let statement = builder.build(&self.query);
let result = db.execute(statement).await?;
if self.check_record_exists && result.rows_affected() == 0 {
return Err(DbErr::RecordNotFound(
"None of the database rows are affected".to_owned(),
));
}
Ok(UpdateResult {
rows_affected: result.rows_affected(),
})
}
async fn exec_update_and_return_updated<A, C>(
mut self,
model: A,
db: &C,
) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
where
A: ActiveModelTrait,
C: ConnectionTrait,
{
type Entity<A> = <A as ActiveModelTrait>::Entity;
type Model<A> = <Entity<A> as EntityTrait>::Model;
type Column<A> = <Entity<A> as EntityTrait>::Column;
if self.is_noop() {
return find_updated_model_by_id(model, db).await;
}
match db.support_returning() {
true => {
let returning = Query::returning()
.exprs(Column::<A>::iter().map(|c| c.select_as(Expr::col(c))));
self.query.returning(returning);
let db_backend = db.get_database_backend();
let found: Option<Model<A>> = SelectorRaw::<SelectModel<Model<A>>>::from_statement(
db_backend.build(&self.query),
)
.one(db)
.await?;
// If we got `None` then we are updating a row that does not exist.
match found {
Some(model) => Ok(model),
None => Err(DbErr::RecordNotFound(
"None of the database rows are affected".to_owned(),
)),
}
}
false => {
// If we updating a row that does not exist then an error will be thrown here.
self.check_record_exists().exec(db).await?;
find_updated_model_by_id(model, db).await
}
}
}
fn is_noop(&self) -> bool {
self.query.get_values().is_empty()
} }
} }
async fn exec_update_only<C>(query: UpdateStatement, db: &C) -> Result<UpdateResult, DbErr> async fn find_updated_model_by_id<A, C>(
where
C: ConnectionTrait,
{
Updater::new(query).exec(db).await
}
async fn exec_update_and_return_updated<A, C>(
mut query: UpdateStatement,
model: A, model: A,
db: &C, db: &C,
) -> Result<<A::Entity as EntityTrait>::Model, DbErr> ) -> Result<<A::Entity as EntityTrait>::Model, DbErr>
@ -90,63 +141,20 @@ where
C: ConnectionTrait, C: ConnectionTrait,
{ {
type Entity<A> = <A as ActiveModelTrait>::Entity; type Entity<A> = <A as ActiveModelTrait>::Entity;
type Model<A> = <Entity<A> as EntityTrait>::Model;
type Column<A> = <Entity<A> as EntityTrait>::Column;
type ValueType<A> = <<Entity<A> as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType; type ValueType<A> = <<Entity<A> as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType;
match db.support_returning() {
true => {
let returning =
Query::returning().exprs(Column::<A>::iter().map(|c| c.select_as(Expr::col(c))));
query.returning(returning);
let db_backend = db.get_database_backend();
let found: Option<Model<A>> =
SelectorRaw::<SelectModel<Model<A>>>::from_statement(db_backend.build(&query))
.one(db)
.await?;
// If we got `None` then we are updating a row that does not exist.
match found {
Some(model) => Ok(model),
None => Err(DbErr::RecordNotFound(
"None of the database rows are affected".to_owned(),
)),
}
}
false => {
// If we updating a row that does not exist then an error will be thrown here.
Updater::new(query).check_record_exists().exec(db).await?;
let primary_key_value = match model.get_primary_key_value() {
Some(val) => ValueType::<A>::from_value_tuple(val),
None => return Err(DbErr::UpdateGetPrimaryKey),
};
let found = Entity::<A>::find_by_id(primary_key_value).one(db).await?;
// If we cannot select the updated row from db by the cached primary key
match found {
Some(model) => Ok(model),
None => Err(DbErr::RecordNotFound(
"Failed to find updated item".to_owned(),
)),
}
}
}
}
async fn exec_update<C>( let primary_key_value = match model.get_primary_key_value() {
statement: Statement, Some(val) => ValueType::<A>::from_value_tuple(val),
db: &C, None => return Err(DbErr::UpdateGetPrimaryKey),
check_record_exists: bool, };
) -> Result<UpdateResult, DbErr> let found = Entity::<A>::find_by_id(primary_key_value).one(db).await?;
where // If we cannot select the updated row from db by the cached primary key
C: ConnectionTrait, match found {
{ Some(model) => Ok(model),
let result = db.execute(statement).await?; None => Err(DbErr::RecordNotFound(
if check_record_exists && result.rows_affected() == 0 { "Failed to find updated item".to_owned(),
return Err(DbErr::RecordNotFound( )),
"None of the database rows are affected".to_owned(),
));
} }
Ok(UpdateResult {
rows_affected: result.rows_affected(),
})
} }
#[cfg(test)] #[cfg(test)]
@ -157,15 +165,20 @@ mod tests {
#[smol_potat::test] #[smol_potat::test]
async fn update_record_not_found_1() -> Result<(), DbErr> { async fn update_record_not_found_1() -> Result<(), DbErr> {
let updated_cake = cake::Model {
id: 1,
name: "Cheese Cake".to_owned(),
};
let db = MockDatabase::new(DbBackend::Postgres) let db = MockDatabase::new(DbBackend::Postgres)
.append_query_results([ .append_query_results([
vec![cake::Model { vec![updated_cake.clone()],
id: 1,
name: "Cheese Cake".to_owned(),
}],
vec![], vec![],
vec![], vec![],
vec![], vec![],
vec![updated_cake.clone()],
vec![updated_cake.clone()],
vec![updated_cake.clone()],
]) ])
.append_exec_results([MockExecResult { .append_exec_results([MockExecResult {
last_insert_id: 0, last_insert_id: 0,
@ -181,7 +194,7 @@ mod tests {
assert_eq!( assert_eq!(
cake::ActiveModel { cake::ActiveModel {
name: Set("Cheese Cake".to_owned()), name: Set("Cheese Cake".to_owned()),
..model.into_active_model() ..model.clone().into_active_model()
} }
.update(&db) .update(&db)
.await?, .await?,
@ -223,7 +236,7 @@ mod tests {
assert_eq!( assert_eq!(
Update::one(cake::ActiveModel { Update::one(cake::ActiveModel {
name: Set("Cheese Cake".to_owned()), name: Set("Cheese Cake".to_owned()),
..model.into_active_model() ..model.clone().into_active_model()
}) })
.exec(&db) .exec(&db)
.await, .await,
@ -241,6 +254,28 @@ mod tests {
Ok(UpdateResult { rows_affected: 0 }) Ok(UpdateResult { rows_affected: 0 })
); );
assert_eq!(
updated_cake.clone().into_active_model().save(&db).await?,
updated_cake.clone().into_active_model()
);
assert_eq!(
updated_cake.clone().into_active_model().update(&db).await?,
updated_cake
);
assert_eq!(
cake::Entity::update(updated_cake.clone().into_active_model())
.exec(&db)
.await?,
updated_cake
);
assert_eq!(
cake::Entity::update_many().exec(&db).await?.rows_affected,
0
);
assert_eq!( assert_eq!(
db.into_transaction_log(), db.into_transaction_log(),
[ [
@ -269,6 +304,21 @@ mod tests {
r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#,
["Cheese Cake".into(), 2i32.into()] ["Cheese Cake".into(), 2i32.into()]
), ),
Transaction::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
[1.into(), 1u64.into()]
),
Transaction::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
[1.into(), 1u64.into()]
),
Transaction::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#,
[1.into(), 1u64.into()]
),
] ]
); );