diff --git a/Cargo.toml b/Cargo.toml index 0c8de3ef..15c2b240 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ futures-util = { version = "^0.3" } log = { version = "^0.4", optional = true } rust_decimal = { version = "^1", optional = true } sea-orm-macros = { version = "^0.2.6", path = "sea-orm-macros", optional = true } -sea-query = { version = "^0.17.0", features = ["thread-safe"] } +sea-query = { version = "^0.17.1", features = ["thread-safe"] } sea-strum = { version = "^0.21", features = ["derive", "sea-orm"] } serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index b02f4408..db3fa4e2 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -1,5 +1,5 @@ use sqlx::{ - sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}, + sqlite::{SqliteArguments, SqlitePoolOptions, SqliteQueryResult, SqliteRow}, Sqlite, SqlitePool, }; @@ -24,7 +24,11 @@ impl SqlxSqliteConnector { } pub async fn connect(string: &str) -> Result { - if let Ok(pool) = SqlitePool::connect(string).await { + if let Ok(pool) = SqlitePoolOptions::new() + .max_connections(1) + .connect(string) + .await + { Ok(DatabaseConnection::SqlxSqlitePoolConnection( SqlxSqlitePoolConnection { pool }, )) diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index 6f4b86f6..c5b1db97 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -1,8 +1,8 @@ use crate::{ - error::*, DatabaseConnection, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, - PrimaryKeyTrait, Value, + error::*, DatabaseConnection, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, Value, }; use async_trait::async_trait; +use sea_query::ValueTuple; use std::fmt::Debug; #[derive(Clone, Debug, Default)] @@ -10,7 +10,8 @@ pub struct ActiveValue where V: Into, { - value: Option, + // Don't want to call ActiveValue::unwrap() and cause panic + pub(self) value: Option, state: ActiveValueState, } @@ -67,6 +68,42 @@ pub trait ActiveModelTrait: Clone + Debug { fn default() -> Self; + #[allow(clippy::question_mark)] + fn get_primary_key_value(&self) -> Option { + let mut cols = ::PrimaryKey::iter(); + macro_rules! next { + () => { + if let Some(col) = cols.next() { + if let Some(val) = self.get(col.into_column()).value { + val + } else { + return None; + } + } else { + return None; + } + }; + } + match ::PrimaryKey::iter().count() { + 1 => { + let s1 = next!(); + Some(ValueTuple::One(s1)) + } + 2 => { + let s1 = next!(); + let s2 = next!(); + Some(ValueTuple::Two(s1, s2)) + } + 3 => { + let s1 = next!(); + let s2 = next!(); + let s3 = next!(); + Some(ValueTuple::Three(s1, s2, s3)) + } + _ => panic!("The arity cannot be larger than 3"), + } + } + async fn insert(self, db: &DatabaseConnection) -> Result where ::Model: IntoActiveModel, @@ -74,19 +111,12 @@ pub trait ActiveModelTrait: Clone + Debug { let am = self; let exec = ::insert(am).exec(db); let res = exec.await?; - // Assume valid last_insert_id is not equals to Default::default() - if res.last_insert_id - != <::PrimaryKey as PrimaryKeyTrait>::ValueType::default() - { - let found = ::find_by_id(res.last_insert_id) - .one(db) - .await?; - match found { - Some(model) => Ok(model.into_active_model()), - None => Err(DbErr::Exec("Failed to find inserted item".to_owned())), - } - } else { - Ok(Self::default()) + let found = ::find_by_id(res.last_insert_id) + .one(db) + .await?; + match found { + Some(model) => Ok(model.into_active_model()), + None => Err(DbErr::Exec("Failed to find inserted item".to_owned())), } } diff --git a/src/entity/primary_key.rs b/src/entity/primary_key.rs index 463f1482..a5e4cde0 100644 --- a/src/entity/primary_key.rs +++ b/src/entity/primary_key.rs @@ -1,16 +1,16 @@ use super::{ColumnTrait, IdenStatic, Iterable}; use crate::{TryFromU64, TryGetableMany}; -use sea_query::IntoValueTuple; +use sea_query::{FromValueTuple, IntoValueTuple}; use std::fmt::Debug; //LINT: composite primary key cannot auto increment pub trait PrimaryKeyTrait: IdenStatic + Iterable { type ValueType: Sized + Send - + Default + Debug + PartialEq + IntoValueTuple + + FromValueTuple + TryGetableMany + TryFromU64; diff --git a/src/executor/insert.rs b/src/executor/insert.rs index a44867f7..02d02c0b 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -2,14 +2,15 @@ use crate::{ error::*, ActiveModelTrait, DatabaseConnection, DbBackend, EntityTrait, Insert, PrimaryKeyTrait, Statement, TryFromU64, }; -use sea_query::InsertStatement; +use sea_query::{FromValueTuple, InsertStatement, ValueTuple}; use std::{future::Future, marker::PhantomData}; -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Inserter where A: ActiveModelTrait, { + primary_key: Option, query: InsertStatement, model: PhantomData, } @@ -34,7 +35,6 @@ where where A: 'a, { - // TODO: extract primary key's value from query // so that self is dropped before entering await let mut query = self.query; if db.get_database_backend() == DbBackend::Postgres { @@ -47,8 +47,7 @@ where ); } } - Inserter::::new(query).exec(db) - // TODO: return primary key if extracted before, otherwise use InsertResult + Inserter::::new(self.primary_key, query).exec(db) } } @@ -56,8 +55,9 @@ impl Inserter where A: ActiveModelTrait, { - pub fn new(query: InsertStatement) -> Self { + pub fn new(primary_key: Option, query: InsertStatement) -> Self { Self { + primary_key, query, model: PhantomData, } @@ -71,12 +71,13 @@ where A: 'a, { let builder = db.get_database_backend(); - exec_insert(builder.build(&self.query), db) + exec_insert(self.primary_key, builder.build(&self.query), db) } } // Only Statement impl Send async fn exec_insert( + primary_key: Option, statement: Statement, db: &DatabaseConnection, ) -> Result, DbErr> @@ -85,21 +86,26 @@ where { type PrimaryKey = <::Entity as EntityTrait>::PrimaryKey; type ValueTypeOf = as PrimaryKeyTrait>::ValueType; - let last_insert_id = match db.get_database_backend() { + let last_insert_id_opt = match db.get_database_backend() { DbBackend::Postgres => { use crate::{sea_query::Iden, Iterable}; let cols = PrimaryKey::::iter() .map(|col| col.to_string()) .collect::>(); let res = db.query_one(statement).await?.unwrap(); - res.try_get_many("", cols.as_ref()).unwrap_or_default() + res.try_get_many("", cols.as_ref()).ok() } _ => { let last_insert_id = db.execute(statement).await?.last_insert_id(); - ValueTypeOf::::try_from_u64(last_insert_id) - .ok() - .unwrap_or_default() + ValueTypeOf::::try_from_u64(last_insert_id).ok() } }; + let last_insert_id = match last_insert_id_opt { + Some(last_insert_id) => last_insert_id, + None => match primary_key { + Some(value_tuple) => FromValueTuple::from_value_tuple(value_tuple), + None => return Err(DbErr::Exec("Fail to unpack last_insert_id".to_owned())), + }, + }; Ok(InsertResult { last_insert_id }) } diff --git a/src/query/insert.rs b/src/query/insert.rs index f2d60dc8..5e504a0c 100644 --- a/src/query/insert.rs +++ b/src/query/insert.rs @@ -1,14 +1,18 @@ -use crate::{ActiveModelTrait, EntityName, EntityTrait, IntoActiveModel, Iterable, QueryTrait}; +use crate::{ + ActiveModelTrait, EntityName, EntityTrait, IntoActiveModel, Iterable, PrimaryKeyTrait, + QueryTrait, +}; use core::marker::PhantomData; -use sea_query::InsertStatement; +use sea_query::{InsertStatement, ValueTuple}; -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Insert where A: ActiveModelTrait, { pub(crate) query: InsertStatement, pub(crate) columns: Vec, + pub(crate) primary_key: Option, pub(crate) model: PhantomData, } @@ -31,6 +35,7 @@ where .into_table(A::Entity::default().table_ref()) .to_owned(), columns: Vec::new(), + primary_key: None, model: PhantomData, } } @@ -107,6 +112,12 @@ where M: IntoActiveModel, { let mut am: A = m.into_active_model(); + self.primary_key = + if !<::PrimaryKey as PrimaryKeyTrait>::auto_increment() { + am.get_primary_key_value() + } else { + None + }; let mut columns = Vec::new(); let mut values = Vec::new(); let columns_empty = self.columns.is_empty(); diff --git a/tests/crud/create_cake.rs b/tests/crud/create_cake.rs index 4fa914a5..df5130aa 100644 --- a/tests/crud/create_cake.rs +++ b/tests/crud/create_cake.rs @@ -58,11 +58,7 @@ pub async fn test_create_cake(db: &DbConn) { .expect("could not insert cake_baker"); assert_eq!( cake_baker_res.last_insert_id, - if cfg!(feature = "sqlx-postgres") { - (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) - } else { - Default::default() - } + (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) ); assert!(cake.is_some()); diff --git a/tests/crud/create_lineitem.rs b/tests/crud/create_lineitem.rs index da82cc82..0ba9c7d3 100644 --- a/tests/crud/create_lineitem.rs +++ b/tests/crud/create_lineitem.rs @@ -57,11 +57,7 @@ pub async fn test_create_lineitem(db: &DbConn) { .expect("could not insert cake_baker"); assert_eq!( cake_baker_res.last_insert_id, - if cfg!(feature = "sqlx-postgres") { - (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) - } else { - Default::default() - } + (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) ); // Customer diff --git a/tests/crud/create_order.rs b/tests/crud/create_order.rs index ba8ff09b..6de3d46f 100644 --- a/tests/crud/create_order.rs +++ b/tests/crud/create_order.rs @@ -57,11 +57,7 @@ pub async fn test_create_order(db: &DbConn) { .expect("could not insert cake_baker"); assert_eq!( cake_baker_res.last_insert_id, - if cfg!(feature = "sqlx-postgres") { - (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) - } else { - Default::default() - } + (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) ); // Customer diff --git a/tests/sequential_op_tests.rs b/tests/sequential_op_tests.rs index 28333d84..286e856a 100644 --- a/tests/sequential_op_tests.rs +++ b/tests/sequential_op_tests.rs @@ -84,11 +84,7 @@ async fn init_setup(db: &DatabaseConnection) { .expect("could not insert cake_baker"); assert_eq!( cake_baker_res.last_insert_id, - if cfg!(feature = "sqlx-postgres") { - (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) - } else { - Default::default() - } + (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) ); let customer_kate = customer::ActiveModel { @@ -225,11 +221,7 @@ async fn create_cake(db: &DatabaseConnection, baker: baker::Model) -> Option Result<(), DbErr> { assert_eq!(Metadata::find().one(db).await?, Some(metadata.clone())); - assert_eq!( - res.last_insert_id, - if cfg!(feature = "sqlx-postgres") { - metadata.uuid - } else { - Default::default() - } - ); + assert_eq!(res.last_insert_id, metadata.uuid); let update_res = Metadata::update(metadata::ActiveModel { value: Set("0.22".to_owned()),