diff --git a/src/executor/update.rs b/src/executor/update.rs index 2bc5ed80..b564165c 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -7,9 +7,10 @@ use std::future::Future; #[derive(Clone, Debug)] pub struct Updater { query: UpdateStatement, + check_record_exists: bool, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct UpdateResult { pub rows_affected: u64, } @@ -39,7 +40,15 @@ where impl Updater { pub fn new(query: UpdateStatement) -> Self { - Self { query } + Self { + query, + check_record_exists: false, + } + } + + pub fn check_record_exists(mut self) -> Self { + self.check_record_exists = true; + self } pub fn exec( @@ -47,7 +56,7 @@ impl Updater { db: &DatabaseConnection, ) -> impl Future> + '_ { let builder = db.get_database_backend(); - exec_update(builder.build(&self.query), db) + exec_update(builder.build(&self.query), db, self.check_record_exists) } } @@ -66,14 +75,18 @@ async fn exec_update_and_return_original( where A: ActiveModelTrait, { - Updater::new(query).exec(db).await?; + Updater::new(query).check_record_exists().exec(db).await?; Ok(model) } // Only Statement impl Send -async fn exec_update(statement: Statement, db: &DatabaseConnection) -> Result { +async fn exec_update( + statement: Statement, + db: &DatabaseConnection, + check_record_exists: bool, +) -> Result { let result = db.execute(statement).await?; - if result.rows_affected() == 0 { + if check_record_exists && result.rows_affected() == 0 { return Err(DbErr::RecordNotFound( "None of the database rows are affected".to_owned(), )); @@ -87,6 +100,7 @@ async fn exec_update(statement: Statement, db: &DatabaseConnection) -> Result Result<(), DbErr> { @@ -104,6 +118,14 @@ mod tests { last_insert_id: 0, rows_affected: 0, }, + MockExecResult { + last_insert_id: 0, + rows_affected: 0, + }, + MockExecResult { + last_insert_id: 0, + rows_affected: 0, + }, ]) .into_connection(); @@ -145,6 +167,18 @@ mod tests { assert_eq!( cake::Entity::update(cake::ActiveModel { + name: Set("Cheese Cake".to_owned()), + ..model.clone().into_active_model() + }) + .exec(&db) + .await, + Err(DbErr::RecordNotFound( + "None of the database rows are affected".to_owned() + )) + ); + + assert_eq!( + Update::one(cake::ActiveModel { name: Set("Cheese Cake".to_owned()), ..model.into_active_model() }) @@ -155,6 +189,15 @@ mod tests { )) ); + assert_eq!( + Update::many(cake::Entity) + .col_expr(cake::Column::Name, Expr::value("Cheese Cake".to_owned())) + .filter(cake::Column::Id.eq(2)) + .exec(&db) + .await, + Ok(UpdateResult { rows_affected: 0 }) + ); + assert_eq!( db.into_transaction_log(), vec![ @@ -173,6 +216,16 @@ mod tests { r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, vec!["Cheese Cake".into(), 2i32.into()] ), + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, + vec!["Cheese Cake".into(), 2i32.into()] + ), + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, + vec!["Cheese Cake".into(), 2i32.into()] + ), ] );