diff --git a/src/executor/insert.rs b/src/executor/insert.rs
index 796cd362..9a686997 100644
--- a/src/executor/insert.rs
+++ b/src/executor/insert.rs
@@ -1,6 +1,7 @@
use crate::{
- error::*, ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, Insert, IntoActiveModel,
- Iterable, PrimaryKeyTrait, SelectModel, SelectorRaw, Statement, TryFromU64,
+ error::*, ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, ExecResult, Insert,
+ IntoActiveModel, Iterable, PrimaryKeyTrait, QueryResult, SelectModel, SelectorRaw, Statement,
+ TryFromU64,
};
use sea_query::{Expr, FromValueTuple, Iden, InsertStatement, IntoColumnRef, Query, ValueTuple};
use std::{future::Future, marker::PhantomData};
@@ -137,29 +138,49 @@ where
{
type PrimaryKey = <::Entity as EntityTrait>::PrimaryKey;
type ValueTypeOf = as PrimaryKeyTrait>::ValueType;
- let last_insert_id_opt = match db.support_returning() {
- true => {
- let cols = PrimaryKey::::iter()
- .map(|col| col.to_string())
- .collect::>();
- let rows = db.query_all(statement).await?;
- let res = rows.last().ok_or_else(|| {
- DbErr::RecordNotInserted("None of the records are being inserted".to_owned())
- })?;
- res.try_get_many("", cols.as_ref()).ok()
- }
- false => {
- let last_insert_id = db.execute(statement).await?.last_insert_id();
- ValueTypeOf::::try_from_u64(last_insert_id).ok()
+
+ enum QueryOrExecResult {
+ Query(QueryResult),
+ Exec(ExecResult),
+ }
+
+ let insert_result = if db.support_returning() {
+ let mut rows = db.query_all(statement).await?;
+ if rows.is_empty() {
+ return Err(DbErr::RecordNotInserted(
+ "None of the records are being inserted".to_owned(),
+ ));
}
+ QueryOrExecResult::Query(rows.remove(rows.len() - 1))
+ } else {
+ QueryOrExecResult::Exec(db.execute(statement).await?)
};
+
let last_insert_id = match primary_key {
- Some(value_tuple) => FromValueTuple::from_value_tuple(value_tuple),
- None => match last_insert_id_opt {
- Some(last_insert_id) => last_insert_id,
- None => return Err(DbErr::UnpackInsertId),
- },
- };
+ Some(value_tuple) => Ok(FromValueTuple::from_value_tuple(value_tuple)),
+ None => {
+ if db.support_returning() {
+ match insert_result {
+ QueryOrExecResult::Query(row) => {
+ let cols = PrimaryKey::::iter()
+ .map(|col| col.to_string())
+ .collect::>();
+ row.try_get_many("", cols.as_ref())
+ }
+ _ => unreachable!(),
+ }
+ } else {
+ match insert_result {
+ QueryOrExecResult::Exec(res) => {
+ let last_insert_id = res.last_insert_id();
+ ValueTypeOf::::try_from_u64(last_insert_id)
+ }
+ _ => unreachable!(),
+ }
+ }
+ }
+ }
+ .map_err(|_| DbErr::UnpackInsertId)?;
Ok(InsertResult { last_insert_id })
}