diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index 9006e385..f952f721 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -114,10 +114,11 @@ pub trait ActiveModelTrait: Clone + Debug { let found = ::find_by_id(res.last_insert_id) .one(db) .await?; - match found { - Some(model) => Ok(model.into_active_model()), - None => Err(DbErr::Exec("Failed to find inserted item".to_owned())), - } + let am = match found { + Some(model) => model.into_active_model(), + None => return Err(DbErr::Exec("Failed to find inserted item".to_owned())), + }; + ActiveModelBehavior::after_save(am, true) } async fn update<'a, C>(self, db: &'a C) -> Result diff --git a/tests/common/bakery_chain/cake.rs b/tests/common/bakery_chain/cake.rs index 4730394b..0ecaf5de 100644 --- a/tests/common/bakery_chain/cake.rs +++ b/tests/common/bakery_chain/cake.rs @@ -49,4 +49,54 @@ impl Related for Entity { } } -impl ActiveModelBehavior for ActiveModel {} +impl ActiveModelBehavior for ActiveModel { + fn new() -> Self { + use sea_orm::Set; + Self { + serial: Set(Uuid::new_v4()), + ..ActiveModelTrait::default() + } + } + + fn before_save(self, insert: bool) -> Result { + use rust_decimal_macros::dec; + if self.price.as_ref() == &dec!(0) { + Err(DbErr::Custom(format!( + "[before_save] Invalid Price, insert: {}", + insert + ))) + } else { + Ok(self) + } + } + + fn after_save(self, insert: bool) -> Result { + use rust_decimal_macros::dec; + if self.price.as_ref() < &dec!(0) { + Err(DbErr::Custom(format!( + "[after_save] Invalid Price, insert: {}", + insert + ))) + } else { + Ok(self) + } + } + + fn before_delete(self) -> Result { + if self.name.as_ref().contains("(err_on_before_delete)") { + Err(DbErr::Custom( + "[before_delete] Cannot be deleted".to_owned(), + )) + } else { + Ok(self) + } + } + + fn after_delete(self) -> Result { + if self.name.as_ref().contains("(err_on_after_delete)") { + Err(DbErr::Custom("[after_delete] Cannot be deleted".to_owned())) + } else { + Ok(self) + } + } +} diff --git a/tests/transaction_tests.rs b/tests/transaction_tests.rs index 6a713e4d..7845843e 100644 --- a/tests/transaction_tests.rs +++ b/tests/transaction_tests.rs @@ -1,9 +1,9 @@ pub mod common; pub use common::{bakery_chain::*, setup::*, TestContext}; +use pretty_assertions::assert_eq; pub use sea_orm::entity::*; -pub use sea_orm::{ConnectionTrait, QueryFilter}; -use sea_orm::{DatabaseTransaction, DbErr}; +pub use sea_orm::*; #[sea_orm_macros::test] #[cfg(any( @@ -105,6 +105,342 @@ fn _transaction_with_reference<'a>( }) } +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn transaction_begin_out_of_scope() -> Result<(), DbErr> { + let ctx = TestContext::new("transaction_begin_out_of_scope_test").await; + create_tables(&ctx.db).await?; + + assert_eq!(bakery::Entity::find().all(&ctx.db).await?.len(), 0); + + { + // Transaction begin in this scope + let txn = ctx.db.begin().await?; + + bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(&txn) + .await?; + + assert_eq!(bakery::Entity::find().all(&txn).await?.len(), 1); + + bakery::ActiveModel { + name: Set("Top Bakery".to_owned()), + profit_margin: Set(15.0), + ..Default::default() + } + .save(&txn) + .await?; + + assert_eq!(bakery::Entity::find().all(&txn).await?.len(), 2); + + // The scope ended and transaction is dropped without commit + } + + assert_eq!(bakery::Entity::find().all(&ctx.db).await?.len(), 0); + + ctx.delete().await; + Ok(()) +} + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn transaction_begin_commit() -> Result<(), DbErr> { + let ctx = TestContext::new("transaction_begin_commit_test").await; + create_tables(&ctx.db).await?; + + assert_eq!(bakery::Entity::find().all(&ctx.db).await?.len(), 0); + + { + // Transaction begin in this scope + let txn = ctx.db.begin().await?; + + bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(&txn) + .await?; + + assert_eq!(bakery::Entity::find().all(&txn).await?.len(), 1); + + bakery::ActiveModel { + name: Set("Top Bakery".to_owned()), + profit_margin: Set(15.0), + ..Default::default() + } + .save(&txn) + .await?; + + assert_eq!(bakery::Entity::find().all(&txn).await?.len(), 2); + + // Commit changes before the end of scope + txn.commit().await?; + } + + assert_eq!(bakery::Entity::find().all(&ctx.db).await?.len(), 2); + + ctx.delete().await; + Ok(()) +} + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn transaction_begin_rollback() -> Result<(), DbErr> { + let ctx = TestContext::new("transaction_begin_rollback_test").await; + create_tables(&ctx.db).await?; + + assert_eq!(bakery::Entity::find().all(&ctx.db).await?.len(), 0); + + { + // Transaction begin in this scope + let txn = ctx.db.begin().await?; + + bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(&txn) + .await?; + + assert_eq!(bakery::Entity::find().all(&txn).await?.len(), 1); + + bakery::ActiveModel { + name: Set("Top Bakery".to_owned()), + profit_margin: Set(15.0), + ..Default::default() + } + .save(&txn) + .await?; + + assert_eq!(bakery::Entity::find().all(&txn).await?.len(), 2); + + // Rollback changes before the end of scope + txn.rollback().await?; + } + + assert_eq!(bakery::Entity::find().all(&ctx.db).await?.len(), 0); + + ctx.delete().await; + Ok(()) +} + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn transaction_closure_commit() -> Result<(), DbErr> { + let ctx = TestContext::new("transaction_closure_commit_test").await; + create_tables(&ctx.db).await?; + + assert_eq!(bakery::Entity::find().all(&ctx.db).await?.len(), 0); + + let res = ctx + .db + .transaction::<_, _, DbErr>(|txn| { + Box::pin(async move { + bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(txn) + .await?; + + assert_eq!(bakery::Entity::find().all(txn).await?.len(), 1); + + bakery::ActiveModel { + name: Set("Top Bakery".to_owned()), + profit_margin: Set(15.0), + ..Default::default() + } + .save(txn) + .await?; + + assert_eq!(bakery::Entity::find().all(txn).await?.len(), 2); + + Ok(()) + }) + }) + .await; + + assert!(res.is_ok()); + + assert_eq!(bakery::Entity::find().all(&ctx.db).await?.len(), 2); + + ctx.delete().await; + Ok(()) +} + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn transaction_closure_rollback() -> Result<(), DbErr> { + let ctx = TestContext::new("transaction_closure_rollback_test").await; + create_tables(&ctx.db).await?; + + assert_eq!(bakery::Entity::find().all(&ctx.db).await?.len(), 0); + + let res = ctx + .db + .transaction::<_, _, DbErr>(|txn| { + Box::pin(async move { + bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(txn) + .await?; + + assert_eq!(bakery::Entity::find().all(txn).await?.len(), 1); + + bakery::ActiveModel { + name: Set("Top Bakery".to_owned()), + profit_margin: Set(15.0), + ..Default::default() + } + .save(txn) + .await?; + + assert_eq!(bakery::Entity::find().all(txn).await?.len(), 2); + + bakery::ActiveModel { + id: Set(1), + name: Set("Duplicated primary key".to_owned()), + profit_margin: Set(20.0), + ..Default::default() + } + .insert(txn) + .await?; // Throw error and rollback + + // This line won't be reached + assert!(false); + + Ok(()) + }) + }) + .await; + + assert!(res.is_err()); + + assert_eq!(bakery::Entity::find().all(&ctx.db).await?.len(), 0); + + ctx.delete().await; + Ok(()) +} + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn transaction_with_active_model_behaviour() -> Result<(), DbErr> { + use rust_decimal_macros::dec; + let ctx = TestContext::new("transaction_with_active_model_behaviour_test").await; + create_tables(&ctx.db).await?; + + if let Ok(txn) = ctx.db.begin().await { + assert_eq!( + cake::ActiveModel { + name: Set("Cake with invalid price".to_owned()), + price: Set(dec!(0)), + gluten_free: Set(false), + ..Default::default() + } + .save(&txn) + .await, + Err(DbErr::Custom( + "[before_save] Invalid Price, insert: true".to_owned() + )) + ); + + assert_eq!(cake::Entity::find().all(&txn).await?.len(), 0); + + assert_eq!( + cake::ActiveModel { + name: Set("Cake with invalid price".to_owned()), + price: Set(dec!(-10)), + gluten_free: Set(false), + ..Default::default() + } + .save(&txn) + .await, + Err(DbErr::Custom( + "[after_save] Invalid Price, insert: true".to_owned() + )) + ); + + assert_eq!(cake::Entity::find().all(&txn).await?.len(), 1); + + let readonly_cake_1 = cake::ActiveModel { + name: Set("Readonly cake (err_on_before_delete)".to_owned()), + price: Set(dec!(10)), + gluten_free: Set(true), + ..Default::default() + } + .save(&txn) + .await?; + + assert_eq!(cake::Entity::find().all(&txn).await?.len(), 2); + + assert_eq!( + readonly_cake_1.delete(&txn).await.err(), + Some(DbErr::Custom( + "[before_delete] Cannot be deleted".to_owned() + )) + ); + + assert_eq!(cake::Entity::find().all(&txn).await?.len(), 2); + + let readonly_cake_2 = cake::ActiveModel { + name: Set("Readonly cake (err_on_after_delete)".to_owned()), + price: Set(dec!(10)), + gluten_free: Set(true), + ..Default::default() + } + .save(&txn) + .await?; + + assert_eq!(cake::Entity::find().all(&txn).await?.len(), 3); + + assert_eq!( + readonly_cake_2.delete(&txn).await.err(), + Some(DbErr::Custom("[after_delete] Cannot be deleted".to_owned())) + ); + + assert_eq!(cake::Entity::find().all(&txn).await?.len(), 2); + } + + assert_eq!(cake::Entity::find().all(&ctx.db).await?.len(), 0); + + ctx.delete().await; + Ok(()) +} + #[sea_orm_macros::test] #[cfg(any( feature = "sqlx-mysql",