diff --git a/src/entity/primary_key.rs b/src/entity/primary_key.rs index f24c62df..9702b8a1 100644 --- a/src/entity/primary_key.rs +++ b/src/entity/primary_key.rs @@ -1,10 +1,7 @@ use super::{ColumnTrait, IdenStatic, Iterable}; -use crate::TryGetable; +use crate::{TryFromU64, TryGetable}; use sea_query::IntoValueTuple; -use std::{ - fmt::{Debug, Display}, - str::FromStr, -}; +use std::fmt::{Debug, Display}; //LINT: composite primary key cannot auto increment pub trait PrimaryKeyTrait: IdenStatic + Iterable { @@ -15,7 +12,7 @@ pub trait PrimaryKeyTrait: IdenStatic + Iterable { + PartialEq + IntoValueTuple + TryGetable - + FromStr; + + TryFromU64; fn auto_increment() -> bool; } diff --git a/src/error.rs b/src/error.rs index 8a695dac..09f80b0a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,4 @@ -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum DbErr { Conn(String), Exec(String), diff --git a/src/executor/insert.rs b/src/executor/insert.rs index e806fe62..6e657d23 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,5 +1,6 @@ use crate::{ - error::*, ActiveModelTrait, DatabaseConnection, EntityTrait, Insert, PrimaryKeyTrait, Statement, + error::*, ActiveModelTrait, DatabaseConnection, EntityTrait, Insert, PrimaryKeyTrait, + Statement, TryFromU64, }; use sea_query::InsertStatement; use std::{future::Future, marker::PhantomData}; @@ -82,20 +83,17 @@ async fn exec_insert( where A: ActiveModelTrait, { - // TODO: Postgres instead use query_one + returning clause + type ValueTypeOf = <<::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType; let last_insert_id = match db { #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => { let res = conn.query_one(statement).await?.unwrap(); res.try_get("", "last_insert_id").unwrap_or_default() } - _ => db - .execute(statement) - .await? - .last_insert_id() - .to_string() - .parse() - .unwrap_or_default(), + _ => { + let last_insert_id = db.execute(statement).await?.last_insert_id(); + ValueTypeOf::::try_from_u64(last_insert_id)? + } }; Ok(InsertResult { last_insert_id }) } diff --git a/src/executor/query.rs b/src/executor/query.rs index b7e1d12e..5420c4c9 100644 --- a/src/executor/query.rs +++ b/src/executor/query.rs @@ -405,3 +405,49 @@ impl TryGetable for Option { #[cfg(feature = "with-uuid")] try_getable_all!(uuid::Uuid); + +pub trait TryFromU64: Sized { + fn try_from_u64(n: u64) -> Result; +} + +macro_rules! try_from_u64 { + ( $type: ty ) => { + impl TryFromU64 for $type { + fn try_from_u64(n: u64) -> Result { + use std::convert::TryInto; + n.try_into().map_err(|_| { + DbErr::Exec(format!( + "fail to convert '{}' into '{}'", + n, + stringify!($type) + )) + }) + } + } + }; +} + +try_from_u64!(i8); +try_from_u64!(i16); +try_from_u64!(i32); +try_from_u64!(i64); +try_from_u64!(u8); +try_from_u64!(u16); +try_from_u64!(u32); +try_from_u64!(u64); + +macro_rules! try_from_u64_err { + ( $type: ty ) => { + impl TryFromU64 for $type { + fn try_from_u64(_: u64) -> Result { + Err(DbErr::Exec(format!( + "{} cannot be converted from u64", + stringify!($type) + ))) + } + } + }; +} + +#[cfg(feature = "with-uuid")] +try_from_u64_err!(uuid::Uuid); diff --git a/tests/common/bakery_chain/metadata.rs b/tests/common/bakery_chain/metadata.rs new file mode 100644 index 00000000..ecef26c3 --- /dev/null +++ b/tests/common/bakery_chain/metadata.rs @@ -0,0 +1,61 @@ +use sea_orm::entity::prelude::*; +use uuid::Uuid; + +#[derive(Copy, Clone, Default, Debug, DeriveEntity)] +pub struct Entity; + +impl EntityName for Entity { + fn table_name(&self) -> &str { + "metadata" + } +} + +#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel)] +pub struct Model { + pub uuid: Uuid, + pub key: String, + pub value: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +pub enum Column { + Uuid, + Key, + Value, +} + +#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)] +pub enum PrimaryKey { + Uuid, +} + +impl PrimaryKeyTrait for PrimaryKey { + type ValueType = Uuid; + + fn auto_increment() -> bool { + false + } +} + +#[derive(Copy, Clone, Debug, EnumIter)] +pub enum Relation {} + +impl ColumnTrait for Column { + type EntityName = Entity; + + fn def(&self) -> ColumnDef { + match self { + Self::Uuid => ColumnType::Uuid.def(), + Self::Key => ColumnType::String(None).def(), + Self::Value => ColumnType::String(None).def(), + } + } +} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + unreachable!() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/tests/common/bakery_chain/mod.rs b/tests/common/bakery_chain/mod.rs index 89028aab..3282766d 100644 --- a/tests/common/bakery_chain/mod.rs +++ b/tests/common/bakery_chain/mod.rs @@ -4,6 +4,7 @@ pub mod cake; pub mod cakes_bakers; pub mod customer; pub mod lineitem; +pub mod metadata; pub mod order; pub use super::baker::Entity as Baker; @@ -12,4 +13,5 @@ pub use super::cake::Entity as Cake; pub use super::cakes_bakers::Entity as CakesBakers; pub use super::customer::Entity as Customer; pub use super::lineitem::Entity as Lineitem; +pub use super::metadata::Entity as Metadata; pub use super::order::Entity as Order; diff --git a/tests/common/setup/mod.rs b/tests/common/setup/mod.rs index 7528cf6c..74e35b45 100644 --- a/tests/common/setup/mod.rs +++ b/tests/common/setup/mod.rs @@ -45,13 +45,14 @@ pub async fn setup(base_url: &str, db_name: &str) -> DatabaseConnection { Database::connect(base_url).await.unwrap() }; - assert!(schema::create_bakery_table(&db).await.is_ok()); - assert!(schema::create_baker_table(&db).await.is_ok()); - assert!(schema::create_customer_table(&db).await.is_ok()); - assert!(schema::create_order_table(&db).await.is_ok()); - assert!(schema::create_cake_table(&db).await.is_ok()); - assert!(schema::create_cakes_bakers_table(&db).await.is_ok()); - assert!(schema::create_lineitem_table(&db).await.is_ok()); + schema::create_bakery_table(&db).await.unwrap(); + schema::create_baker_table(&db).await.unwrap(); + schema::create_customer_table(&db).await.unwrap(); + schema::create_order_table(&db).await.unwrap(); + schema::create_cake_table(&db).await.unwrap(); + schema::create_cakes_bakers_table(&db).await.unwrap(); + schema::create_lineitem_table(&db).await.unwrap(); + schema::create_metadata_table(&db).await.unwrap(); db } diff --git a/tests/common/setup/schema.rs b/tests/common/setup/schema.rs index 4eba40ab..0e124fa2 100644 --- a/tests/common/setup/schema.rs +++ b/tests/common/setup/schema.rs @@ -254,3 +254,20 @@ pub async fn create_cake_table(db: &DbConn) -> Result { create_table(db, &stmt).await } + +pub async fn create_metadata_table(db: &DbConn) -> Result { + let stmt = sea_query::Table::create() + .table(metadata::Entity) + .if_not_exists() + .col( + ColumnDef::new(metadata::Column::Uuid) + .uuid() + .not_null() + .primary_key(), + ) + .col(ColumnDef::new(metadata::Column::Key).string().not_null()) + .col(ColumnDef::new(metadata::Column::Value).string().not_null()) + .to_owned(); + + create_table(db, &stmt).await +} diff --git a/tests/primary_key_tests.rs b/tests/primary_key_tests.rs new file mode 100644 index 00000000..2f74eafc --- /dev/null +++ b/tests/primary_key_tests.rs @@ -0,0 +1,41 @@ +use sea_orm::{entity::prelude::*, DatabaseConnection, Set}; +pub mod common; +pub use common::{bakery_chain::*, setup::*, TestContext}; +use uuid::Uuid; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +async fn main() -> Result<(), DbErr> { + let ctx = TestContext::new("bakery_chain_schema_primary_key_tests").await; + + create_metadata(&ctx.db).await?; + + ctx.delete().await; + + Ok(()) +} + +async fn create_metadata(db: &DatabaseConnection) -> Result<(), DbErr> { + let metadata = metadata::ActiveModel { + uuid: Set(Uuid::new_v4()), + key: Set("markup".to_owned()), + value: Set("1.18".to_owned()), + }; + + let res = Metadata::insert(metadata.clone()).exec(db).await; + + if cfg!(feature = "sqlx-postgres") { + assert_eq!(metadata.uuid.unwrap(), res?.last_insert_id); + } else { + assert_eq!( + res.unwrap_err(), + DbErr::Exec("uuid::Uuid cannot be converted from u64".to_owned()) + ); + } + + Ok(()) +}