From 03207fbda9fc9eb7239d6c558d6d031539fc7730 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Wed, 18 Jan 2023 21:11:35 +0800 Subject: [PATCH] Use `rows_affected` when DB does not support returning --- src/error.rs | 4 ++-- src/executor/execute.rs | 2 +- src/executor/insert.rs | 14 ++++++++------ tests/upsert_tests.rs | 7 +------ 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/error.rs b/src/error.rs index 7988bd1b..2ed87ffd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -58,8 +58,8 @@ pub enum DbErr { /// None of the records are being inserted into the database, /// if you insert with upsert expression that means /// all of them conflict with existing records in the database - #[error("RecordNotInserted Error: {0}")] - RecordNotInserted(String), + #[error("None of the records are being inserted")] + RecordNotInserted, } /// Runtime error diff --git a/src/executor/execute.rs b/src/executor/execute.rs index 3da4ec8a..f3a7150d 100644 --- a/src/executor/execute.rs +++ b/src/executor/execute.rs @@ -52,7 +52,7 @@ impl ExecResult { } } - /// Get the number of rows affedted by the operation + /// Get the number of rows affected by the operation pub fn rows_affected(&self) -> u64 { match &self.result { #[cfg(feature = "sqlx-mysql")] diff --git a/src/executor/insert.rs b/src/executor/insert.rs index 882f3df6..c4c988ba 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -140,18 +140,17 @@ where let last_insert_id = match (primary_key, db.support_returning()) { (Some(value_tuple), _) => { - db.execute(statement).await?; + let res = db.execute(statement).await?; + if res.rows_affected() == 0 { + return Err(DbErr::RecordNotInserted); + } FromValueTuple::from_value_tuple(value_tuple) } (None, true) => { let mut rows = db.query_all(statement).await?; let row = match rows.pop() { Some(row) => row, - None => { - return Err(DbErr::RecordNotInserted( - "None of the records are being inserted".to_owned(), - )) - } + None => return Err(DbErr::RecordNotInserted), }; let cols = PrimaryKey::::iter() .map(|col| col.to_string()) @@ -161,6 +160,9 @@ where } (None, false) => { let res = db.execute(statement).await?; + if res.rows_affected() == 0 { + return Err(DbErr::RecordNotInserted); + } let last_insert_id = res.last_insert_id(); ValueTypeOf::::try_from_u64(last_insert_id).map_err(|_| DbErr::UnpackInsertId)? } diff --git a/tests/upsert_tests.rs b/tests/upsert_tests.rs index ad874e79..748d5b0a 100644 --- a/tests/upsert_tests.rs +++ b/tests/upsert_tests.rs @@ -54,12 +54,7 @@ pub async fn create_insert_default(db: &DatabaseConnection) -> Result<(), DbErr> .exec(db) .await; - assert_eq!( - res.err(), - Some(DbErr::RecordNotInserted( - "None of the records are being inserted".to_owned() - )) - ); + assert_eq!(res.err(), Some(DbErr::RecordNotInserted)); Ok(()) }