diff --git a/README.md b/README.md index 524aa896..82696f29 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ let pear = fruit::ActiveModel { }; // insert one -let res: InsertResult = Fruit::insert(pear).exec(db).await?; +let res = Fruit::insert(pear).exec(db).await?; println!("InsertResult: {}", res.last_insert_id); diff --git a/examples/async-std/src/example_cake.rs b/examples/async-std/src/example_cake.rs index 475315e8..114347ee 100644 --- a/examples/async-std/src/example_cake.rs +++ b/examples/async-std/src/example_cake.rs @@ -27,6 +27,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/examples/async-std/src/example_cake_filling.rs b/examples/async-std/src/example_cake_filling.rs index 19de83e4..4fa188bc 100644 --- a/examples/async-std/src/example_cake_filling.rs +++ b/examples/async-std/src/example_cake_filling.rs @@ -28,6 +28,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = (i32, i32); + fn auto_increment() -> bool { false } diff --git a/examples/async-std/src/example_filling.rs b/examples/async-std/src/example_filling.rs index 925b92fc..2a39a7de 100644 --- a/examples/async-std/src/example_filling.rs +++ b/examples/async-std/src/example_filling.rs @@ -27,6 +27,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/examples/async-std/src/example_fruit.rs b/examples/async-std/src/example_fruit.rs index b875da24..8802e707 100644 --- a/examples/async-std/src/example_fruit.rs +++ b/examples/async-std/src/example_fruit.rs @@ -29,6 +29,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/examples/async-std/src/operation.rs b/examples/async-std/src/operation.rs index b1273e10..3fa5cc85 100644 --- a/examples/async-std/src/operation.rs +++ b/examples/async-std/src/operation.rs @@ -20,7 +20,7 @@ pub async fn insert_and_update(db: &DbConn) -> Result<(), DbErr> { name: Set("pear".to_owned()), ..Default::default() }; - let res: InsertResult = Fruit::insert(pear).exec(db).await?; + let res = Fruit::insert(pear).exec(db).await?; println!(); println!("Inserted: last_insert_id = {}\n", res.last_insert_id); diff --git a/examples/tokio/src/cake.rs b/examples/tokio/src/cake.rs index 0b1a4439..21b83331 100644 --- a/examples/tokio/src/cake.rs +++ b/examples/tokio/src/cake.rs @@ -27,6 +27,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/sea-orm-codegen/src/entity/base_entity.rs b/sea-orm-codegen/src/entity/base_entity.rs index f1e03b21..7b28f9e4 100644 --- a/sea-orm-codegen/src/entity/base_entity.rs +++ b/sea-orm-codegen/src/entity/base_entity.rs @@ -117,6 +117,31 @@ impl Entity { format_ident!("{}", auto_increment) } + pub fn get_primary_key_rs_type(&self) -> TokenStream { + let types = self + .primary_keys + .iter() + .map(|primary_key| { + self.columns + .iter() + .find(|col| col.name.eq(&primary_key.name)) + .unwrap() + .get_rs_type() + .to_string() + }) + .collect::>(); + if !types.is_empty() { + let value_type = if types.len() > 1 { + vec!["(".to_owned(), types.join(", "), ")".to_owned()] + } else { + types + }; + value_type.join("").parse().unwrap() + } else { + TokenStream::new() + } + } + pub fn get_conjunct_relations_via_snake_case(&self) -> Vec { self.conjunct_relations .iter() @@ -151,7 +176,7 @@ mod tests { columns: vec![ Column { name: "id".to_owned(), - col_type: ColumnType::String(None), + col_type: ColumnType::Integer(None), auto_increment: false, not_null: false, unique: false, @@ -373,6 +398,16 @@ mod tests { ); } + #[test] + fn test_get_primary_key_rs_type() { + let entity = setup(); + + assert_eq!( + entity.get_primary_key_rs_type().to_string(), + entity.columns[0].get_rs_type().to_string() + ); + } + #[test] fn test_get_conjunct_relations_via_snake_case() { let entity = setup(); diff --git a/sea-orm-codegen/src/entity/writer.rs b/sea-orm-codegen/src/entity/writer.rs index 207864ba..7accac2d 100644 --- a/sea-orm-codegen/src/entity/writer.rs +++ b/sea-orm-codegen/src/entity/writer.rs @@ -173,8 +173,11 @@ impl EntityWriter { pub fn gen_impl_primary_key(entity: &Entity) -> TokenStream { let primary_key_auto_increment = entity.get_primary_key_auto_increment(); + let value_type = entity.get_primary_key_rs_type(); quote! { impl PrimaryKeyTrait for PrimaryKey { + type ValueType = #value_type; + fn auto_increment() -> bool { #primary_key_auto_increment } diff --git a/sea-orm-codegen/tests/entity/cake.rs b/sea-orm-codegen/tests/entity/cake.rs index 29f55ac6..55fa279f 100644 --- a/sea-orm-codegen/tests/entity/cake.rs +++ b/sea-orm-codegen/tests/entity/cake.rs @@ -29,6 +29,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/sea-orm-codegen/tests/entity/cake_filling.rs b/sea-orm-codegen/tests/entity/cake_filling.rs index d0f00560..d100fa8c 100644 --- a/sea-orm-codegen/tests/entity/cake_filling.rs +++ b/sea-orm-codegen/tests/entity/cake_filling.rs @@ -30,6 +30,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = (i32, i32); + fn auto_increment() -> bool { false } diff --git a/sea-orm-codegen/tests/entity/filling.rs b/sea-orm-codegen/tests/entity/filling.rs index bedb1ab4..73d58152 100644 --- a/sea-orm-codegen/tests/entity/filling.rs +++ b/sea-orm-codegen/tests/entity/filling.rs @@ -29,6 +29,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/sea-orm-codegen/tests/entity/fruit.rs b/sea-orm-codegen/tests/entity/fruit.rs index 72c37c1b..12919021 100644 --- a/sea-orm-codegen/tests/entity/fruit.rs +++ b/sea-orm-codegen/tests/entity/fruit.rs @@ -31,6 +31,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/sea-orm-codegen/tests/entity/vendor.rs b/sea-orm-codegen/tests/entity/vendor.rs index 320400f4..ffc211a8 100644 --- a/sea-orm-codegen/tests/entity/vendor.rs +++ b/sea-orm-codegen/tests/entity/vendor.rs @@ -31,6 +31,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index a283ddf3..b3d71f9a 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -102,24 +102,11 @@ impl From for QueryResult { impl From for ExecResult { fn from(result: PgQueryResult) -> ExecResult { ExecResult { - result: ExecResultHolder::SqlxPostgres { - last_insert_id: 0, - rows_affected: result.rows_affected(), - }, + result: ExecResultHolder::SqlxPostgres(result), } } } -pub(crate) fn query_result_into_exec_result(res: QueryResult) -> Result { - let last_insert_id: i32 = res.try_get("", "last_insert_id")?; - Ok(ExecResult { - result: ExecResultHolder::SqlxPostgres { - last_insert_id: last_insert_id as u64, - rows_affected: 0, - }, - }) -} - fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Postgres, PgArguments> { let mut query = sqlx::query(&stmt.sql); if let Some(values) = &stmt.values { diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index 54bf037b..f823411c 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -221,16 +221,18 @@ where let exec = E::insert(am).exec(db); let res = exec.await?; // TODO: if the entity does not have auto increment primary key, then last_insert_id is a wrong value - if ::auto_increment() && res.last_insert_id != 0 { + // FIXME: Assumed valid last_insert_id is not equals to Default::default() + if ::auto_increment() + && res.last_insert_id != ::ValueType::default() + { let find = E::find_by_id(res.last_insert_id).one(db); let found = find.await; let model: Option = found?; match model { Some(model) => Ok(model.into_active_model()), None => Err(DbErr::Exec(format!( - "Failed to find inserted item: {} {}", + "Failed to find inserted item: {}", E::default().to_string(), - res.last_insert_id ))), } } else { diff --git a/src/entity/primary_key.rs b/src/entity/primary_key.rs index 81a28915..b9c381d4 100644 --- a/src/entity/primary_key.rs +++ b/src/entity/primary_key.rs @@ -1,7 +1,18 @@ use super::{ColumnTrait, IdenStatic, Iterable}; +use crate::{TryFromU64, TryGetableMany}; +use sea_query::IntoValueTuple; +use std::fmt::Debug; //LINT: composite primary key cannot auto increment pub trait PrimaryKeyTrait: IdenStatic + Iterable { + type ValueType: Sized + + Default + + Debug + + PartialEq + + IntoValueTuple + + TryGetableMany + + TryFromU64; + fn auto_increment() -> bool; } diff --git a/src/error.rs b/src/error.rs index 8a695dac..09f80b0a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,4 @@ -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum DbErr { Conn(String), Exec(String), diff --git a/src/executor/execute.rs b/src/executor/execute.rs index 00375bb7..46ba2d69 100644 --- a/src/executor/execute.rs +++ b/src/executor/execute.rs @@ -8,10 +8,7 @@ pub(crate) enum ExecResultHolder { #[cfg(feature = "sqlx-mysql")] SqlxMySql(sqlx::mysql::MySqlQueryResult), #[cfg(feature = "sqlx-postgres")] - SqlxPostgres { - last_insert_id: u64, - rows_affected: u64, - }, + SqlxPostgres(sqlx::postgres::PgQueryResult), #[cfg(feature = "sqlx-sqlite")] SqlxSqlite(sqlx::sqlite::SqliteQueryResult), #[cfg(feature = "mock")] @@ -26,7 +23,9 @@ impl ExecResult { #[cfg(feature = "sqlx-mysql")] ExecResultHolder::SqlxMySql(result) => result.last_insert_id(), #[cfg(feature = "sqlx-postgres")] - ExecResultHolder::SqlxPostgres { last_insert_id, .. } => last_insert_id.to_owned(), + ExecResultHolder::SqlxPostgres(_) => { + panic!("Should not retrieve last_insert_id this way") + } #[cfg(feature = "sqlx-sqlite")] ExecResultHolder::SqlxSqlite(result) => { let last_insert_rowid = result.last_insert_rowid(); @@ -46,7 +45,7 @@ impl ExecResult { #[cfg(feature = "sqlx-mysql")] ExecResultHolder::SqlxMySql(result) => result.rows_affected(), #[cfg(feature = "sqlx-postgres")] - ExecResultHolder::SqlxPostgres { rows_affected, .. } => rows_affected.to_owned(), + ExecResultHolder::SqlxPostgres(result) => result.rows_affected(), #[cfg(feature = "sqlx-sqlite")] ExecResultHolder::SqlxSqlite(result) => result.rows_affected(), #[cfg(feature = "mock")] diff --git a/src/executor/insert.rs b/src/executor/insert.rs index 9414fea6..90065675 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,15 +1,25 @@ -use crate::{error::*, ActiveModelTrait, DatabaseConnection, Insert, Statement}; +use crate::{ + error::*, ActiveModelTrait, DatabaseConnection, EntityTrait, Insert, PrimaryKeyTrait, + Statement, TryFromU64, +}; use sea_query::InsertStatement; -use std::future::Future; +use std::{future::Future, marker::PhantomData}; #[derive(Clone, Debug)] -pub struct Inserter { +pub struct Inserter +where + A: ActiveModelTrait, +{ query: InsertStatement, + model: PhantomData, } -#[derive(Clone, Debug)] -pub struct InsertResult { - pub last_insert_id: u64, +#[derive(Debug)] +pub struct InsertResult +where + A: ActiveModelTrait, +{ + pub last_insert_id: <<::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType, } impl Insert @@ -17,54 +27,79 @@ where A: ActiveModelTrait, { #[allow(unused_mut)] - pub fn exec( + pub fn exec<'a>( self, - db: &DatabaseConnection, - ) -> impl Future> + '_ { + db: &'a DatabaseConnection, + ) -> impl Future, DbErr>> + 'a + where + A: 'a, + { // so that self is dropped before entering await let mut query = self.query; #[cfg(feature = "sqlx-postgres")] if let DatabaseConnection::SqlxPostgresPoolConnection(_) = db { - use crate::{EntityTrait, Iterable}; - use sea_query::{Alias, Expr, Query}; - for key in ::PrimaryKey::iter() { + use crate::{sea_query::Query, Iterable}; + if ::PrimaryKey::iter().count() > 0 { query.returning( Query::select() - .expr_as(Expr::col(key), Alias::new("last_insert_id")) - .to_owned(), + .columns(::PrimaryKey::iter()) + .take(), ); } } - Inserter::new(query).exec(db) + Inserter::::new(query).exec(db) } } -impl Inserter { +impl Inserter +where + A: ActiveModelTrait, +{ pub fn new(query: InsertStatement) -> Self { - Self { query } + Self { + query, + model: PhantomData, + } } - pub fn exec( + pub fn exec<'a>( self, - db: &DatabaseConnection, - ) -> impl Future> + '_ { + db: &'a DatabaseConnection, + ) -> impl Future, DbErr>> + 'a + where + A: 'a, + { let builder = db.get_database_backend(); exec_insert(builder.build(&self.query), db) } } // Only Statement impl Send -async fn exec_insert(statement: Statement, db: &DatabaseConnection) -> Result { - // TODO: Postgres instead use query_one + returning clause - let result = match db { +async fn exec_insert( + statement: Statement, + db: &DatabaseConnection, +) -> Result, DbErr> +where + A: ActiveModelTrait, +{ + type PrimaryKey = <::Entity as EntityTrait>::PrimaryKey; + type ValueTypeOf = as PrimaryKeyTrait>::ValueType; + let last_insert_id = match db { #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => { + use crate::{sea_query::Iden, Iterable}; + let cols = PrimaryKey::::iter() + .map(|col| col.to_string()) + .collect::>(); let res = conn.query_one(statement).await?.unwrap(); - crate::query_result_into_exec_result(res)? + res.try_get_many("", cols.as_ref()).unwrap_or_default() + } + _ => { + let last_insert_id = db.execute(statement).await?.last_insert_id(); + ValueTypeOf::::try_from_u64(last_insert_id) + .ok() + .unwrap_or_default() } - _ => db.execute(statement).await?, }; - Ok(InsertResult { - last_insert_id: result.last_insert_id(), - }) + Ok(InsertResult { last_insert_id }) } diff --git a/src/executor/query.rs b/src/executor/query.rs index b935bfc3..1c2ce4a1 100644 --- a/src/executor/query.rs +++ b/src/executor/query.rs @@ -17,6 +17,10 @@ pub(crate) enum QueryResultRow { Mock(crate::MockRow), } +pub trait TryGetable: Sized { + fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result; +} + pub enum TryGetError { DbErr(DbErr), Null, @@ -31,12 +35,6 @@ impl From for DbErr { } } -pub trait TryGetable { - fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result - where - Self: Sized; -} - // QueryResult // impl QueryResult { @@ -46,6 +44,13 @@ impl QueryResult { { Ok(T::try_get(self, pre, col)?) } + + pub fn try_get_many(&self, pre: &str, cols: &[String]) -> Result + where + T: TryGetableMany, + { + Ok(T::try_get_many(self, pre, cols)?) + } } impl fmt::Debug for QueryResultRow { @@ -283,3 +288,135 @@ impl TryGetable for Decimal { #[cfg(feature = "with-uuid")] try_getable_all!(uuid::Uuid); + +// TryGetableMany // + +pub trait TryGetableMany: Sized { + fn try_get_many(res: &QueryResult, pre: &str, cols: &[String]) -> Result; +} + +impl TryGetableMany for T +where + T: TryGetable, +{ + fn try_get_many(res: &QueryResult, pre: &str, cols: &[String]) -> Result { + try_get_many_with_slice_len_of(1, cols)?; + T::try_get(res, pre, &cols[0]) + } +} + +impl TryGetableMany for (T, T) +where + T: TryGetable, +{ + fn try_get_many(res: &QueryResult, pre: &str, cols: &[String]) -> Result { + try_get_many_with_slice_len_of(2, cols)?; + Ok(( + T::try_get(res, pre, &cols[0])?, + T::try_get(res, pre, &cols[1])?, + )) + } +} + +impl TryGetableMany for (T, T, T) +where + T: TryGetable, +{ + fn try_get_many(res: &QueryResult, pre: &str, cols: &[String]) -> Result { + try_get_many_with_slice_len_of(3, cols)?; + Ok(( + T::try_get(res, pre, &cols[0])?, + T::try_get(res, pre, &cols[1])?, + T::try_get(res, pre, &cols[2])?, + )) + } +} + +fn try_get_many_with_slice_len_of(len: usize, cols: &[String]) -> Result<(), TryGetError> { + if cols.len() < len { + Err(TryGetError::DbErr(DbErr::Query(format!( + "Expect {} column names supplied but got slice of length {}", + len, + cols.len() + )))) + } else { + Ok(()) + } +} + +// TryFromU64 // + +pub trait TryFromU64: Sized { + fn try_from_u64(n: u64) -> Result; +} + +macro_rules! try_from_u64_err { + ( $type: ty ) => { + impl TryFromU64 for $type { + fn try_from_u64(_: u64) -> Result { + Err(DbErr::Exec(format!( + "{} cannot be converted from u64", + stringify!($type) + ))) + } + } + }; +} + +macro_rules! try_from_u64_tuple { + ( $type: ty ) => { + try_from_u64_err!(($type, $type)); + try_from_u64_err!(($type, $type, $type)); + }; +} + +macro_rules! try_from_u64_numeric { + ( $type: ty ) => { + impl TryFromU64 for $type { + fn try_from_u64(n: u64) -> Result { + use std::convert::TryInto; + n.try_into().map_err(|_| { + DbErr::Exec(format!( + "fail to convert '{}' into '{}'", + n, + stringify!($type) + )) + }) + } + } + try_from_u64_tuple!($type); + }; +} + +try_from_u64_numeric!(i8); +try_from_u64_numeric!(i16); +try_from_u64_numeric!(i32); +try_from_u64_numeric!(i64); +try_from_u64_numeric!(u8); +try_from_u64_numeric!(u16); +try_from_u64_numeric!(u32); +try_from_u64_numeric!(u64); + +macro_rules! try_from_u64_string { + ( $type: ty ) => { + impl TryFromU64 for $type { + fn try_from_u64(n: u64) -> Result { + Ok(n.to_string()) + } + } + try_from_u64_tuple!($type); + }; +} + +try_from_u64_string!(String); + +macro_rules! try_from_u64_dummy { + ( $type: ty ) => { + try_from_u64_err!($type); + try_from_u64_err!(($type, $type)); + try_from_u64_err!(($type, $type, $type)); + }; +} + +#[cfg(feature = "with-uuid")] +try_from_u64_dummy!(uuid::Uuid); diff --git a/src/lib.rs b/src/lib.rs index 7a20d851..441df8bd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,7 +93,7 @@ //! }; //! //! // insert one -//! let res: InsertResult = Fruit::insert(pear).exec(db).await?; +//! let res = Fruit::insert(pear).exec(db).await?; //! //! println!("InsertResult: {}", res.last_insert_id); //! # Ok(()) diff --git a/src/tests_cfg/cake.rs b/src/tests_cfg/cake.rs index 32cc22fd..49ada3aa 100644 --- a/src/tests_cfg/cake.rs +++ b/src/tests_cfg/cake.rs @@ -28,6 +28,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/src/tests_cfg/cake_filling.rs b/src/tests_cfg/cake_filling.rs index b1151ee4..ed6d0af0 100644 --- a/src/tests_cfg/cake_filling.rs +++ b/src/tests_cfg/cake_filling.rs @@ -29,6 +29,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = (i32, i32); + fn auto_increment() -> bool { false } diff --git a/src/tests_cfg/cake_filling_price.rs b/src/tests_cfg/cake_filling_price.rs index c0bcbea1..e820ae0f 100644 --- a/src/tests_cfg/cake_filling_price.rs +++ b/src/tests_cfg/cake_filling_price.rs @@ -35,6 +35,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = (i32, i32); + fn auto_increment() -> bool { false } diff --git a/src/tests_cfg/filling.rs b/src/tests_cfg/filling.rs index b439af7b..5e691980 100644 --- a/src/tests_cfg/filling.rs +++ b/src/tests_cfg/filling.rs @@ -41,6 +41,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/src/tests_cfg/fruit.rs b/src/tests_cfg/fruit.rs index 0511ae58..b0b8bc79 100644 --- a/src/tests_cfg/fruit.rs +++ b/src/tests_cfg/fruit.rs @@ -30,6 +30,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/tests/common/bakery_chain/baker.rs b/tests/common/bakery_chain/baker.rs index d111a311..9ab45905 100644 --- a/tests/common/bakery_chain/baker.rs +++ b/tests/common/bakery_chain/baker.rs @@ -31,6 +31,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/tests/common/bakery_chain/bakery.rs b/tests/common/bakery_chain/bakery.rs index a020cfce..d03298b3 100644 --- a/tests/common/bakery_chain/bakery.rs +++ b/tests/common/bakery_chain/bakery.rs @@ -29,6 +29,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/tests/common/bakery_chain/cake.rs b/tests/common/bakery_chain/cake.rs index 8eaf8b9e..cb1895a6 100644 --- a/tests/common/bakery_chain/cake.rs +++ b/tests/common/bakery_chain/cake.rs @@ -35,6 +35,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/tests/common/bakery_chain/cakes_bakers.rs b/tests/common/bakery_chain/cakes_bakers.rs index e2067b59..b0f52ec8 100644 --- a/tests/common/bakery_chain/cakes_bakers.rs +++ b/tests/common/bakery_chain/cakes_bakers.rs @@ -28,6 +28,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = (i32, i32); + fn auto_increment() -> bool { false } diff --git a/tests/common/bakery_chain/customer.rs b/tests/common/bakery_chain/customer.rs index 7e4d9e0c..2cf1b9ec 100644 --- a/tests/common/bakery_chain/customer.rs +++ b/tests/common/bakery_chain/customer.rs @@ -29,6 +29,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/tests/common/bakery_chain/lineitem.rs b/tests/common/bakery_chain/lineitem.rs index 26ec828e..c89b44b8 100644 --- a/tests/common/bakery_chain/lineitem.rs +++ b/tests/common/bakery_chain/lineitem.rs @@ -33,6 +33,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/tests/common/bakery_chain/metadata.rs b/tests/common/bakery_chain/metadata.rs new file mode 100644 index 00000000..ecef26c3 --- /dev/null +++ b/tests/common/bakery_chain/metadata.rs @@ -0,0 +1,61 @@ +use sea_orm::entity::prelude::*; +use uuid::Uuid; + +#[derive(Copy, Clone, Default, Debug, DeriveEntity)] +pub struct Entity; + +impl EntityName for Entity { + fn table_name(&self) -> &str { + "metadata" + } +} + +#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel)] +pub struct Model { + pub uuid: Uuid, + pub key: String, + pub value: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +pub enum Column { + Uuid, + Key, + Value, +} + +#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)] +pub enum PrimaryKey { + Uuid, +} + +impl PrimaryKeyTrait for PrimaryKey { + type ValueType = Uuid; + + fn auto_increment() -> bool { + false + } +} + +#[derive(Copy, Clone, Debug, EnumIter)] +pub enum Relation {} + +impl ColumnTrait for Column { + type EntityName = Entity; + + fn def(&self) -> ColumnDef { + match self { + Self::Uuid => ColumnType::Uuid.def(), + Self::Key => ColumnType::String(None).def(), + Self::Value => ColumnType::String(None).def(), + } + } +} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + unreachable!() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/tests/common/bakery_chain/mod.rs b/tests/common/bakery_chain/mod.rs index 89028aab..3282766d 100644 --- a/tests/common/bakery_chain/mod.rs +++ b/tests/common/bakery_chain/mod.rs @@ -4,6 +4,7 @@ pub mod cake; pub mod cakes_bakers; pub mod customer; pub mod lineitem; +pub mod metadata; pub mod order; pub use super::baker::Entity as Baker; @@ -12,4 +13,5 @@ pub use super::cake::Entity as Cake; pub use super::cakes_bakers::Entity as CakesBakers; pub use super::customer::Entity as Customer; pub use super::lineitem::Entity as Lineitem; +pub use super::metadata::Entity as Metadata; pub use super::order::Entity as Order; diff --git a/tests/common/bakery_chain/order.rs b/tests/common/bakery_chain/order.rs index 6fb8d212..f9561b92 100644 --- a/tests/common/bakery_chain/order.rs +++ b/tests/common/bakery_chain/order.rs @@ -33,6 +33,8 @@ pub enum PrimaryKey { } impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + fn auto_increment() -> bool { true } diff --git a/tests/common/setup/mod.rs b/tests/common/setup/mod.rs index 7528cf6c..74e35b45 100644 --- a/tests/common/setup/mod.rs +++ b/tests/common/setup/mod.rs @@ -45,13 +45,14 @@ pub async fn setup(base_url: &str, db_name: &str) -> DatabaseConnection { Database::connect(base_url).await.unwrap() }; - assert!(schema::create_bakery_table(&db).await.is_ok()); - assert!(schema::create_baker_table(&db).await.is_ok()); - assert!(schema::create_customer_table(&db).await.is_ok()); - assert!(schema::create_order_table(&db).await.is_ok()); - assert!(schema::create_cake_table(&db).await.is_ok()); - assert!(schema::create_cakes_bakers_table(&db).await.is_ok()); - assert!(schema::create_lineitem_table(&db).await.is_ok()); + schema::create_bakery_table(&db).await.unwrap(); + schema::create_baker_table(&db).await.unwrap(); + schema::create_customer_table(&db).await.unwrap(); + schema::create_order_table(&db).await.unwrap(); + schema::create_cake_table(&db).await.unwrap(); + schema::create_cakes_bakers_table(&db).await.unwrap(); + schema::create_lineitem_table(&db).await.unwrap(); + schema::create_metadata_table(&db).await.unwrap(); db } diff --git a/tests/common/setup/schema.rs b/tests/common/setup/schema.rs index c9523572..2dc5f239 100644 --- a/tests/common/setup/schema.rs +++ b/tests/common/setup/schema.rs @@ -271,3 +271,20 @@ pub async fn create_cake_table(db: &DbConn) -> Result { create_table(db, &stmt, Cake).await } + +pub async fn create_metadata_table(db: &DbConn) -> Result { + let stmt = sea_query::Table::create() + .table(metadata::Entity) + .if_not_exists() + .col( + ColumnDef::new(metadata::Column::Uuid) + .uuid() + .not_null() + .primary_key(), + ) + .col(ColumnDef::new(metadata::Column::Key).string().not_null()) + .col(ColumnDef::new(metadata::Column::Value).string().not_null()) + .to_owned(); + + create_table(db, &stmt).await +} diff --git a/tests/crud/create_baker.rs b/tests/crud/create_baker.rs index 8653b692..085b4005 100644 --- a/tests/crud/create_baker.rs +++ b/tests/crud/create_baker.rs @@ -7,7 +7,7 @@ pub async fn test_create_baker(db: &DbConn) { profit_margin: Set(10.4), ..Default::default() }; - let bakery_insert_res: InsertResult = Bakery::insert(seaside_bakery) + let bakery_insert_res = Bakery::insert(seaside_bakery) .exec(db) .await .expect("could not insert bakery"); @@ -30,7 +30,7 @@ pub async fn test_create_baker(db: &DbConn) { bakery_id: Set(Some(bakery_insert_res.last_insert_id as i32)), ..Default::default() }; - let res: InsertResult = Baker::insert(baker_bob) + let res = Baker::insert(baker_bob) .exec(db) .await .expect("could not insert baker"); diff --git a/tests/crud/create_cake.rs b/tests/crud/create_cake.rs index 07f9c4b3..eea74350 100644 --- a/tests/crud/create_cake.rs +++ b/tests/crud/create_cake.rs @@ -8,7 +8,7 @@ pub async fn test_create_cake(db: &DbConn) { profit_margin: Set(10.4), ..Default::default() }; - let bakery_insert_res: InsertResult = Bakery::insert(seaside_bakery) + let bakery_insert_res = Bakery::insert(seaside_bakery) .exec(db) .await .expect("could not insert bakery"); @@ -23,7 +23,7 @@ pub async fn test_create_cake(db: &DbConn) { bakery_id: Set(Some(bakery_insert_res.last_insert_id as i32)), ..Default::default() }; - let baker_insert_res: InsertResult = Baker::insert(baker_bob) + let baker_insert_res = Baker::insert(baker_bob) .exec(db) .await .expect("could not insert baker"); @@ -38,7 +38,7 @@ pub async fn test_create_cake(db: &DbConn) { ..Default::default() }; - let cake_insert_res: InsertResult = Cake::insert(mud_cake) + let cake_insert_res = Cake::insert(mud_cake) .exec(db) .await .expect("could not insert cake"); @@ -53,10 +53,18 @@ pub async fn test_create_cake(db: &DbConn) { baker_id: Set(baker_insert_res.last_insert_id as i32), ..Default::default() }; - let _cake_baker_res: InsertResult = CakesBakers::insert(cake_baker) + let cake_baker_res = CakesBakers::insert(cake_baker.clone()) .exec(db) .await .expect("could not insert cake_baker"); + assert_eq!( + cake_baker_res.last_insert_id, + if cfg!(feature = "sqlx-postgres") { + (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) + } else { + Default::default() + } + ); assert!(cake.is_some()); let cake_model = cake.unwrap(); diff --git a/tests/crud/create_lineitem.rs b/tests/crud/create_lineitem.rs index c82958f9..9acba961 100644 --- a/tests/crud/create_lineitem.rs +++ b/tests/crud/create_lineitem.rs @@ -10,7 +10,7 @@ pub async fn test_create_lineitem(db: &DbConn) { profit_margin: Set(10.4), ..Default::default() }; - let bakery_insert_res: InsertResult = Bakery::insert(seaside_bakery) + let bakery_insert_res = Bakery::insert(seaside_bakery) .exec(db) .await .expect("could not insert bakery"); @@ -26,7 +26,7 @@ pub async fn test_create_lineitem(db: &DbConn) { bakery_id: Set(Some(bakery_insert_res.last_insert_id as i32)), ..Default::default() }; - let baker_insert_res: InsertResult = Baker::insert(baker_bob) + let baker_insert_res = Baker::insert(baker_bob) .exec(db) .await .expect("could not insert baker"); @@ -41,7 +41,7 @@ pub async fn test_create_lineitem(db: &DbConn) { ..Default::default() }; - let cake_insert_res: InsertResult = Cake::insert(mud_cake) + let cake_insert_res = Cake::insert(mud_cake) .exec(db) .await .expect("could not insert cake"); @@ -52,10 +52,18 @@ pub async fn test_create_lineitem(db: &DbConn) { baker_id: Set(baker_insert_res.last_insert_id as i32), ..Default::default() }; - let _cake_baker_res: InsertResult = CakesBakers::insert(cake_baker) + let cake_baker_res = CakesBakers::insert(cake_baker.clone()) .exec(db) .await .expect("could not insert cake_baker"); + assert_eq!( + cake_baker_res.last_insert_id, + if cfg!(feature = "sqlx-postgres") { + (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) + } else { + Default::default() + } + ); // Customer let customer_kate = customer::ActiveModel { @@ -63,7 +71,7 @@ pub async fn test_create_lineitem(db: &DbConn) { notes: Set(Some("Loves cheese cake".to_owned())), ..Default::default() }; - let customer_insert_res: InsertResult = Customer::insert(customer_kate) + let customer_insert_res = Customer::insert(customer_kate) .exec(db) .await .expect("could not insert customer"); @@ -76,7 +84,7 @@ pub async fn test_create_lineitem(db: &DbConn) { placed_at: Set(Utc::now().naive_utc()), ..Default::default() }; - let order_insert_res: InsertResult = Order::insert(order_1) + let order_insert_res = Order::insert(order_1) .exec(db) .await .expect("could not insert order"); @@ -89,7 +97,7 @@ pub async fn test_create_lineitem(db: &DbConn) { quantity: Set(1), ..Default::default() }; - let lineitem_insert_res: InsertResult = Lineitem::insert(lineitem_1) + let lineitem_insert_res = Lineitem::insert(lineitem_1) .exec(db) .await .expect("could not insert lineitem"); @@ -105,7 +113,7 @@ pub async fn test_create_lineitem(db: &DbConn) { assert_eq!(lineitem_model.price, dec!(7.55)); - let cake: Option = Cake::find_by_id(lineitem_model.cake_id as u64) + let cake: Option = Cake::find_by_id(lineitem_model.cake_id) .one(db) .await .expect("could not find cake"); diff --git a/tests/crud/create_order.rs b/tests/crud/create_order.rs index 46ebcf09..8d10cfcd 100644 --- a/tests/crud/create_order.rs +++ b/tests/crud/create_order.rs @@ -10,7 +10,7 @@ pub async fn test_create_order(db: &DbConn) { profit_margin: Set(10.4), ..Default::default() }; - let bakery_insert_res: InsertResult = Bakery::insert(seaside_bakery) + let bakery_insert_res = Bakery::insert(seaside_bakery) .exec(db) .await .expect("could not insert bakery"); @@ -26,7 +26,7 @@ pub async fn test_create_order(db: &DbConn) { bakery_id: Set(Some(bakery_insert_res.last_insert_id as i32)), ..Default::default() }; - let baker_insert_res: InsertResult = Baker::insert(baker_bob) + let baker_insert_res = Baker::insert(baker_bob) .exec(db) .await .expect("could not insert baker"); @@ -41,7 +41,7 @@ pub async fn test_create_order(db: &DbConn) { ..Default::default() }; - let cake_insert_res: InsertResult = Cake::insert(mud_cake) + let cake_insert_res = Cake::insert(mud_cake) .exec(db) .await .expect("could not insert cake"); @@ -52,10 +52,18 @@ pub async fn test_create_order(db: &DbConn) { baker_id: Set(baker_insert_res.last_insert_id as i32), ..Default::default() }; - let _cake_baker_res: InsertResult = CakesBakers::insert(cake_baker) + let cake_baker_res = CakesBakers::insert(cake_baker.clone()) .exec(db) .await .expect("could not insert cake_baker"); + assert_eq!( + cake_baker_res.last_insert_id, + if cfg!(feature = "sqlx-postgres") { + (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) + } else { + Default::default() + } + ); // Customer let customer_kate = customer::ActiveModel { @@ -63,7 +71,7 @@ pub async fn test_create_order(db: &DbConn) { notes: Set(Some("Loves cheese cake".to_owned())), ..Default::default() }; - let customer_insert_res: InsertResult = Customer::insert(customer_kate) + let customer_insert_res = Customer::insert(customer_kate) .exec(db) .await .expect("could not insert customer"); @@ -76,7 +84,7 @@ pub async fn test_create_order(db: &DbConn) { placed_at: Set(Utc::now().naive_utc()), ..Default::default() }; - let order_insert_res: InsertResult = Order::insert(order_1) + let order_insert_res = Order::insert(order_1) .exec(db) .await .expect("could not insert order"); @@ -89,7 +97,7 @@ pub async fn test_create_order(db: &DbConn) { quantity: Set(2), ..Default::default() }; - let _lineitem_insert_res: InsertResult = Lineitem::insert(lineitem_1) + let _lineitem_insert_res = Lineitem::insert(lineitem_1) .exec(db) .await .expect("could not insert lineitem"); @@ -103,7 +111,7 @@ pub async fn test_create_order(db: &DbConn) { let order_model = order.unwrap(); assert_eq!(order_model.total, dec!(15.10)); - let customer: Option = Customer::find_by_id(order_model.customer_id as u64) + let customer: Option = Customer::find_by_id(order_model.customer_id) .one(db) .await .expect("could not find customer"); @@ -111,7 +119,7 @@ pub async fn test_create_order(db: &DbConn) { let customer_model = customer.unwrap(); assert_eq!(customer_model.name, "Kate"); - let bakery: Option = Bakery::find_by_id(order_model.bakery_id as i64) + let bakery: Option = Bakery::find_by_id(order_model.bakery_id) .one(db) .await .expect("could not find bakery"); diff --git a/tests/crud/deletes.rs b/tests/crud/deletes.rs index ebff4b5f..4c34d36b 100644 --- a/tests/crud/deletes.rs +++ b/tests/crud/deletes.rs @@ -10,7 +10,7 @@ pub async fn test_delete_cake(db: &DbConn) { profit_margin: Set(10.4), ..Default::default() }; - let bakery_insert_res: InsertResult = Bakery::insert(seaside_bakery) + let bakery_insert_res = Bakery::insert(seaside_bakery) .exec(db) .await .expect("could not insert bakery"); diff --git a/tests/crud/mod.rs b/tests/crud/mod.rs index 457f639d..7e024055 100644 --- a/tests/crud/mod.rs +++ b/tests/crud/mod.rs @@ -1,4 +1,4 @@ -use sea_orm::{entity::*, DbConn, InsertResult}; +use sea_orm::{entity::*, DbConn}; pub use super::common::bakery_chain::*; @@ -15,7 +15,7 @@ pub async fn test_create_bakery(db: &DbConn) { profit_margin: Set(10.4), ..Default::default() }; - let res: InsertResult = Bakery::insert(seaside_bakery) + let res = Bakery::insert(seaside_bakery) .exec(db) .await .expect("could not insert bakery"); @@ -37,7 +37,7 @@ pub async fn test_create_customer(db: &DbConn) { notes: Set(Some("Loves cheese cake".to_owned())), ..Default::default() }; - let res: InsertResult = Customer::insert(customer_kate) + let res = Customer::insert(customer_kate) .exec(db) .await .expect("could not insert customer"); diff --git a/tests/crud/updates.rs b/tests/crud/updates.rs index 505e4837..172c9c21 100644 --- a/tests/crud/updates.rs +++ b/tests/crud/updates.rs @@ -8,7 +8,7 @@ pub async fn test_update_cake(db: &DbConn) { profit_margin: Set(10.4), ..Default::default() }; - let bakery_insert_res: InsertResult = Bakery::insert(seaside_bakery) + let bakery_insert_res = Bakery::insert(seaside_bakery) .exec(db) .await .expect("could not insert bakery"); @@ -22,7 +22,7 @@ pub async fn test_update_cake(db: &DbConn) { ..Default::default() }; - let cake_insert_res: InsertResult = Cake::insert(mud_cake) + let cake_insert_res = Cake::insert(mud_cake) .exec(db) .await .expect("could not insert cake"); @@ -62,7 +62,7 @@ pub async fn test_update_bakery(db: &DbConn) { profit_margin: Set(10.4), ..Default::default() }; - let bakery_insert_res: InsertResult = Bakery::insert(seaside_bakery) + let bakery_insert_res = Bakery::insert(seaside_bakery) .exec(db) .await .expect("could not insert bakery"); @@ -130,11 +130,10 @@ pub async fn test_update_deleted_customer(db: &DbConn) { assert_eq!(Customer::find().count(db).await.unwrap(), init_n_customers); - let customer: Option = - Customer::find_by_id(customer_id.clone().unwrap() as i64) - .one(db) - .await - .expect("could not find customer"); + let customer: Option = Customer::find_by_id(customer_id.clone().unwrap()) + .one(db) + .await + .expect("could not find customer"); assert_eq!(customer, None); } diff --git a/tests/primary_key_tests.rs b/tests/primary_key_tests.rs new file mode 100644 index 00000000..74dd46e5 --- /dev/null +++ b/tests/primary_key_tests.rs @@ -0,0 +1,41 @@ +use sea_orm::{entity::prelude::*, DatabaseConnection, Set}; +pub mod common; +pub use common::{bakery_chain::*, setup::*, TestContext}; +use uuid::Uuid; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +async fn main() -> Result<(), DbErr> { + let ctx = TestContext::new("bakery_chain_schema_primary_key_tests").await; + + create_metadata(&ctx.db).await?; + + ctx.delete().await; + + Ok(()) +} + +async fn create_metadata(db: &DatabaseConnection) -> Result<(), DbErr> { + let metadata = metadata::ActiveModel { + uuid: Set(Uuid::new_v4()), + key: Set("markup".to_owned()), + value: Set("1.18".to_owned()), + }; + + let res = Metadata::insert(metadata.clone()).exec(db).await?; + + assert_eq!( + res.last_insert_id, + if cfg!(feature = "sqlx-postgres") { + metadata.uuid.unwrap() + } else { + Default::default() + } + ); + + Ok(()) +} diff --git a/tests/sequential_op_tests.rs b/tests/sequential_op_tests.rs index a854c768..6c490762 100644 --- a/tests/sequential_op_tests.rs +++ b/tests/sequential_op_tests.rs @@ -67,7 +67,7 @@ async fn init_setup(db: &DatabaseConnection) { ..Default::default() }; - let cake_insert_res: InsertResult = Cake::insert(mud_cake) + let cake_insert_res = Cake::insert(mud_cake) .exec(db) .await .expect("could not insert cake"); @@ -78,10 +78,18 @@ async fn init_setup(db: &DatabaseConnection) { ..Default::default() }; - let _cake_baker_res: InsertResult = CakesBakers::insert(cake_baker) + let cake_baker_res = CakesBakers::insert(cake_baker.clone()) .exec(db) .await .expect("could not insert cake_baker"); + assert_eq!( + cake_baker_res.last_insert_id, + if cfg!(feature = "sqlx-postgres") { + (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) + } else { + Default::default() + } + ); let customer_kate = customer::ActiveModel { name: Set("Kate".to_owned()), @@ -183,7 +191,7 @@ async fn find_baker_least_sales(db: &DatabaseConnection) -> Option results.sort_by(|a, b| b.cakes_sold.cmp(&a.cakes_sold)); - Baker::find_by_id(results.last().unwrap().id as i64) + Baker::find_by_id(results.last().unwrap().id) .one(db) .await .unwrap() @@ -200,7 +208,7 @@ async fn create_cake(db: &DatabaseConnection, baker: baker::Model) -> Option Option