From f79a4171503dbcd6f86187b6263e7c345bd173be Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Mon, 27 Sep 2021 18:01:38 +0800 Subject: [PATCH 01/65] Remove `ValueType: Default` --- src/entity/active_model.rs | 25 +++++++++++-------------- src/entity/base_entity.rs | 8 ++++++-- src/entity/prelude.rs | 5 +++-- src/entity/primary_key.rs | 20 +++++++++++--------- src/executor/insert.rs | 29 ++++++++++++++++++++--------- src/query/insert.rs | 23 +++++++++++++++++++++-- 6 files changed, 72 insertions(+), 38 deletions(-) diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index cfcb0bbd..c69e5b01 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -1,6 +1,6 @@ use crate::{ error::*, DatabaseConnection, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, - PrimaryKeyTrait, Value, + PrimaryKeyValue, Value, }; use async_trait::async_trait; use std::fmt::Debug; @@ -70,23 +70,18 @@ pub trait ActiveModelTrait: Clone + Debug { async fn insert(self, db: &DatabaseConnection) -> Result where ::Model: IntoActiveModel, + <::Entity as EntityTrait>::PrimaryKey: + PrimaryKeyValue<::Entity>, { let am = self; let exec = ::insert(am).exec(db); let res = exec.await?; - // Assume valid last_insert_id is not equals to Default::default() - if res.last_insert_id - != <::PrimaryKey as PrimaryKeyTrait>::ValueType::default() - { - let found = ::find_by_id(res.last_insert_id) - .one(db) - .await?; - match found { - Some(model) => Ok(model.into_active_model()), - None => Err(DbErr::Exec("Failed to find inserted item".to_owned())), - } - } else { - Ok(Self::default()) + let found = ::find_by_id(res.last_insert_id) + .one(db) + .await?; + match found { + Some(model) => Ok(model.into_active_model()), + None => Err(DbErr::Exec("Failed to find inserted item".to_owned())), } } @@ -101,6 +96,8 @@ pub trait ActiveModelTrait: Clone + Debug { where Self: ActiveModelBehavior, ::Model: IntoActiveModel, + <::Entity as EntityTrait>::PrimaryKey: + PrimaryKeyValue<::Entity>, { let mut am = self; am = ActiveModelBehavior::before_save(am); diff --git a/src/entity/base_entity.rs b/src/entity/base_entity.rs index 7ba1e965..0c9813f7 100644 --- a/src/entity/base_entity.rs +++ b/src/entity/base_entity.rs @@ -1,7 +1,7 @@ use crate::{ ActiveModelTrait, ColumnTrait, Delete, DeleteMany, DeleteOne, FromQueryResult, Insert, - ModelTrait, PrimaryKeyToColumn, PrimaryKeyTrait, QueryFilter, Related, RelationBuilder, - RelationTrait, RelationType, Select, Update, UpdateMany, UpdateOne, + ModelTrait, PrimaryKeyToColumn, PrimaryKeyTrait, PrimaryKeyValue, QueryFilter, Related, + RelationBuilder, RelationTrait, RelationType, Select, Update, UpdateMany, UpdateOne, }; use sea_query::{Alias, Iden, IntoIden, IntoTableRef, IntoValueTuple, TableRef}; pub use sea_strum::IntoEnumIterator as Iterable; @@ -299,6 +299,8 @@ pub trait EntityTrait: EntityName { fn insert(model: A) -> Insert where A: ActiveModelTrait, + <::Entity as EntityTrait>::PrimaryKey: + PrimaryKeyValue<::Entity>, { Insert::one(model) } @@ -352,6 +354,8 @@ pub trait EntityTrait: EntityName { where A: ActiveModelTrait, I: IntoIterator, + <::Entity as EntityTrait>::PrimaryKey: + PrimaryKeyValue<::Entity>, { Insert::many(models) } diff --git a/src/entity/prelude.rs b/src/entity/prelude.rs index 8d87a4b2..be630567 100644 --- a/src/entity/prelude.rs +++ b/src/entity/prelude.rs @@ -2,8 +2,9 @@ pub use crate::{ error::*, ActiveModelBehavior, ActiveModelTrait, ColumnDef, ColumnTrait, ColumnType, DeriveActiveModel, DeriveActiveModelBehavior, DeriveColumn, DeriveCustomColumn, DeriveEntity, DeriveEntityModel, DeriveModel, DerivePrimaryKey, DeriveRelation, EntityName, EntityTrait, - EnumIter, ForeignKeyAction, Iden, IdenStatic, Linked, ModelTrait, PrimaryKeyToColumn, - PrimaryKeyTrait, QueryFilter, QueryResult, Related, RelationDef, RelationTrait, Select, Value, + EnumIter, ForeignKeyAction, Iden, IdenStatic, IntoActiveModel, Linked, ModelTrait, + PrimaryKeyToColumn, PrimaryKeyTrait, PrimaryKeyValue, QueryFilter, QueryResult, Related, + RelationDef, RelationTrait, Select, Value, }; #[cfg(feature = "with-json")] diff --git a/src/entity/primary_key.rs b/src/entity/primary_key.rs index 463f1482..530eba30 100644 --- a/src/entity/primary_key.rs +++ b/src/entity/primary_key.rs @@ -1,18 +1,11 @@ use super::{ColumnTrait, IdenStatic, Iterable}; -use crate::{TryFromU64, TryGetableMany}; +use crate::{ActiveModelTrait, EntityTrait, TryFromU64, TryGetableMany}; use sea_query::IntoValueTuple; use std::fmt::Debug; //LINT: composite primary key cannot auto increment pub trait PrimaryKeyTrait: IdenStatic + Iterable { - type ValueType: Sized - + Send - + Default - + Debug - + PartialEq - + IntoValueTuple - + TryGetableMany - + TryFromU64; + type ValueType: Sized + Send + Debug + PartialEq + IntoValueTuple + TryGetableMany + TryFromU64; fn auto_increment() -> bool; } @@ -26,3 +19,12 @@ pub trait PrimaryKeyToColumn { where Self: Sized; } + +pub trait PrimaryKeyValue +where + E: EntityTrait, +{ + fn get_primary_key_value(active_model: A) -> ::ValueType + where + A: ActiveModelTrait; +} diff --git a/src/executor/insert.rs b/src/executor/insert.rs index 1f4936ba..cb7e5555 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -5,11 +5,12 @@ use crate::{ use sea_query::InsertStatement; use std::{future::Future, marker::PhantomData}; -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Inserter where A: ActiveModelTrait, { + primary_key: Option<<<::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType>, query: InsertStatement, model: PhantomData, } @@ -48,7 +49,7 @@ where ); } } - Inserter::::new(query).exec(db) + Inserter::::new(self.primary_key, query).exec(db) // TODO: return primary key if extracted before, otherwise use InsertResult } } @@ -57,8 +58,12 @@ impl Inserter where A: ActiveModelTrait, { - pub fn new(query: InsertStatement) -> Self { + pub fn new( + primary_key: Option<<<::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType>, + query: InsertStatement, + ) -> Self { Self { + primary_key, query, model: PhantomData, } @@ -72,12 +77,13 @@ where A: 'a, { let builder = db.get_database_backend(); - exec_insert(builder.build(&self.query), db) + exec_insert(self.primary_key, builder.build(&self.query), db) } } // Only Statement impl Send async fn exec_insert( + primary_key: Option<<<::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType>, statement: Statement, db: &DatabaseConnection, ) -> Result, DbErr> @@ -86,7 +92,7 @@ where { type PrimaryKey = <::Entity as EntityTrait>::PrimaryKey; type ValueTypeOf = as PrimaryKeyTrait>::ValueType; - let last_insert_id = match db { + let last_insert_id_opt = match db { #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => { use crate::{sea_query::Iden, Iterable}; @@ -94,14 +100,19 @@ where .map(|col| col.to_string()) .collect::>(); let res = conn.query_one(statement).await?.unwrap(); - res.try_get_many("", cols.as_ref()).unwrap_or_default() + Some(res.try_get_many("", cols.as_ref()).unwrap_or_default()) } _ => { let last_insert_id = db.execute(statement).await?.last_insert_id(); - ValueTypeOf::::try_from_u64(last_insert_id) - .ok() - .unwrap_or_default() + ValueTypeOf::::try_from_u64(last_insert_id).ok() } }; + let last_insert_id = match last_insert_id_opt { + Some(last_insert_id) => last_insert_id, + None => match primary_key { + Some(primary_key) => primary_key, + None => return Err(DbErr::Exec("Fail to unpack last_insert_id".to_owned())), + }, + }; Ok(InsertResult { last_insert_id }) } diff --git a/src/query/insert.rs b/src/query/insert.rs index a65071e1..418d2b70 100644 --- a/src/query/insert.rs +++ b/src/query/insert.rs @@ -1,14 +1,18 @@ -use crate::{ActiveModelTrait, EntityName, EntityTrait, IntoActiveModel, Iterable, QueryTrait}; +use crate::{ + ActiveModelTrait, EntityName, EntityTrait, IntoActiveModel, Iterable, PrimaryKeyTrait, + PrimaryKeyValue, QueryTrait, +}; use core::marker::PhantomData; use sea_query::InsertStatement; -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Insert where A: ActiveModelTrait, { pub(crate) query: InsertStatement, pub(crate) columns: Vec, + pub(crate) primary_key: Option<<<::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType>, pub(crate) model: PhantomData, } @@ -31,6 +35,7 @@ where .into_table(A::Entity::default().table_ref()) .to_owned(), columns: Vec::new(), + primary_key: None, model: PhantomData, } } @@ -68,6 +73,8 @@ where pub fn one(m: M) -> Insert where M: IntoActiveModel, + <::Entity as EntityTrait>::PrimaryKey: + PrimaryKeyValue<::Entity>, { Self::new().add(m) } @@ -97,6 +104,8 @@ where where M: IntoActiveModel, I: IntoIterator, + <::Entity as EntityTrait>::PrimaryKey: + PrimaryKeyValue<::Entity>, { Self::new().add_many(models) } @@ -105,8 +114,16 @@ where pub fn add(mut self, m: M) -> Self where M: IntoActiveModel, + <::Entity as EntityTrait>::PrimaryKey: + PrimaryKeyValue<::Entity>, { let mut am: A = m.into_active_model(); + self.primary_key = + if !<::PrimaryKey as PrimaryKeyTrait>::auto_increment() { + Some(<::PrimaryKey as PrimaryKeyValue>::get_primary_key_value::(am.clone())) + } else { + None + }; let mut columns = Vec::new(); let mut values = Vec::new(); let columns_empty = self.columns.is_empty(); @@ -132,6 +149,8 @@ where where M: IntoActiveModel, I: IntoIterator, + <::Entity as EntityTrait>::PrimaryKey: + PrimaryKeyValue<::Entity>, { for model in models.into_iter() { self = self.add(model); From 4f090d192bf7a26024c60792cd0b9bfdddbda97b Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Mon, 27 Sep 2021 18:02:12 +0800 Subject: [PATCH 02/65] Update `DerivePrimaryKey` --- sea-orm-macros/src/derives/primary_key.rs | 30 +++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/sea-orm-macros/src/derives/primary_key.rs b/sea-orm-macros/src/derives/primary_key.rs index e677e2e6..ce7ca840 100644 --- a/sea-orm-macros/src/derives/primary_key.rs +++ b/sea-orm-macros/src/derives/primary_key.rs @@ -1,7 +1,7 @@ use heck::SnakeCase; use proc_macro2::{Ident, TokenStream}; -use quote::{quote, quote_spanned}; -use syn::{Data, DataEnum, Fields, Variant}; +use quote::{quote, quote_spanned, ToTokens}; +use syn::{punctuated::Punctuated, token::Comma, Data, DataEnum, Fields, Variant}; pub fn expand_derive_primary_key(ident: Ident, data: Data) -> syn::Result { let variants = match data { @@ -30,6 +30,21 @@ pub fn expand_derive_primary_key(ident: Ident, data: Data) -> syn::Result = + variants.iter().fold(Punctuated::new(), |mut acc, v| { + let variant = &v.ident; + acc.push( + quote! { active_model.take(#ident::#variant.into_column()).unwrap().unwrap() }, + ); + acc + }); + let mut primary_key_value = primary_key_value.to_token_stream(); + if variants.len() > 1 { + primary_key_value = quote! { + (#primary_key_value) + }; + } + Ok(quote!( impl sea_orm::Iden for #ident { fn unquoted(&self, s: &mut dyn std::fmt::Write) { @@ -61,5 +76,16 @@ pub fn expand_derive_primary_key(ident: Ident, data: Data) -> syn::Result for #ident { + fn get_primary_key_value( + mut active_model: A, + ) -> <::PrimaryKey as PrimaryKeyTrait>::ValueType + where + A: ActiveModelTrait, + { + #primary_key_value + } + } )) } From 9efaeeba8b15f8f71c2565f55976ea26d84fb651 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Mon, 27 Sep 2021 18:02:20 +0800 Subject: [PATCH 03/65] Update test cases --- tests/crud/create_cake.rs | 6 +----- tests/crud/create_lineitem.rs | 6 +----- tests/crud/create_order.rs | 6 +----- tests/sequential_op_tests.rs | 12 ++---------- tests/uuid_tests.rs | 9 +-------- 5 files changed, 6 insertions(+), 33 deletions(-) diff --git a/tests/crud/create_cake.rs b/tests/crud/create_cake.rs index 4fa914a5..df5130aa 100644 --- a/tests/crud/create_cake.rs +++ b/tests/crud/create_cake.rs @@ -58,11 +58,7 @@ pub async fn test_create_cake(db: &DbConn) { .expect("could not insert cake_baker"); assert_eq!( cake_baker_res.last_insert_id, - if cfg!(feature = "sqlx-postgres") { - (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) - } else { - Default::default() - } + (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) ); assert!(cake.is_some()); diff --git a/tests/crud/create_lineitem.rs b/tests/crud/create_lineitem.rs index da82cc82..0ba9c7d3 100644 --- a/tests/crud/create_lineitem.rs +++ b/tests/crud/create_lineitem.rs @@ -57,11 +57,7 @@ pub async fn test_create_lineitem(db: &DbConn) { .expect("could not insert cake_baker"); assert_eq!( cake_baker_res.last_insert_id, - if cfg!(feature = "sqlx-postgres") { - (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) - } else { - Default::default() - } + (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) ); // Customer diff --git a/tests/crud/create_order.rs b/tests/crud/create_order.rs index ba8ff09b..6de3d46f 100644 --- a/tests/crud/create_order.rs +++ b/tests/crud/create_order.rs @@ -57,11 +57,7 @@ pub async fn test_create_order(db: &DbConn) { .expect("could not insert cake_baker"); assert_eq!( cake_baker_res.last_insert_id, - if cfg!(feature = "sqlx-postgres") { - (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) - } else { - Default::default() - } + (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) ); // Customer diff --git a/tests/sequential_op_tests.rs b/tests/sequential_op_tests.rs index 28333d84..286e856a 100644 --- a/tests/sequential_op_tests.rs +++ b/tests/sequential_op_tests.rs @@ -84,11 +84,7 @@ async fn init_setup(db: &DatabaseConnection) { .expect("could not insert cake_baker"); assert_eq!( cake_baker_res.last_insert_id, - if cfg!(feature = "sqlx-postgres") { - (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) - } else { - Default::default() - } + (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()) ); let customer_kate = customer::ActiveModel { @@ -225,11 +221,7 @@ async fn create_cake(db: &DatabaseConnection, baker: baker::Model) -> Option Result<(), DbErr> { assert_eq!(Metadata::find().one(db).await?, Some(metadata.clone())); - assert_eq!( - res.last_insert_id, - if cfg!(feature = "sqlx-postgres") { - metadata.uuid - } else { - Default::default() - } - ); + assert_eq!(res.last_insert_id, metadata.uuid); Ok(()) } From 9bd537efe347cf3f585f673f2d01221db9cba9af Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Mon, 27 Sep 2021 18:10:45 +0800 Subject: [PATCH 04/65] Fixup --- src/executor/insert.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/executor/insert.rs b/src/executor/insert.rs index cb7e5555..aa8905be 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -100,7 +100,7 @@ where .map(|col| col.to_string()) .collect::>(); let res = conn.query_one(statement).await?.unwrap(); - Some(res.try_get_many("", cols.as_ref()).unwrap_or_default()) + res.try_get_many("", cols.as_ref()).ok() } _ => { let last_insert_id = db.execute(statement).await?.last_insert_id(); From a1a7a98a5c0a62af2521209e38e21fa867f9af30 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Tue, 28 Sep 2021 18:21:00 +0800 Subject: [PATCH 05/65] Set SqlxSqlit pool max connection to 1 --- src/driver/sqlx_sqlite.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index b02f4408..db3fa4e2 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -1,5 +1,5 @@ use sqlx::{ - sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}, + sqlite::{SqliteArguments, SqlitePoolOptions, SqliteQueryResult, SqliteRow}, Sqlite, SqlitePool, }; @@ -24,7 +24,11 @@ impl SqlxSqliteConnector { } pub async fn connect(string: &str) -> Result { - if let Ok(pool) = SqlitePool::connect(string).await { + if let Ok(pool) = SqlitePoolOptions::new() + .max_connections(1) + .connect(string) + .await + { Ok(DatabaseConnection::SqlxSqlitePoolConnection( SqlxSqlitePoolConnection { pool }, )) From 97b95bf61285d6ba7fef29e6d1d3599a54c22370 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Tue, 28 Sep 2021 18:23:42 +0800 Subject: [PATCH 06/65] cargo fmt --- src/executor/insert.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/executor/insert.rs b/src/executor/insert.rs index 630ee6a7..4a7b55a2 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,6 +1,6 @@ use crate::{ - error::*, ActiveModelTrait, DatabaseConnection, DbBackend, EntityTrait, Insert, PrimaryKeyTrait, - Statement, TryFromU64, + error::*, ActiveModelTrait, DatabaseConnection, DbBackend, EntityTrait, Insert, + PrimaryKeyTrait, Statement, TryFromU64, }; use sea_query::InsertStatement; use std::{future::Future, marker::PhantomData}; From 11781082ba52806d4e33957069d2979bf16de1a4 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Tue, 28 Sep 2021 19:00:55 +0800 Subject: [PATCH 07/65] Throw error if none of the db rows are affected --- src/error.rs | 2 ++ src/executor/insert.rs | 4 ++-- src/executor/update.rs | 5 +++++ tests/crud/updates.rs | 13 +++++++++---- tests/uuid_tests.rs | 17 ++++++++++++++++- 5 files changed, 34 insertions(+), 7 deletions(-) diff --git a/src/error.rs b/src/error.rs index 09f80b0a..f8aff775 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,6 +3,7 @@ pub enum DbErr { Conn(String), Exec(String), Query(String), + RecordNotFound(String), } impl std::error::Error for DbErr {} @@ -13,6 +14,7 @@ impl std::fmt::Display for DbErr { Self::Conn(s) => write!(f, "Connection Error: {}", s), Self::Exec(s) => write!(f, "Execution Error: {}", s), Self::Query(s) => write!(f, "Query Error: {}", s), + Self::RecordNotFound(s) => write!(f, "RecordNotFound Error: {}", s), } } } diff --git a/src/executor/insert.rs b/src/executor/insert.rs index d580f110..a44867f7 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,6 +1,6 @@ use crate::{ - error::*, ActiveModelTrait, DatabaseConnection, DbBackend, EntityTrait, Insert, PrimaryKeyTrait, - Statement, TryFromU64, + error::*, ActiveModelTrait, DatabaseConnection, DbBackend, EntityTrait, Insert, + PrimaryKeyTrait, Statement, TryFromU64, }; use sea_query::InsertStatement; use std::{future::Future, marker::PhantomData}; diff --git a/src/executor/update.rs b/src/executor/update.rs index 6c7a9873..c6668d6b 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -73,6 +73,11 @@ where // Only Statement impl Send async fn exec_update(statement: Statement, db: &DatabaseConnection) -> Result { let result = db.execute(statement).await?; + if result.rows_affected() <= 0 { + return Err(DbErr::RecordNotFound( + "None of the database rows are affected".to_owned(), + )); + } Ok(UpdateResult { rows_affected: result.rows_affected(), }) diff --git a/tests/crud/updates.rs b/tests/crud/updates.rs index c2048f9b..262031ef 100644 --- a/tests/crud/updates.rs +++ b/tests/crud/updates.rs @@ -1,5 +1,6 @@ pub use super::*; use rust_decimal_macros::dec; +use sea_orm::DbErr; use uuid::Uuid; pub async fn test_update_cake(db: &DbConn) { @@ -119,10 +120,14 @@ pub async fn test_update_deleted_customer(db: &DbConn) { ..Default::default() }; - let _customer_update_res: customer::ActiveModel = customer - .update(db) - .await - .expect("could not update customer"); + let customer_update_res = customer.update(db).await; + + assert_eq!( + customer_update_res, + Err(DbErr::RecordNotFound( + "None of the database rows are affected".to_owned() + )) + ); assert_eq!(Customer::find().count(db).await.unwrap(), init_n_customers); diff --git a/tests/uuid_tests.rs b/tests/uuid_tests.rs index e58daca4..052f57d7 100644 --- a/tests/uuid_tests.rs +++ b/tests/uuid_tests.rs @@ -1,7 +1,7 @@ pub mod common; pub use common::{bakery_chain::*, setup::*, TestContext}; -use sea_orm::{entity::prelude::*, DatabaseConnection, IntoActiveModel}; +use sea_orm::{entity::prelude::*, DatabaseConnection, IntoActiveModel, Set}; #[sea_orm_macros::test] #[cfg(any( @@ -40,5 +40,20 @@ pub async fn create_metadata(db: &DatabaseConnection) -> Result<(), DbErr> { } ); + let update_res = Metadata::update(metadata::ActiveModel { + value: Set("0.22".to_owned()), + ..metadata.clone().into_active_model() + }) + .filter(metadata::Column::Uuid.eq(Uuid::default())) + .exec(db) + .await; + + assert_eq!( + update_res, + Err(DbErr::RecordNotFound( + "None of the database rows are affected".to_owned() + )) + ); + Ok(()) } From 966f7ff9a85fcb133fb863ca1af709da5ed2c7fb Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Tue, 28 Sep 2021 19:06:02 +0800 Subject: [PATCH 08/65] Fix clippy warning --- src/executor/update.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/executor/update.rs b/src/executor/update.rs index c6668d6b..7cb60f3e 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -73,7 +73,7 @@ where // Only Statement impl Send async fn exec_update(statement: Statement, db: &DatabaseConnection) -> Result { let result = db.execute(statement).await?; - if result.rows_affected() <= 0 { + if result.rows_affected() == 0 { return Err(DbErr::RecordNotFound( "None of the database rows are affected".to_owned(), )); From 3123a9d129964058ae81a555ad83cdc6a82eab29 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Thu, 30 Sep 2021 11:40:27 +0800 Subject: [PATCH 09/65] Fix git merge conflict --- src/executor/insert.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/executor/insert.rs b/src/executor/insert.rs index 4a7b55a2..ed4da269 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -91,14 +91,13 @@ where { type PrimaryKey = <::Entity as EntityTrait>::PrimaryKey; type ValueTypeOf = as PrimaryKeyTrait>::ValueType; - let last_insert_id_opt = match db { - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => { + let last_insert_id_opt = match db.get_database_backend() { + DbBackend::Postgres => { use crate::{sea_query::Iden, Iterable}; let cols = PrimaryKey::::iter() .map(|col| col.to_string()) .collect::>(); - let res = conn.query_one(statement).await?.unwrap(); + let res = db.query_one(statement).await?.unwrap(); res.try_get_many("", cols.as_ref()).ok() } _ => { From 5497810afb08b1c6bbd0ef39ac392a94ca06f52a Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Thu, 30 Sep 2021 11:49:27 +0800 Subject: [PATCH 10/65] Remove unnecessary trait bounds --- src/entity/active_model.rs | 7 +------ src/entity/base_entity.rs | 8 +++----- src/query/insert.rs | 8 -------- 3 files changed, 4 insertions(+), 19 deletions(-) diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index c69e5b01..ab076bf0 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -1,6 +1,5 @@ use crate::{ - error::*, DatabaseConnection, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, - PrimaryKeyValue, Value, + error::*, DatabaseConnection, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, Value, }; use async_trait::async_trait; use std::fmt::Debug; @@ -70,8 +69,6 @@ pub trait ActiveModelTrait: Clone + Debug { async fn insert(self, db: &DatabaseConnection) -> Result where ::Model: IntoActiveModel, - <::Entity as EntityTrait>::PrimaryKey: - PrimaryKeyValue<::Entity>, { let am = self; let exec = ::insert(am).exec(db); @@ -96,8 +93,6 @@ pub trait ActiveModelTrait: Clone + Debug { where Self: ActiveModelBehavior, ::Model: IntoActiveModel, - <::Entity as EntityTrait>::PrimaryKey: - PrimaryKeyValue<::Entity>, { let mut am = self; am = ActiveModelBehavior::before_save(am); diff --git a/src/entity/base_entity.rs b/src/entity/base_entity.rs index 79caaf0b..d691fb3e 100644 --- a/src/entity/base_entity.rs +++ b/src/entity/base_entity.rs @@ -49,7 +49,9 @@ pub trait EntityTrait: EntityName { type Relation: RelationTrait; - type PrimaryKey: PrimaryKeyTrait + PrimaryKeyToColumn; + type PrimaryKey: PrimaryKeyTrait + + PrimaryKeyToColumn + + PrimaryKeyValue; fn belongs_to(related: R) -> RelationBuilder where @@ -299,8 +301,6 @@ pub trait EntityTrait: EntityName { fn insert(model: A) -> Insert where A: ActiveModelTrait, - <::Entity as EntityTrait>::PrimaryKey: - PrimaryKeyValue<::Entity>, { Insert::one(model) } @@ -354,8 +354,6 @@ pub trait EntityTrait: EntityName { where A: ActiveModelTrait, I: IntoIterator, - <::Entity as EntityTrait>::PrimaryKey: - PrimaryKeyValue<::Entity>, { Insert::many(models) } diff --git a/src/query/insert.rs b/src/query/insert.rs index 418d2b70..59c06b97 100644 --- a/src/query/insert.rs +++ b/src/query/insert.rs @@ -73,8 +73,6 @@ where pub fn one(m: M) -> Insert where M: IntoActiveModel, - <::Entity as EntityTrait>::PrimaryKey: - PrimaryKeyValue<::Entity>, { Self::new().add(m) } @@ -104,8 +102,6 @@ where where M: IntoActiveModel, I: IntoIterator, - <::Entity as EntityTrait>::PrimaryKey: - PrimaryKeyValue<::Entity>, { Self::new().add_many(models) } @@ -114,8 +110,6 @@ where pub fn add(mut self, m: M) -> Self where M: IntoActiveModel, - <::Entity as EntityTrait>::PrimaryKey: - PrimaryKeyValue<::Entity>, { let mut am: A = m.into_active_model(); self.primary_key = @@ -149,8 +143,6 @@ where where M: IntoActiveModel, I: IntoIterator, - <::Entity as EntityTrait>::PrimaryKey: - PrimaryKeyValue<::Entity>, { for model in models.into_iter() { self = self.add(model); From f4218dec56b0745b75f88171ee7e0202115e9397 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Thu, 30 Sep 2021 12:47:01 +0800 Subject: [PATCH 11/65] Test mock connection --- src/executor/update.rs | 105 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/src/executor/update.rs b/src/executor/update.rs index 7cb60f3e..bcbafefc 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -82,3 +82,108 @@ async fn exec_update(statement: Statement, db: &DatabaseConnection) -> Result Result<(), DbErr> { + let db = MockDatabase::new(DbBackend::Postgres) + .append_query_results(vec![ + vec![cake::Model { + id: 1, + name: "Cheese Cake".to_owned(), + }], + vec![], + vec![], + ]) + .append_exec_results(vec![ + MockExecResult { + last_insert_id: 0, + rows_affected: 1, + }, + MockExecResult { + last_insert_id: 0, + rows_affected: 0, + }, + MockExecResult { + last_insert_id: 0, + rows_affected: 0, + }, + ]) + .into_connection(); + + let model = cake::Model { + id: 1, + name: "New York Cheese".to_owned(), + }; + + assert_eq!( + cake::ActiveModel { + name: Set("Cheese Cake".to_owned()), + ..model.into_active_model() + } + .update(&db) + .await?, + cake::Model { + id: 1, + name: "Cheese Cake".to_owned(), + } + .into_active_model() + ); + + let model = cake::Model { + id: 2, + name: "New York Cheese".to_owned(), + }; + + assert_eq!( + cake::ActiveModel { + name: Set("Cheese Cake".to_owned()), + ..model.clone().into_active_model() + } + .update(&db) + .await, + Err(DbErr::RecordNotFound( + "None of the database rows are affected".to_owned() + )) + ); + + assert_eq!( + cake::Entity::update(cake::ActiveModel { + name: Set("Cheese Cake".to_owned()), + ..model.into_active_model() + }) + .exec(&db) + .await, + Err(DbErr::RecordNotFound( + "None of the database rows are affected".to_owned() + )) + ); + + assert_eq!( + db.into_transaction_log(), + vec![ + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, + vec!["Cheese Cake".into(), 1i32.into()] + ), + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, + vec!["Cheese Cake".into(), 2i32.into()] + ), + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, + vec!["Cheese Cake".into(), 2i32.into()] + ), + ] + ); + + Ok(()) + } +} From 602690e9a7c8c281cc33a185234fe1ee09492c47 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Thu, 30 Sep 2021 19:16:55 +0800 Subject: [PATCH 12/65] Remove unneeded --- src/executor/update.rs | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/executor/update.rs b/src/executor/update.rs index bcbafefc..2bc5ed80 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -91,14 +91,6 @@ mod tests { #[smol_potat::test] async fn update_record_not_found_1() -> Result<(), DbErr> { let db = MockDatabase::new(DbBackend::Postgres) - .append_query_results(vec![ - vec![cake::Model { - id: 1, - name: "Cheese Cake".to_owned(), - }], - vec![], - vec![], - ]) .append_exec_results(vec![ MockExecResult { last_insert_id: 0, From 91fb97c12ae1e4ce4ebf87977f960198f475e365 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Mon, 4 Oct 2021 11:18:42 +0800 Subject: [PATCH 13/65] `Update::many()` will not raise `DbErr::RecordNotFound` error --- src/executor/update.rs | 65 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 6 deletions(-) diff --git a/src/executor/update.rs b/src/executor/update.rs index 2bc5ed80..b564165c 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -7,9 +7,10 @@ use std::future::Future; #[derive(Clone, Debug)] pub struct Updater { query: UpdateStatement, + check_record_exists: bool, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct UpdateResult { pub rows_affected: u64, } @@ -39,7 +40,15 @@ where impl Updater { pub fn new(query: UpdateStatement) -> Self { - Self { query } + Self { + query, + check_record_exists: false, + } + } + + pub fn check_record_exists(mut self) -> Self { + self.check_record_exists = true; + self } pub fn exec( @@ -47,7 +56,7 @@ impl Updater { db: &DatabaseConnection, ) -> impl Future> + '_ { let builder = db.get_database_backend(); - exec_update(builder.build(&self.query), db) + exec_update(builder.build(&self.query), db, self.check_record_exists) } } @@ -66,14 +75,18 @@ async fn exec_update_and_return_original( where A: ActiveModelTrait, { - Updater::new(query).exec(db).await?; + Updater::new(query).check_record_exists().exec(db).await?; Ok(model) } // Only Statement impl Send -async fn exec_update(statement: Statement, db: &DatabaseConnection) -> Result { +async fn exec_update( + statement: Statement, + db: &DatabaseConnection, + check_record_exists: bool, +) -> Result { let result = db.execute(statement).await?; - if result.rows_affected() == 0 { + if check_record_exists && result.rows_affected() == 0 { return Err(DbErr::RecordNotFound( "None of the database rows are affected".to_owned(), )); @@ -87,6 +100,7 @@ async fn exec_update(statement: Statement, db: &DatabaseConnection) -> Result Result<(), DbErr> { @@ -104,6 +118,14 @@ mod tests { last_insert_id: 0, rows_affected: 0, }, + MockExecResult { + last_insert_id: 0, + rows_affected: 0, + }, + MockExecResult { + last_insert_id: 0, + rows_affected: 0, + }, ]) .into_connection(); @@ -145,6 +167,18 @@ mod tests { assert_eq!( cake::Entity::update(cake::ActiveModel { + name: Set("Cheese Cake".to_owned()), + ..model.clone().into_active_model() + }) + .exec(&db) + .await, + Err(DbErr::RecordNotFound( + "None of the database rows are affected".to_owned() + )) + ); + + assert_eq!( + Update::one(cake::ActiveModel { name: Set("Cheese Cake".to_owned()), ..model.into_active_model() }) @@ -155,6 +189,15 @@ mod tests { )) ); + assert_eq!( + Update::many(cake::Entity) + .col_expr(cake::Column::Name, Expr::value("Cheese Cake".to_owned())) + .filter(cake::Column::Id.eq(2)) + .exec(&db) + .await, + Ok(UpdateResult { rows_affected: 0 }) + ); + assert_eq!( db.into_transaction_log(), vec![ @@ -173,6 +216,16 @@ mod tests { r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, vec!["Cheese Cake".into(), 2i32.into()] ), + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, + vec!["Cheese Cake".into(), 2i32.into()] + ), + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, + vec!["Cheese Cake".into(), 2i32.into()] + ), ] ); From 02ebc9745c8b216cc466e1d702b2daa49a3ed321 Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Mon, 4 Oct 2021 10:42:37 +0800 Subject: [PATCH 14/65] Transaction 3 --- Cargo.toml | 1 + examples/rocket_example/src/main.rs | 16 +- examples/rocket_example/src/setup.rs | 2 +- src/database/connection.rs | 94 ++++++-- src/database/db_connection.rs | 45 ++++ src/database/db_transaction.rs | 308 +++++++++++++++++++++++++++ src/database/mock.rs | 7 +- src/database/mod.rs | 6 + src/database/stream/mod.rs | 5 + src/database/stream/query.rs | 108 ++++++++++ src/database/stream/transaction.rs | 82 +++++++ src/driver/mock.rs | 21 +- src/driver/mod.rs | 6 +- src/driver/sqlx_mysql.rs | 51 ++++- src/driver/sqlx_postgres.rs | 51 ++++- src/driver/sqlx_sqlite.rs | 51 ++++- src/entity/active_model.rs | 22 +- src/entity/base_entity.rs | 2 +- src/executor/delete.rs | 35 +-- src/executor/insert.rs | 34 +-- src/executor/paginator.rs | 12 +- src/executor/select.rs | 116 +++++++--- src/executor/update.rs | 40 ++-- src/query/mod.rs | 2 +- tests/basic.rs | 2 +- tests/common/setup/mod.rs | 2 +- tests/common/setup/schema.rs | 2 +- tests/query_tests.rs | 2 +- tests/sequential_op_tests.rs | 2 +- tests/stream_tests.rs | 37 ++++ tests/transaction_tests.rs | 90 ++++++++ 31 files changed, 1093 insertions(+), 161 deletions(-) create mode 100644 src/database/db_connection.rs create mode 100644 src/database/db_transaction.rs create mode 100644 src/database/stream/mod.rs create mode 100644 src/database/stream/query.rs create mode 100644 src/database/stream/transaction.rs create mode 100644 tests/stream_tests.rs create mode 100644 tests/transaction_tests.rs diff --git a/Cargo.toml b/Cargo.toml index a4466152..b1fa90c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } sqlx = { version = "^0.5", optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true } +ouroboros = "0.11" [dev-dependencies] smol = { version = "^1.2" } diff --git a/examples/rocket_example/src/main.rs b/examples/rocket_example/src/main.rs index 853eaaaa..43f227c6 100644 --- a/examples/rocket_example/src/main.rs +++ b/examples/rocket_example/src/main.rs @@ -42,7 +42,7 @@ async fn create(conn: Connection, post_form: Form) -> Flash, post_form: Form) -> Flash", data = "")] async fn update(conn: Connection, id: i32, post_form: Form) -> Flash { let post: post::ActiveModel = Post::find_by_id(id) - .one(&conn) + .one(&*conn) .await .unwrap() .unwrap() @@ -65,7 +65,7 @@ async fn update(conn: Connection, id: i32, post_form: Form) -> title: Set(form.title.to_owned()), text: Set(form.text.to_owned()), } - .save(&conn) + .save(&*conn) .await .expect("could not edit post"); @@ -89,7 +89,7 @@ async fn list( // Setup paginator let paginator = Post::find() .order_by_asc(post::Column::Id) - .paginate(&conn, posts_per_page); + .paginate(&*conn, posts_per_page); let num_pages = paginator.num_pages().await.ok().unwrap(); // Fetch paginated posts @@ -113,7 +113,7 @@ async fn list( #[get("/")] async fn edit(conn: Connection, id: i32) -> Template { let post: Option = Post::find_by_id(id) - .one(&conn) + .one(&*conn) .await .expect("could not find post"); @@ -128,20 +128,20 @@ async fn edit(conn: Connection, id: i32) -> Template { #[delete("/")] async fn delete(conn: Connection, id: i32) -> Flash { let post: post::ActiveModel = Post::find_by_id(id) - .one(&conn) + .one(&*conn) .await .unwrap() .unwrap() .into(); - post.delete(&conn).await.unwrap(); + post.delete(&*conn).await.unwrap(); Flash::success(Redirect::to("/"), "Post successfully deleted.") } #[delete("/")] async fn destroy(conn: Connection) -> Result<()> { - Post::delete_many().exec(&conn).await.unwrap(); + Post::delete_many().exec(&*conn).await.unwrap(); Ok(()) } diff --git a/examples/rocket_example/src/setup.rs b/examples/rocket_example/src/setup.rs index 034e8b53..91bbb7b1 100644 --- a/examples/rocket_example/src/setup.rs +++ b/examples/rocket_example/src/setup.rs @@ -1,5 +1,5 @@ use sea_orm::sea_query::{ColumnDef, TableCreateStatement}; -use sea_orm::{error::*, sea_query, DbConn, ExecResult}; +use sea_orm::{query::*, error::*, sea_query, DbConn, ExecResult}; async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result { let builder = db.get_database_backend(); diff --git a/src/database/connection.rs b/src/database/connection.rs index e5ec4e2c..630b8b96 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -1,4 +1,5 @@ -use crate::{error::*, ExecResult, QueryResult, Statement, StatementBuilder}; +use std::{future::Future, pin::Pin, sync::Arc}; +use crate::{DatabaseTransaction, ConnectionTrait, ExecResult, QueryResult, Statement, StatementBuilder, TransactionError, error::*}; use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; #[cfg_attr(not(feature = "mock"), derive(Clone))] @@ -10,7 +11,7 @@ pub enum DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection), #[cfg(feature = "mock")] - MockDatabaseConnection(crate::MockDatabaseConnection), + MockDatabaseConnection(Arc), Disconnected, } @@ -51,8 +52,11 @@ impl std::fmt::Debug for DatabaseConnection { } } -impl DatabaseConnection { - pub fn get_database_backend(&self) -> DbBackend { +#[async_trait::async_trait] +impl<'a> ConnectionTrait<'a> for DatabaseConnection { + type Stream = crate::QueryStream; + + fn get_database_backend(&self) -> DbBackend { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql, @@ -66,7 +70,7 @@ impl DatabaseConnection { } } - pub async fn execute(&self, stmt: Statement) -> Result { + async fn execute(&self, stmt: Statement) -> Result { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await, @@ -75,12 +79,12 @@ impl DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt).await, + DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } - pub async fn query_one(&self, stmt: Statement) -> Result, DbErr> { + async fn query_one(&self, stmt: Statement) -> Result, DbErr> { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await, @@ -89,12 +93,12 @@ impl DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt).await, + DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } - pub async fn query_all(&self, stmt: Statement) -> Result, DbErr> { + async fn query_all(&self, stmt: Statement) -> Result, DbErr> { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await, @@ -103,12 +107,76 @@ impl DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt).await, + DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } + fn stream(&'a self, stmt: Statement) -> Pin> + 'a>> { + Box::pin(async move { + Ok(match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => crate::QueryStream::from((Arc::clone(conn), stmt)), + DatabaseConnection::Disconnected => panic!("Disconnected"), + }) + }) + } + + async fn begin(&self) -> Result { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => DatabaseTransaction::new_mock(Arc::clone(conn)).await, + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } + + /// Execute the function inside a transaction. + /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. + async fn transaction(&self, _callback: F) -> Result> + where + F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, + T: Send, + E: std::error::Error + Send, + { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.transaction(_callback).await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => { + let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)).await.map_err(|e| TransactionError::Connection(e))?; + transaction.run(_callback).await + }, + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } + #[cfg(feature = "mock")] + fn is_mock_connection(&self) -> bool { + match self { + DatabaseConnection::MockDatabaseConnection(_) => true, + _ => false, + } + } +} + +#[cfg(feature = "mock")] +impl DatabaseConnection { pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection { match self { DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn, @@ -116,12 +184,6 @@ impl DatabaseConnection { } } - #[cfg(not(feature = "mock"))] - pub fn as_mock_connection(&self) -> Option { - None - } - - #[cfg(feature = "mock")] pub fn into_transaction_log(self) -> Vec { let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap(); mocker.drain_transaction_log() diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs new file mode 100644 index 00000000..569f3896 --- /dev/null +++ b/src/database/db_connection.rs @@ -0,0 +1,45 @@ +use std::{future::Future, pin::Pin, sync::Arc}; +use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, MockDatabaseConnection, QueryResult, Statement, TransactionError}; +use futures::Stream; +#[cfg(feature = "sqlx-dep")] +use sqlx::pool::PoolConnection; + +pub(crate) enum InnerConnection { + #[cfg(feature = "sqlx-mysql")] + MySql(PoolConnection), + #[cfg(feature = "sqlx-postgres")] + Postgres(PoolConnection), + #[cfg(feature = "sqlx-sqlite")] + Sqlite(PoolConnection), + #[cfg(feature = "mock")] + Mock(Arc), +} + +#[async_trait::async_trait] +pub trait ConnectionTrait<'a>: Sync { + type Stream: Stream>; + + fn get_database_backend(&self) -> DbBackend; + + async fn execute(&self, stmt: Statement) -> Result; + + async fn query_one(&self, stmt: Statement) -> Result, DbErr>; + + async fn query_all(&self, stmt: Statement) -> Result, DbErr>; + + fn stream(&'a self, stmt: Statement) -> Pin> + 'a>>; + + async fn begin(&self) -> Result; + + /// Execute the function inside a transaction. + /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. + async fn transaction(&self, callback: F) -> Result> + where + F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, + T: Send, + E: std::error::Error + Send; + + fn is_mock_connection(&self) -> bool { + false + } +} diff --git a/src/database/db_transaction.rs b/src/database/db_transaction.rs new file mode 100644 index 00000000..b403971e --- /dev/null +++ b/src/database/db_transaction.rs @@ -0,0 +1,308 @@ +use std::{sync::Arc, future::Future, pin::Pin}; +use crate::{ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, Statement, TransactionStream, debug_print}; +use futures::lock::Mutex; +#[cfg(feature = "sqlx-dep")] +use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err}; +#[cfg(feature = "sqlx-dep")] +use sqlx::{pool::PoolConnection, TransactionManager}; + +// a Transaction is just a sugar for a connection where START TRANSACTION has been executed +pub struct DatabaseTransaction { + conn: Arc>, + backend: DbBackend, + open: bool, +} + +impl std::fmt::Debug for DatabaseTransaction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DatabaseTransaction") + } +} + +impl DatabaseTransaction { + #[cfg(feature = "sqlx-mysql")] + pub(crate) async fn new_mysql(inner: PoolConnection) -> Result { + Self::build(Arc::new(Mutex::new(InnerConnection::MySql(inner))), DbBackend::MySql).await + } + + #[cfg(feature = "sqlx-postgres")] + pub(crate) async fn new_postgres(inner: PoolConnection) -> Result { + Self::build(Arc::new(Mutex::new(InnerConnection::Postgres(inner))), DbBackend::Postgres).await + } + + #[cfg(feature = "sqlx-sqlite")] + pub(crate) async fn new_sqlite(inner: PoolConnection) -> Result { + Self::build(Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), DbBackend::Sqlite).await + } + + #[cfg(feature = "mock")] + pub(crate) async fn new_mock(inner: Arc) -> Result { + let backend = inner.get_database_backend(); + Self::build(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await + } + + async fn build(conn: Arc>, backend: DbBackend) -> Result { + let res = DatabaseTransaction { + conn, + backend, + open: true, + }; + match *res.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? + }, + // should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } + Ok(res) + } + + pub(crate) async fn run(self, callback: F) -> Result> + where + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + T: Send, + E: std::error::Error + Send, + { + let res = callback(&self).await.map_err(|e| TransactionError::Transaction(e)); + if res.is_ok() { + self.commit().await.map_err(|e| TransactionError::Connection(e))?; + } + else { + self.rollback().await.map_err(|e| TransactionError::Connection(e))?; + } + res + } + + pub async fn commit(mut self) -> Result<(), DbErr> { + self.open = false; + match *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? + }, + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } + Ok(()) + } + + pub async fn rollback(mut self) -> Result<(), DbErr> { + self.open = false; + match *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? + }, + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } + Ok(()) + } + + // the rollback is queued and will be performed on next async operation, like returning the connection to the pool + fn start_rollback(&mut self) { + if self.open { + if let Some(mut conn) = self.conn.try_lock() { + match &mut *conn { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + ::TransactionManager::start_rollback(c); + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + ::TransactionManager::start_rollback(c); + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + ::TransactionManager::start_rollback(c); + }, + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } + } + else { + //this should never happen + panic!("Dropping a locked Transaction"); + } + } + } +} + +impl Drop for DatabaseTransaction { + fn drop(&mut self) { + self.start_rollback(); + } +} + +#[async_trait::async_trait] +impl<'a> ConnectionTrait<'a> for DatabaseTransaction { + type Stream = TransactionStream<'a>; + + fn get_database_backend(&self) -> DbBackend { + // this way we don't need to lock + self.backend + } + + async fn execute(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query.execute(conn).await + .map(Into::into) + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query.execute(conn).await + .map(Into::into) + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query.execute(conn).await + .map(Into::into) + }, + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.execute(stmt), + }; + #[cfg(feature = "sqlx-dep")] + _res.map_err(sqlx_error_to_exec_err) + } + + async fn query_one(&self, stmt: Statement) -> Result, DbErr> { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query.fetch_one(conn).await + .map(|row| Some(row.into())) + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query.fetch_one(conn).await + .map(|row| Some(row.into())) + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query= crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query.fetch_one(conn).await + .map(|row| Some(row.into())) + }, + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.query_one(stmt), + }; + #[cfg(feature = "sqlx-dep")] + if let Err(sqlx::Error::RowNotFound) = _res { + Ok(None) + } + else { + _res.map_err(sqlx_error_to_query_err) + } + } + + async fn query_all(&self, stmt: Statement) -> Result, DbErr> { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query.fetch_all(conn).await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query.fetch_all(conn).await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query.fetch_all(conn).await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + }, + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.query_all(stmt), + }; + #[cfg(feature = "sqlx-dep")] + _res.map_err(sqlx_error_to_query_err) + } + + fn stream(&'a self, stmt: Statement) -> Pin> + 'a>> { + Box::pin(async move { + Ok(crate::TransactionStream::build(self.conn.lock().await, stmt).await) + }) + } + + async fn begin(&self) -> Result { + DatabaseTransaction::build(Arc::clone(&self.conn), self.backend).await + } + + /// Execute the function inside a transaction. + /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. + async fn transaction(&self, _callback: F) -> Result> + where + F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, + T: Send, + E: std::error::Error + Send, + { + let transaction = self.begin().await.map_err(|e| TransactionError::Connection(e))?; + transaction.run(_callback).await + } +} + +#[derive(Debug)] +pub enum TransactionError +where E: std::error::Error { + Connection(DbErr), + Transaction(E), +} + +impl std::fmt::Display for TransactionError +where E: std::error::Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TransactionError::Connection(e) => std::fmt::Display::fmt(e, f), + TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f), + } + } +} + +impl std::error::Error for TransactionError +where E: std::error::Error {} diff --git a/src/database/mock.rs b/src/database/mock.rs index ccb34a49..d9e1f9d0 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -4,7 +4,7 @@ use crate::{ Statement, Transaction, }; use sea_query::{Value, ValueType}; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, sync::Arc}; #[derive(Debug)] pub struct MockDatabase { @@ -40,7 +40,7 @@ impl MockDatabase { } pub fn into_connection(self) -> DatabaseConnection { - DatabaseConnection::MockDatabaseConnection(MockDatabaseConnection::new(self)) + DatabaseConnection::MockDatabaseConnection(Arc::new(MockDatabaseConnection::new(self))) } pub fn append_exec_results(mut self, mut vec: Vec) -> Self { @@ -100,7 +100,8 @@ impl MockRow { where T: ValueType, { - Ok(self.values.get(col).unwrap().clone().unwrap()) + T::try_from(self.values.get(col).unwrap().clone()) + .map_err(|e| DbErr::Query(e.to_string())) } pub fn into_column_value_tuples(self) -> impl Iterator { diff --git a/src/database/mod.rs b/src/database/mod.rs index f61343c1..ce4127e3 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -3,12 +3,18 @@ mod connection; mod mock; mod statement; mod transaction; +mod db_connection; +mod db_transaction; +mod stream; pub use connection::*; #[cfg(feature = "mock")] pub use mock::*; pub use statement::*; pub use transaction::*; +pub use db_connection::*; +pub use db_transaction::*; +pub use stream::*; use crate::DbErr; diff --git a/src/database/stream/mod.rs b/src/database/stream/mod.rs new file mode 100644 index 00000000..774cf45f --- /dev/null +++ b/src/database/stream/mod.rs @@ -0,0 +1,5 @@ +mod query; +mod transaction; + +pub use query::*; +pub use transaction::*; diff --git a/src/database/stream/query.rs b/src/database/stream/query.rs new file mode 100644 index 00000000..553d9f7b --- /dev/null +++ b/src/database/stream/query.rs @@ -0,0 +1,108 @@ +use std::{pin::Pin, task::Poll, sync::Arc}; + +use futures::Stream; +#[cfg(feature = "sqlx-dep")] +use futures::TryStreamExt; + +#[cfg(feature = "sqlx-dep")] +use sqlx::{pool::PoolConnection, Executor}; + +use crate::{DbErr, InnerConnection, QueryResult, Statement}; + +#[ouroboros::self_referencing] +pub struct QueryStream { + stmt: Statement, + conn: InnerConnection, + #[borrows(mut conn, stmt)] + #[not_covariant] + stream: Pin> + 'this>>, +} + +#[cfg(feature = "sqlx-mysql")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::MySql(conn)) + } +} + +#[cfg(feature = "sqlx-postgres")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::Postgres(conn)) + } +} + +#[cfg(feature = "sqlx-sqlite")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::Sqlite(conn)) + } +} + +#[cfg(feature = "mock")] +impl From<(Arc, Statement)> for QueryStream { + fn from((conn, stmt): (Arc, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::Mock(conn)) + } +} + +impl std::fmt::Debug for QueryStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "QueryStream") + } +} + +impl QueryStream { + fn build(stmt: Statement, conn: InnerConnection) -> QueryStream { + QueryStreamBuilder { + stmt, + conn, + stream_builder: |conn, stmt| { + match conn { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + let query = crate::driver::sqlx_mysql::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + let query = crate::driver::sqlx_postgres::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) + }, + #[cfg(feature = "mock")] + InnerConnection::Mock(c) => { + c.fetch(stmt) + }, + } + }, + }.build() + } +} + +impl Stream for QueryStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + this.with_stream_mut(|stream| { + stream.as_mut().poll_next(cx) + }) + } +} diff --git a/src/database/stream/transaction.rs b/src/database/stream/transaction.rs new file mode 100644 index 00000000..d945f409 --- /dev/null +++ b/src/database/stream/transaction.rs @@ -0,0 +1,82 @@ +use std::{ops::DerefMut, pin::Pin, task::Poll}; + +use futures::Stream; +#[cfg(feature = "sqlx-dep")] +use futures::TryStreamExt; + +#[cfg(feature = "sqlx-dep")] +use sqlx::Executor; + +use futures::lock::MutexGuard; + +use crate::{DbErr, InnerConnection, QueryResult, Statement}; + +#[ouroboros::self_referencing] +pub struct TransactionStream<'a> { + stmt: Statement, + conn: MutexGuard<'a, InnerConnection>, + #[borrows(mut conn, stmt)] + #[not_covariant] + stream: Pin> + 'this>>, +} + +impl<'a> std::fmt::Debug for TransactionStream<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TransactionStream") + } +} + +impl<'a> TransactionStream<'a> { + pub(crate) async fn build(conn: MutexGuard<'a, InnerConnection>, stmt: Statement) -> TransactionStream<'a> { + TransactionStreamAsyncBuilder { + stmt, + conn, + stream_builder: |conn, stmt| Box::pin(async move { + match conn.deref_mut() { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + let query = crate::driver::sqlx_mysql::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) as Pin>>> + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + let query = crate::driver::sqlx_postgres::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) as Pin>>> + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) as Pin>>> + }, + #[cfg(feature = "mock")] + InnerConnection::Mock(c) => { + c.fetch(stmt) + }, + } + }), + }.build().await + } +} + +impl<'a> Stream for TransactionStream<'a> { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + this.with_stream_mut(|stream| { + stream.as_mut().poll_next(cx) + }) + } +} diff --git a/src/driver/mock.rs b/src/driver/mock.rs index 0e398586..388ac911 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -2,11 +2,11 @@ use crate::{ debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult, Statement, Transaction, }; -use std::fmt::Debug; -use std::sync::{ +use std::{fmt::Debug, pin::Pin, sync::{Arc, atomic::{AtomicUsize, Ordering}, Mutex, -}; +}}; +use futures::Stream; #[derive(Debug)] pub struct MockDatabaseConnector; @@ -50,7 +50,7 @@ impl MockDatabaseConnector { macro_rules! connect_mock_db { ( $syntax: expr ) => { Ok(DatabaseConnection::MockDatabaseConnection( - MockDatabaseConnection::new(MockDatabase::new($syntax)), + Arc::new(MockDatabaseConnection::new(MockDatabase::new($syntax))), )) }; } @@ -86,25 +86,32 @@ impl MockDatabaseConnection { &self.mocker } - pub async fn execute(&self, statement: Statement) -> Result { + pub fn execute(&self, statement: Statement) -> Result { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); self.mocker.lock().unwrap().execute(counter, statement) } - pub async fn query_one(&self, statement: Statement) -> Result, DbErr> { + pub fn query_one(&self, statement: Statement) -> Result, DbErr> { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); let result = self.mocker.lock().unwrap().query(counter, statement)?; Ok(result.into_iter().next()) } - pub async fn query_all(&self, statement: Statement) -> Result, DbErr> { + pub fn query_all(&self, statement: Statement) -> Result, DbErr> { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); self.mocker.lock().unwrap().query(counter, statement) } + pub fn fetch(&self, statement: &Statement) -> Pin>>> { + match self.query_all(statement.clone()) { + Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(|r| Ok(r)))), + Err(e) => Box::pin(futures::stream::iter(Some(Err(e)).into_iter())), + } + } + pub fn get_database_backend(&self) -> DbBackend { self.mocker.lock().unwrap().get_database_backend() } diff --git a/src/driver/mod.rs b/src/driver/mod.rs index 6f6cfb64..33b6c847 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -3,11 +3,11 @@ mod mock; #[cfg(feature = "sqlx-dep")] mod sqlx_common; #[cfg(feature = "sqlx-mysql")] -mod sqlx_mysql; +pub(crate) mod sqlx_mysql; #[cfg(feature = "sqlx-postgres")] -mod sqlx_postgres; +pub(crate) mod sqlx_postgres; #[cfg(feature = "sqlx-sqlite")] -mod sqlx_sqlite; +pub(crate) mod sqlx_sqlite; #[cfg(feature = "mock")] pub use mock::*; diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index c542a9b4..75e6e5ff 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -1,12 +1,11 @@ -use sqlx::{ - mysql::{MySqlArguments, MySqlQueryResult, MySqlRow}, - MySql, MySqlPool, -}; +use std::{future::Future, pin::Pin}; + +use sqlx::{MySql, MySqlPool, mysql::{MySqlArguments, MySqlQueryResult, MySqlRow}}; sea_query::sea_query_driver_mysql!(); use sea_query_driver_mysql::bind_query; -use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; +use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; use super::sqlx_common::*; @@ -20,7 +19,7 @@ pub struct SqlxMySqlPoolConnection { impl SqlxMySqlConnector { pub fn accepts(string: &str) -> bool { - DbBackend::MySql.is_prefix_of(string) + string.starts_with("mysql://") } pub async fn connect(string: &str) -> Result { @@ -91,6 +90,44 @@ impl SqlxMySqlPoolConnection { )) } } + + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn begin(&self) -> Result { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_mysql(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn transaction(&self, callback: F) -> Result> + where + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + T: Send, + E: std::error::Error + Send, + { + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_mysql(conn).await.map_err(|e| TransactionError::Connection(e))?; + transaction.run(callback).await + } else { + Err(TransactionError::Connection(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + ))) + } + } } impl From for QueryResult { @@ -109,7 +146,7 @@ impl From for ExecResult { } } -fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySqlArguments> { +pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySqlArguments> { let mut query = sqlx::query(&stmt.sql); if let Some(values) = &stmt.values { query = bind_query(query, values); diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index fb5402eb..c9949375 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -1,12 +1,11 @@ -use sqlx::{ - postgres::{PgArguments, PgQueryResult, PgRow}, - PgPool, Postgres, -}; +use std::{future::Future, pin::Pin}; + +use sqlx::{PgPool, Postgres, postgres::{PgArguments, PgQueryResult, PgRow}}; sea_query::sea_query_driver_postgres!(); use sea_query_driver_postgres::bind_query; -use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; +use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; use super::sqlx_common::*; @@ -20,7 +19,7 @@ pub struct SqlxPostgresPoolConnection { impl SqlxPostgresConnector { pub fn accepts(string: &str) -> bool { - DbBackend::Postgres.is_prefix_of(string) + string.starts_with("postgres://") } pub async fn connect(string: &str) -> Result { @@ -91,6 +90,44 @@ impl SqlxPostgresPoolConnection { )) } } + + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn begin(&self) -> Result { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_postgres(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn transaction(&self, callback: F) -> Result> + where + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + T: Send, + E: std::error::Error + Send, + { + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_postgres(conn).await.map_err(|e| TransactionError::Connection(e))?; + transaction.run(callback).await + } else { + Err(TransactionError::Connection(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + ))) + } + } } impl From for QueryResult { @@ -109,7 +146,7 @@ impl From for ExecResult { } } -fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Postgres, PgArguments> { +pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Postgres, PgArguments> { let mut query = sqlx::query(&stmt.sql); if let Some(values) = &stmt.values { query = bind_query(query, values); diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index b02f4408..bf06a265 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -1,12 +1,11 @@ -use sqlx::{ - sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}, - Sqlite, SqlitePool, -}; +use std::{future::Future, pin::Pin}; + +use sqlx::{Sqlite, SqlitePool, sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}}; sea_query::sea_query_driver_sqlite!(); use sea_query_driver_sqlite::bind_query; -use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; +use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; use super::sqlx_common::*; @@ -20,7 +19,7 @@ pub struct SqlxSqlitePoolConnection { impl SqlxSqliteConnector { pub fn accepts(string: &str) -> bool { - DbBackend::Sqlite.is_prefix_of(string) + string.starts_with("sqlite:") } pub async fn connect(string: &str) -> Result { @@ -91,6 +90,44 @@ impl SqlxSqlitePoolConnection { )) } } + + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn begin(&self) -> Result { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_sqlite(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn transaction(&self, callback: F) -> Result> + where + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + T: Send, + E: std::error::Error + Send, + { + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_sqlite(conn).await.map_err(|e| TransactionError::Connection(e))?; + transaction.run(callback).await + } else { + Err(TransactionError::Connection(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + ))) + } + } } impl From for QueryResult { @@ -109,7 +146,7 @@ impl From for ExecResult { } } -fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqliteArguments> { +pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqliteArguments> { let mut query = sqlx::query(&stmt.sql); if let Some(values) = &stmt.values { query = bind_query(query, values); diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index cfcb0bbd..32e9d77d 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -1,5 +1,5 @@ use crate::{ - error::*, DatabaseConnection, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, + error::*, ConnectionTrait, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, PrimaryKeyTrait, Value, }; use async_trait::async_trait; @@ -67,9 +67,11 @@ pub trait ActiveModelTrait: Clone + Debug { fn default() -> Self; - async fn insert(self, db: &DatabaseConnection) -> Result + async fn insert<'a, C>(self, db: &'a C) -> Result where ::Model: IntoActiveModel, + C: ConnectionTrait<'a>, + Self: 'a, { let am = self; let exec = ::insert(am).exec(db); @@ -90,17 +92,22 @@ pub trait ActiveModelTrait: Clone + Debug { } } - async fn update(self, db: &DatabaseConnection) -> Result { + async fn update<'a, C>(self, db: &'a C) -> Result + where + C: ConnectionTrait<'a>, + Self: 'a, + { let exec = Self::Entity::update(self).exec(db); exec.await } /// Insert the model if primary key is unset, update otherwise. /// Only works if the entity has auto increment primary key. - async fn save(self, db: &DatabaseConnection) -> Result + async fn save<'a, C>(self, db: &'a C) -> Result where - Self: ActiveModelBehavior, + Self: ActiveModelBehavior + 'a, ::Model: IntoActiveModel, + C: ConnectionTrait<'a>, { let mut am = self; am = ActiveModelBehavior::before_save(am); @@ -122,9 +129,10 @@ pub trait ActiveModelTrait: Clone + Debug { } /// Delete an active model by its primary key - async fn delete(self, db: &DatabaseConnection) -> Result + async fn delete<'a, C>(self, db: &'a C) -> Result where - Self: ActiveModelBehavior, + Self: ActiveModelBehavior + 'a, + C: ConnectionTrait<'a>, { let mut am = self; am = ActiveModelBehavior::before_delete(am); diff --git a/src/entity/base_entity.rs b/src/entity/base_entity.rs index 764f2524..aef46207 100644 --- a/src/entity/base_entity.rs +++ b/src/entity/base_entity.rs @@ -510,7 +510,7 @@ pub trait EntityTrait: EntityName { /// /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; + /// # use sea_orm::{entity::*, error::*, query::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_exec_results(vec![ diff --git a/src/executor/delete.rs b/src/executor/delete.rs index 807bc544..85b37cb0 100644 --- a/src/executor/delete.rs +++ b/src/executor/delete.rs @@ -1,6 +1,4 @@ -use crate::{ - error::*, ActiveModelTrait, DatabaseConnection, DeleteMany, DeleteOne, EntityTrait, Statement, -}; +use crate::{ActiveModelTrait, ConnectionTrait, DeleteMany, DeleteOne, EntityTrait, Statement, error::*}; use sea_query::DeleteStatement; use std::future::Future; @@ -18,10 +16,11 @@ impl<'a, A: 'a> DeleteOne where A: ActiveModelTrait, { - pub fn exec( + pub fn exec( self, - db: &'a DatabaseConnection, - ) -> impl Future> + 'a { + db: &'a C, + ) -> impl Future> + 'a + where C: ConnectionTrait<'a> { // so that self is dropped before entering await exec_delete_only(self.query, db) } @@ -31,10 +30,11 @@ impl<'a, E> DeleteMany where E: EntityTrait, { - pub fn exec( + pub fn exec( self, - db: &'a DatabaseConnection, - ) -> impl Future> + 'a { + db: &'a C, + ) -> impl Future> + 'a + where C: ConnectionTrait<'a> { // so that self is dropped before entering await exec_delete_only(self.query, db) } @@ -45,24 +45,27 @@ impl Deleter { Self { query } } - pub fn exec( + pub fn exec<'a, C>( self, - db: &DatabaseConnection, - ) -> impl Future> + '_ { + db: &'a C, + ) -> impl Future> + '_ + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); exec_delete(builder.build(&self.query), db) } } -async fn exec_delete_only( +async fn exec_delete_only<'a, C>( query: DeleteStatement, - db: &DatabaseConnection, -) -> Result { + db: &'a C, +) -> Result +where C: ConnectionTrait<'a> { Deleter::new(query).exec(db).await } // Only Statement impl Send -async fn exec_delete(statement: Statement, db: &DatabaseConnection) -> Result { +async fn exec_delete<'a, C>(statement: Statement, db: &C) -> Result +where C: ConnectionTrait<'a> { let result = db.execute(statement).await?; Ok(DeleteResult { rows_affected: result.rows_affected(), diff --git a/src/executor/insert.rs b/src/executor/insert.rs index a44867f7..7683f9bb 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,9 +1,6 @@ -use crate::{ - error::*, ActiveModelTrait, DatabaseConnection, DbBackend, EntityTrait, Insert, - PrimaryKeyTrait, Statement, TryFromU64, -}; +use crate::{ActiveModelTrait, ConnectionTrait, EntityTrait, Insert, PrimaryKeyTrait, Statement, TryFromU64, error::*}; use sea_query::InsertStatement; -use std::{future::Future, marker::PhantomData}; +use std::marker::PhantomData; #[derive(Clone, Debug)] pub struct Inserter @@ -27,11 +24,12 @@ where A: ActiveModelTrait, { #[allow(unused_mut)] - pub fn exec<'a>( + pub async fn exec<'a, C>( self, - db: &'a DatabaseConnection, - ) -> impl Future, DbErr>> + 'a + db: &'a C, + ) -> Result, DbErr> where + C: ConnectionTrait<'a>, A: 'a, { // TODO: extract primary key's value from query @@ -47,7 +45,7 @@ where ); } } - Inserter::::new(query).exec(db) + Inserter::::new(query).exec(db).await // TODO: return primary key if extracted before, otherwise use InsertResult } } @@ -63,24 +61,26 @@ where } } - pub fn exec<'a>( + pub async fn exec<'a, C>( self, - db: &'a DatabaseConnection, - ) -> impl Future, DbErr>> + 'a + db: &'a C, + ) -> Result, DbErr> where + C: ConnectionTrait<'a>, A: 'a, { let builder = db.get_database_backend(); - exec_insert(builder.build(&self.query), db) + exec_insert(builder.build(&self.query), db).await } } // Only Statement impl Send -async fn exec_insert( +async fn exec_insert<'a, A, C>( statement: Statement, - db: &DatabaseConnection, + db: &C, ) -> Result, DbErr> where + C: ConnectionTrait<'a>, A: ActiveModelTrait, { type PrimaryKey = <::Entity as EntityTrait>::PrimaryKey; @@ -93,13 +93,13 @@ where .collect::>(); let res = db.query_one(statement).await?.unwrap(); res.try_get_many("", cols.as_ref()).unwrap_or_default() - } + }, _ => { let last_insert_id = db.execute(statement).await?.last_insert_id(); ValueTypeOf::::try_from_u64(last_insert_id) .ok() .unwrap_or_default() - } + }, }; Ok(InsertResult { last_insert_id }) } diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index 608d9dc1..a4e9bc69 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -1,4 +1,4 @@ -use crate::{error::*, DatabaseConnection, DbBackend, SelectorTrait}; +use crate::{ConnectionTrait, SelectorTrait, error::*}; use async_stream::stream; use futures::Stream; use sea_query::{Alias, Expr, SelectStatement}; @@ -7,21 +7,23 @@ use std::{marker::PhantomData, pin::Pin}; pub type PinBoxStream<'db, Item> = Pin + 'db>>; #[derive(Clone, Debug)] -pub struct Paginator<'db, S> +pub struct Paginator<'db, C, S> where + C: ConnectionTrait<'db>, S: SelectorTrait + 'db, { pub(crate) query: SelectStatement, pub(crate) page: usize, pub(crate) page_size: usize, - pub(crate) db: &'db DatabaseConnection, + pub(crate) db: &'db C, pub(crate) selector: PhantomData, } // LINT: warn if paginator is used without an order by clause -impl<'db, S> Paginator<'db, S> +impl<'db, C, S> Paginator<'db, C, S> where + C: ConnectionTrait<'db>, S: SelectorTrait + 'db, { /// Fetch a specific page; page index starts from zero @@ -155,7 +157,7 @@ where #[cfg(feature = "mock")] mod tests { use crate::entity::prelude::*; - use crate::tests_cfg::*; + use crate::{ConnectionTrait, tests_cfg::*}; use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction}; use futures::TryStreamExt; use sea_query::{Alias, Expr, SelectStatement, Value}; diff --git a/src/executor/select.rs b/src/executor/select.rs index bb386722..cff4a348 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -1,8 +1,8 @@ -use crate::{ - error::*, DatabaseConnection, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, - ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, - SelectTwoMany, Statement, TryGetableMany, -}; +#[cfg(feature = "sqlx-dep")] +use std::pin::Pin; +use crate::{ConnectionTrait, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, SelectTwoMany, Statement, error::*}; +#[cfg(feature = "sqlx-dep")] +use futures::{Stream, TryStreamExt}; use sea_query::SelectStatement; use std::marker::PhantomData; @@ -235,23 +235,35 @@ where Selector::>::with_columns(self.query) } - pub async fn one(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn one<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { self.into_model().one(db).await } - pub async fn all(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { self.into_model().all(db).await } - pub fn paginate( + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result> + 'b, DbErr> + where + C: ConnectionTrait<'a>, + { + self.into_model().stream(db).await + } + + pub fn paginate<'a, C>( self, - db: &DatabaseConnection, + db: &'a C, page_size: usize, - ) -> Paginator<'_, SelectModel> { + ) -> Paginator<'a, C, SelectModel> + where C: ConnectionTrait<'a> { self.into_model().paginate(db, page_size) } - pub async fn count(self, db: &DatabaseConnection) -> Result { + pub async fn count<'a, C>(self, db: &'a C) -> Result + where C: ConnectionTrait<'a> { self.paginate(db, 1).num_items().await } } @@ -280,29 +292,41 @@ where } } - pub async fn one( + pub async fn one<'a, C>( self, - db: &DatabaseConnection, - ) -> Result)>, DbErr> { + db: &C, + ) -> Result)>, DbErr> + where C: ConnectionTrait<'a> { self.into_model().one(db).await } - pub async fn all( + pub async fn all<'a, C>( self, - db: &DatabaseConnection, - ) -> Result)>, DbErr> { + db: &C, + ) -> Result)>, DbErr> + where C: ConnectionTrait<'a> { self.into_model().all(db).await } - pub fn paginate( + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result), DbErr>> + 'b, DbErr> + where + C: ConnectionTrait<'a>, + { + self.into_model().stream(db).await + } + + pub fn paginate<'a, C>( self, - db: &DatabaseConnection, + db: &'a C, page_size: usize, - ) -> Paginator<'_, SelectTwoModel> { + ) -> Paginator<'a, C, SelectTwoModel> + where C: ConnectionTrait<'a> { self.into_model().paginate(db, page_size) } - pub async fn count(self, db: &DatabaseConnection) -> Result { + pub async fn count<'a, C>(self, db: &'a C) -> Result + where C: ConnectionTrait<'a> { self.paginate(db, 1).num_items().await } } @@ -331,17 +355,27 @@ where } } - pub async fn one( + pub async fn one<'a, C>( self, - db: &DatabaseConnection, - ) -> Result)>, DbErr> { + db: &C, + ) -> Result)>, DbErr> + where C: ConnectionTrait<'a> { self.into_model().one(db).await } - pub async fn all( + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result), DbErr>> + 'b, DbErr> + where + C: ConnectionTrait<'a>, + { + self.into_model().stream(db).await + } + + pub async fn all<'a, C>( self, - db: &DatabaseConnection, - ) -> Result)>, DbErr> { + db: &C, + ) -> Result)>, DbErr> + where C: ConnectionTrait<'a> { let rows = self.into_model().all(db).await?; Ok(consolidate_query_result::(rows)) } @@ -376,7 +410,8 @@ where } } - pub async fn one(mut self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn one<'a, C>(mut self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); self.query.limit(1); let row = db.query_one(builder.build(&self.query)).await?; @@ -386,7 +421,8 @@ where } } - pub async fn all(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); let rows = db.query_all(builder.build(&self.query)).await?; let mut models = Vec::new(); @@ -396,7 +432,21 @@ where Ok(models) } - pub fn paginate(self, db: &DatabaseConnection, page_size: usize) -> Paginator<'_, S> { + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result> + 'b>>, DbErr> + where + C: ConnectionTrait<'a>, + S: 'b, + { + let builder = db.get_database_backend(); + let stream = db.stream(builder.build(&self.query)).await?; + Ok(Box::pin(stream.and_then(|row| { + futures::future::ready(S::from_raw_query_result(row)) + }))) + } + + pub fn paginate<'a, C>(self, db: &'a C, page_size: usize) -> Paginator<'a, C, S> + where C: ConnectionTrait<'a> { Paginator { query: self.query, page: 0, @@ -606,7 +656,8 @@ where /// ),] /// ); /// ``` - pub async fn one(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn one<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let row = db.query_one(self.stmt).await?; match row { Some(row) => Ok(Some(S::from_raw_query_result(row)?)), @@ -645,7 +696,8 @@ where /// ),] /// ); /// ``` - pub async fn all(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let rows = db.query_all(self.stmt).await?; let mut models = Vec::new(); for row in rows.into_iter() { diff --git a/src/executor/update.rs b/src/executor/update.rs index 6c7a9873..06cd514e 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -1,6 +1,4 @@ -use crate::{ - error::*, ActiveModelTrait, DatabaseConnection, EntityTrait, Statement, UpdateMany, UpdateOne, -}; +use crate::{ActiveModelTrait, ConnectionTrait, EntityTrait, Statement, UpdateMany, UpdateOne, error::*}; use sea_query::UpdateStatement; use std::future::Future; @@ -18,9 +16,10 @@ impl<'a, A: 'a> UpdateOne where A: ActiveModelTrait, { - pub fn exec(self, db: &'a DatabaseConnection) -> impl Future> + 'a { + pub async fn exec<'b, C>(self, db: &'b C) -> Result + where C: ConnectionTrait<'b> { // so that self is dropped before entering await - exec_update_and_return_original(self.query, self.model, db) + exec_update_and_return_original(self.query, self.model, db).await } } @@ -28,10 +27,11 @@ impl<'a, E> UpdateMany where E: EntityTrait, { - pub fn exec( + pub fn exec( self, - db: &'a DatabaseConnection, - ) -> impl Future> + 'a { + db: &'a C, + ) -> impl Future> + 'a + where C: ConnectionTrait<'a> { // so that self is dropped before entering await exec_update_only(self.query, db) } @@ -42,36 +42,40 @@ impl Updater { Self { query } } - pub fn exec( + pub async fn exec<'a, C>( self, - db: &DatabaseConnection, - ) -> impl Future> + '_ { + db: &'a C, + ) -> Result + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); - exec_update(builder.build(&self.query), db) + exec_update(builder.build(&self.query), db).await } } -async fn exec_update_only( +async fn exec_update_only<'a, C>( query: UpdateStatement, - db: &DatabaseConnection, -) -> Result { + db: &'a C, +) -> Result +where C: ConnectionTrait<'a> { Updater::new(query).exec(db).await } -async fn exec_update_and_return_original( +async fn exec_update_and_return_original<'a, A, C>( query: UpdateStatement, model: A, - db: &DatabaseConnection, + db: &'a C, ) -> Result where A: ActiveModelTrait, + C: ConnectionTrait<'a>, { Updater::new(query).exec(db).await?; Ok(model) } // Only Statement impl Send -async fn exec_update(statement: Statement, db: &DatabaseConnection) -> Result { +async fn exec_update<'a, C>(statement: Statement, db: &'a C) -> Result +where C: ConnectionTrait<'a> { let result = db.execute(statement).await?; Ok(UpdateResult { rows_affected: result.rows_affected(), diff --git a/src/query/mod.rs b/src/query/mod.rs index 54cc12dd..c7f60049 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -20,4 +20,4 @@ pub use select::*; pub use traits::*; pub use update::*; -pub use crate::{InsertResult, Statement, UpdateResult, Value, Values}; +pub use crate::{InsertResult, Statement, UpdateResult, Value, Values, ConnectionTrait}; diff --git a/tests/basic.rs b/tests/basic.rs index a0763d45..ef379779 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -1,6 +1,6 @@ pub mod common; -pub use sea_orm::{entity::*, error::*, sea_query, tests_cfg::*, Database, DbConn}; +pub use sea_orm::{entity::*, error::*, query::*, sea_query, tests_cfg::*, Database, DbConn}; // cargo test --features sqlx-sqlite,runtime-async-std-native-tls --test basic #[sea_orm_macros::test] diff --git a/tests/common/setup/mod.rs b/tests/common/setup/mod.rs index d982b2b7..9deb903f 100644 --- a/tests/common/setup/mod.rs +++ b/tests/common/setup/mod.rs @@ -1,4 +1,4 @@ -use sea_orm::{Database, DatabaseBackend, DatabaseConnection, Statement}; +use sea_orm::{Database, DatabaseBackend, DatabaseConnection, ConnectionTrait, Statement}; pub mod schema; pub use schema::*; diff --git a/tests/common/setup/schema.rs b/tests/common/setup/schema.rs index b39f1e77..ac34304a 100644 --- a/tests/common/setup/schema.rs +++ b/tests/common/setup/schema.rs @@ -1,6 +1,6 @@ pub use super::super::bakery_chain::*; use pretty_assertions::assert_eq; -use sea_orm::{error::*, sea_query, DbBackend, DbConn, EntityTrait, ExecResult, Schema}; +use sea_orm::{error::*, sea_query, ConnectionTrait, DbBackend, DbConn, EntityTrait, ExecResult, Schema}; use sea_query::{ Alias, ColumnDef, ForeignKey, ForeignKeyAction, Index, Table, TableCreateStatement, }; diff --git a/tests/query_tests.rs b/tests/query_tests.rs index 2b5a2295..e688b14f 100644 --- a/tests/query_tests.rs +++ b/tests/query_tests.rs @@ -2,7 +2,7 @@ pub mod common; pub use common::{bakery_chain::*, setup::*, TestContext}; pub use sea_orm::entity::*; -pub use sea_orm::QueryFilter; +pub use sea_orm::{QueryFilter, ConnectionTrait}; // Run the test locally: // DATABASE_URL="mysql://root:@localhost" cargo test --features sqlx-mysql,runtime-async-std --test query_tests diff --git a/tests/sequential_op_tests.rs b/tests/sequential_op_tests.rs index 28333d84..47e69ccb 100644 --- a/tests/sequential_op_tests.rs +++ b/tests/sequential_op_tests.rs @@ -179,7 +179,7 @@ async fn find_baker_least_sales(db: &DatabaseConnection) -> Option let mut results: Vec = select .into_model::() - .all(&db) + .all(db) .await .unwrap() .into_iter() diff --git a/tests/stream_tests.rs b/tests/stream_tests.rs new file mode 100644 index 00000000..969b93e1 --- /dev/null +++ b/tests/stream_tests.rs @@ -0,0 +1,37 @@ +pub mod common; + +pub use common::{bakery_chain::*, setup::*, TestContext}; +pub use sea_orm::entity::*; +pub use sea_orm::{QueryFilter, ConnectionTrait, DbErr}; +use futures::StreamExt; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn stream() -> Result<(), DbErr> { + let ctx = TestContext::new("stream").await; + + let bakery = bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(&ctx.db) + .await?; + + let result = Bakery::find_by_id(bakery.id.clone().unwrap()) + .stream(&ctx.db) + .await? + .next() + .await + .unwrap()?; + + assert_eq!(result.id, bakery.id.unwrap()); + + ctx.delete().await; + + Ok(()) +} diff --git a/tests/transaction_tests.rs b/tests/transaction_tests.rs new file mode 100644 index 00000000..539eaefc --- /dev/null +++ b/tests/transaction_tests.rs @@ -0,0 +1,90 @@ +pub mod common; + +pub use common::{bakery_chain::*, setup::*, TestContext}; +use sea_orm::{DatabaseTransaction, DbErr}; +pub use sea_orm::entity::*; +pub use sea_orm::{QueryFilter, ConnectionTrait}; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn transaction() { + let ctx = TestContext::new("transaction_test").await; + + ctx.db.transaction::<_, _, DbErr>(|txn| Box::pin(async move { + let _ = bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(txn) + .await?; + + let _ = bakery::ActiveModel { + name: Set("Top Bakery".to_owned()), + profit_margin: Set(15.0), + ..Default::default() + } + .save(txn) + .await?; + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 2); + + Ok(()) + })).await.unwrap(); + + ctx.delete().await; +} + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn transaction_with_reference() { + let ctx = TestContext::new("transaction_with_reference_test").await; + let name1 = "SeaSide Bakery"; + let name2 = "Top Bakery"; + let search_name = "Bakery"; + ctx.db.transaction(|txn| _transaction_with_reference(txn, name1, name2, search_name)).await.unwrap(); + + ctx.delete().await; +} + +fn _transaction_with_reference<'a>(txn: &'a DatabaseTransaction, name1: &'a str, name2: &'a str, search_name: &'a str) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + let _ = bakery::ActiveModel { + name: Set(name1.to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(txn) + .await?; + + let _ = bakery::ActiveModel { + name: Set(name2.to_owned()), + profit_margin: Set(15.0), + ..Default::default() + } + .save(txn) + .await?; + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains(search_name)) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 2); + + Ok(()) + }) +} From 01a5c1c6dd53bc366fefb1d004cbce30214162f1 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 4 Oct 2021 11:42:27 +0800 Subject: [PATCH 15/65] Fix build errors --- src/database/db_connection.rs | 6 +++--- src/executor/insert.rs | 2 +- src/executor/paginator.rs | 2 +- src/executor/select.rs | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 569f3896..7f43ae84 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -1,5 +1,5 @@ -use std::{future::Future, pin::Pin, sync::Arc}; -use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, MockDatabaseConnection, QueryResult, Statement, TransactionError}; +use std::{future::Future, pin::Pin}; +use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, QueryResult, Statement, TransactionError}; use futures::Stream; #[cfg(feature = "sqlx-dep")] use sqlx::pool::PoolConnection; @@ -12,7 +12,7 @@ pub(crate) enum InnerConnection { #[cfg(feature = "sqlx-sqlite")] Sqlite(PoolConnection), #[cfg(feature = "mock")] - Mock(Arc), + Mock(std::sync::Arc), } #[async_trait::async_trait] diff --git a/src/executor/insert.rs b/src/executor/insert.rs index 7683f9bb..10b5d8d3 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,4 +1,4 @@ -use crate::{ActiveModelTrait, ConnectionTrait, EntityTrait, Insert, PrimaryKeyTrait, Statement, TryFromU64, error::*}; +use crate::{ActiveModelTrait, ConnectionTrait, DbBackend, EntityTrait, Insert, PrimaryKeyTrait, Statement, TryFromU64, error::*}; use sea_query::InsertStatement; use std::marker::PhantomData; diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index a4e9bc69..0d07591f 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -1,4 +1,4 @@ -use crate::{ConnectionTrait, SelectorTrait, error::*}; +use crate::{ConnectionTrait, DbBackend, SelectorTrait, error::*}; use async_stream::stream; use futures::Stream; use sea_query::{Alias, Expr, SelectStatement}; diff --git a/src/executor/select.rs b/src/executor/select.rs index cff4a348..dc8b7afb 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -1,6 +1,6 @@ #[cfg(feature = "sqlx-dep")] use std::pin::Pin; -use crate::{ConnectionTrait, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, SelectTwoMany, Statement, error::*}; +use crate::{ConnectionTrait, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, SelectTwoMany, Statement, TryGetableMany, error::*}; #[cfg(feature = "sqlx-dep")] use futures::{Stream, TryStreamExt}; use sea_query::SelectStatement; From e7b822c65d9ac8041e96cb954e92aeaa3156b2ff Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 4 Oct 2021 11:44:02 +0800 Subject: [PATCH 16/65] cargo fmt --- examples/rocket_example/src/pool.rs | 4 +- examples/rocket_example/src/setup.rs | 2 +- src/database/connection.rs | 35 +++-- src/database/db_connection.rs | 18 ++- src/database/db_transaction.rs | 224 +++++++++++++++++---------- src/database/mock.rs | 3 +- src/database/mod.rs | 12 +- src/database/stream/query.rs | 76 +++++---- src/database/stream/transaction.rs | 91 ++++++----- src/docs.rs | 2 +- src/driver/mock.rs | 23 ++- src/driver/sqlx_mysql.rs | 19 ++- src/driver/sqlx_postgres.rs | 19 ++- src/driver/sqlx_sqlite.rs | 19 ++- src/executor/delete.rs | 44 +++--- src/executor/insert.rs | 24 ++- src/executor/paginator.rs | 4 +- src/executor/select.rs | 110 ++++++++----- src/executor/update.rs | 39 ++--- src/lib.rs | 2 +- src/query/mod.rs | 2 +- tests/common/setup/mod.rs | 2 +- tests/common/setup/schema.rs | 4 +- tests/query_tests.rs | 2 +- tests/stream_tests.rs | 4 +- tests/transaction_tests.rs | 73 +++++---- 26 files changed, 517 insertions(+), 340 deletions(-) diff --git a/examples/rocket_example/src/pool.rs b/examples/rocket_example/src/pool.rs index 7c8e37cd..c4140c1f 100644 --- a/examples/rocket_example/src/pool.rs +++ b/examples/rocket_example/src/pool.rs @@ -16,9 +16,7 @@ impl rocket_db_pools::Pool for RocketDbPool { let config = figment.extract::().unwrap(); let conn = sea_orm::Database::connect(&config.url).await.unwrap(); - Ok(RocketDbPool { - conn, - }) + Ok(RocketDbPool { conn }) } async fn get(&self) -> Result { diff --git a/examples/rocket_example/src/setup.rs b/examples/rocket_example/src/setup.rs index 91bbb7b1..f5b5a99e 100644 --- a/examples/rocket_example/src/setup.rs +++ b/examples/rocket_example/src/setup.rs @@ -1,5 +1,5 @@ use sea_orm::sea_query::{ColumnDef, TableCreateStatement}; -use sea_orm::{query::*, error::*, sea_query, DbConn, ExecResult}; +use sea_orm::{error::*, query::*, sea_query, DbConn, ExecResult}; async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result { let builder = db.get_database_backend(); diff --git a/src/database/connection.rs b/src/database/connection.rs index 630b8b96..d2283caf 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -1,6 +1,9 @@ -use std::{future::Future, pin::Pin, sync::Arc}; -use crate::{DatabaseTransaction, ConnectionTrait, ExecResult, QueryResult, Statement, StatementBuilder, TransactionError, error::*}; +use crate::{ + error::*, ConnectionTrait, DatabaseTransaction, ExecResult, QueryResult, Statement, + StatementBuilder, TransactionError, +}; use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; +use std::{future::Future, pin::Pin, sync::Arc}; #[cfg_attr(not(feature = "mock"), derive(Clone))] pub enum DatabaseConnection { @@ -112,7 +115,10 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { } } - fn stream(&'a self, stmt: Statement) -> Pin> + 'a>> { + fn stream( + &'a self, + stmt: Statement, + ) -> Pin> + 'a>> { Box::pin(async move { Ok(match self { #[cfg(feature = "sqlx-mysql")] @@ -122,7 +128,9 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => crate::QueryStream::from((Arc::clone(conn), stmt)), + DatabaseConnection::MockDatabaseConnection(conn) => { + crate::QueryStream::from((Arc::clone(conn), stmt)) + } DatabaseConnection::Disconnected => panic!("Disconnected"), }) }) @@ -137,7 +145,9 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => DatabaseTransaction::new_mock(Arc::clone(conn)).await, + DatabaseConnection::MockDatabaseConnection(conn) => { + DatabaseTransaction::new_mock(Arc::clone(conn)).await + } DatabaseConnection::Disconnected => panic!("Disconnected"), } } @@ -146,7 +156,10 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. async fn transaction(&self, _callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, + F: for<'c> FnOnce( + &'c DatabaseTransaction, + ) -> Pin> + Send + 'c>> + + Send, T: Send, E: std::error::Error + Send, { @@ -154,14 +167,18 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await, #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.transaction(_callback).await, + DatabaseConnection::SqlxPostgresPoolConnection(conn) => { + conn.transaction(_callback).await + } #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await, #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => { - let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)).await.map_err(|e| TransactionError::Connection(e))?; + let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)) + .await + .map_err(|e| TransactionError::Connection(e))?; transaction.run(_callback).await - }, + } DatabaseConnection::Disconnected => panic!("Disconnected"), } } diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 7f43ae84..6040e452 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -1,8 +1,10 @@ -use std::{future::Future, pin::Pin}; -use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, QueryResult, Statement, TransactionError}; +use crate::{ + DatabaseTransaction, DbBackend, DbErr, ExecResult, QueryResult, Statement, TransactionError, +}; use futures::Stream; #[cfg(feature = "sqlx-dep")] use sqlx::pool::PoolConnection; +use std::{future::Future, pin::Pin}; pub(crate) enum InnerConnection { #[cfg(feature = "sqlx-mysql")] @@ -17,7 +19,7 @@ pub(crate) enum InnerConnection { #[async_trait::async_trait] pub trait ConnectionTrait<'a>: Sync { - type Stream: Stream>; + type Stream: Stream>; fn get_database_backend(&self) -> DbBackend; @@ -27,7 +29,10 @@ pub trait ConnectionTrait<'a>: Sync { async fn query_all(&self, stmt: Statement) -> Result, DbErr>; - fn stream(&'a self, stmt: Statement) -> Pin> + 'a>>; + fn stream( + &'a self, + stmt: Statement, + ) -> Pin> + 'a>>; async fn begin(&self) -> Result; @@ -35,7 +40,10 @@ pub trait ConnectionTrait<'a>: Sync { /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. async fn transaction(&self, callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, + F: for<'c> FnOnce( + &'c DatabaseTransaction, + ) -> Pin> + Send + 'c>> + + Send, T: Send, E: std::error::Error + Send; diff --git a/src/database/db_transaction.rs b/src/database/db_transaction.rs index b403971e..ae954097 100644 --- a/src/database/db_transaction.rs +++ b/src/database/db_transaction.rs @@ -1,10 +1,13 @@ -use std::{sync::Arc, future::Future, pin::Pin}; -use crate::{ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, Statement, TransactionStream, debug_print}; -use futures::lock::Mutex; +use crate::{ + debug_print, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, + Statement, TransactionStream, +}; #[cfg(feature = "sqlx-dep")] use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err}; +use futures::lock::Mutex; #[cfg(feature = "sqlx-dep")] use sqlx::{pool::PoolConnection, TransactionManager}; +use std::{future::Future, pin::Pin, sync::Arc}; // a Transaction is just a sugar for a connection where START TRANSACTION has been executed pub struct DatabaseTransaction { @@ -21,27 +24,50 @@ impl std::fmt::Debug for DatabaseTransaction { impl DatabaseTransaction { #[cfg(feature = "sqlx-mysql")] - pub(crate) async fn new_mysql(inner: PoolConnection) -> Result { - Self::build(Arc::new(Mutex::new(InnerConnection::MySql(inner))), DbBackend::MySql).await + pub(crate) async fn new_mysql( + inner: PoolConnection, + ) -> Result { + Self::build( + Arc::new(Mutex::new(InnerConnection::MySql(inner))), + DbBackend::MySql, + ) + .await } #[cfg(feature = "sqlx-postgres")] - pub(crate) async fn new_postgres(inner: PoolConnection) -> Result { - Self::build(Arc::new(Mutex::new(InnerConnection::Postgres(inner))), DbBackend::Postgres).await + pub(crate) async fn new_postgres( + inner: PoolConnection, + ) -> Result { + Self::build( + Arc::new(Mutex::new(InnerConnection::Postgres(inner))), + DbBackend::Postgres, + ) + .await } #[cfg(feature = "sqlx-sqlite")] - pub(crate) async fn new_sqlite(inner: PoolConnection) -> Result { - Self::build(Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), DbBackend::Sqlite).await + pub(crate) async fn new_sqlite( + inner: PoolConnection, + ) -> Result { + Self::build( + Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), + DbBackend::Sqlite, + ) + .await } #[cfg(feature = "mock")] - pub(crate) async fn new_mock(inner: Arc) -> Result { + pub(crate) async fn new_mock( + inner: Arc, + ) -> Result { let backend = inner.get_database_backend(); Self::build(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await } - async fn build(conn: Arc>, backend: DbBackend) -> Result { + async fn build( + conn: Arc>, + backend: DbBackend, + ) -> Result { let res = DatabaseTransaction { conn, backend, @@ -50,35 +76,49 @@ impl DatabaseTransaction { match *res.conn.lock().await { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(ref mut c) => { - ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? - }, + ::TransactionManager::begin(c) + .await + .map_err(sqlx_error_to_query_err)? + } #[cfg(feature = "sqlx-postgres")] InnerConnection::Postgres(ref mut c) => { - ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? - }, + ::TransactionManager::begin(c) + .await + .map_err(sqlx_error_to_query_err)? + } #[cfg(feature = "sqlx-sqlite")] InnerConnection::Sqlite(ref mut c) => { - ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? - }, + ::TransactionManager::begin(c) + .await + .map_err(sqlx_error_to_query_err)? + } // should we do something for mocked connections? #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {}, + InnerConnection::Mock(_) => {} } Ok(res) } pub(crate) async fn run(self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + F: for<'b> FnOnce( + &'b DatabaseTransaction, + ) -> Pin> + Send + 'b>> + + Send, T: Send, E: std::error::Error + Send, { - let res = callback(&self).await.map_err(|e| TransactionError::Transaction(e)); + let res = callback(&self) + .await + .map_err(|e| TransactionError::Transaction(e)); if res.is_ok() { - self.commit().await.map_err(|e| TransactionError::Connection(e))?; - } - else { - self.rollback().await.map_err(|e| TransactionError::Connection(e))?; + self.commit() + .await + .map_err(|e| TransactionError::Connection(e))?; + } else { + self.rollback() + .await + .map_err(|e| TransactionError::Connection(e))?; } res } @@ -88,19 +128,25 @@ impl DatabaseTransaction { match *self.conn.lock().await { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(ref mut c) => { - ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? - }, + ::TransactionManager::commit(c) + .await + .map_err(sqlx_error_to_query_err)? + } #[cfg(feature = "sqlx-postgres")] InnerConnection::Postgres(ref mut c) => { - ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? - }, + ::TransactionManager::commit(c) + .await + .map_err(sqlx_error_to_query_err)? + } #[cfg(feature = "sqlx-sqlite")] InnerConnection::Sqlite(ref mut c) => { - ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? - }, + ::TransactionManager::commit(c) + .await + .map_err(sqlx_error_to_query_err)? + } //Should we do something for mocked connections? #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {}, + InnerConnection::Mock(_) => {} } Ok(()) } @@ -110,19 +156,25 @@ impl DatabaseTransaction { match *self.conn.lock().await { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(ref mut c) => { - ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? - }, + ::TransactionManager::rollback(c) + .await + .map_err(sqlx_error_to_query_err)? + } #[cfg(feature = "sqlx-postgres")] InnerConnection::Postgres(ref mut c) => { - ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? - }, + ::TransactionManager::rollback(c) + .await + .map_err(sqlx_error_to_query_err)? + } #[cfg(feature = "sqlx-sqlite")] InnerConnection::Sqlite(ref mut c) => { - ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? - }, + ::TransactionManager::rollback(c) + .await + .map_err(sqlx_error_to_query_err)? + } //Should we do something for mocked connections? #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {}, + InnerConnection::Mock(_) => {} } Ok(()) } @@ -135,21 +187,20 @@ impl DatabaseTransaction { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(c) => { ::TransactionManager::start_rollback(c); - }, + } #[cfg(feature = "sqlx-postgres")] InnerConnection::Postgres(c) => { ::TransactionManager::start_rollback(c); - }, + } #[cfg(feature = "sqlx-sqlite")] InnerConnection::Sqlite(c) => { ::TransactionManager::start_rollback(c); - }, + } //Should we do something for mocked connections? #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {}, + InnerConnection::Mock(_) => {} } - } - else { + } else { //this should never happen panic!("Dropping a locked Transaction"); } @@ -179,21 +230,18 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(conn) => { let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - query.execute(conn).await - .map(Into::into) - }, + query.execute(conn).await.map(Into::into) + } #[cfg(feature = "sqlx-postgres")] InnerConnection::Postgres(conn) => { let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - query.execute(conn).await - .map(Into::into) - }, + query.execute(conn).await.map(Into::into) + } #[cfg(feature = "sqlx-sqlite")] InnerConnection::Sqlite(conn) => { let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); - query.execute(conn).await - .map(Into::into) - }, + query.execute(conn).await.map(Into::into) + } #[cfg(feature = "mock")] InnerConnection::Mock(conn) => return conn.execute(stmt), }; @@ -208,29 +256,25 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(conn) => { let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - query.fetch_one(conn).await - .map(|row| Some(row.into())) - }, + query.fetch_one(conn).await.map(|row| Some(row.into())) + } #[cfg(feature = "sqlx-postgres")] InnerConnection::Postgres(conn) => { let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - query.fetch_one(conn).await - .map(|row| Some(row.into())) - }, + query.fetch_one(conn).await.map(|row| Some(row.into())) + } #[cfg(feature = "sqlx-sqlite")] InnerConnection::Sqlite(conn) => { - let query= crate::driver::sqlx_sqlite::sqlx_query(&stmt); - query.fetch_one(conn).await - .map(|row| Some(row.into())) - }, + let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query.fetch_one(conn).await.map(|row| Some(row.into())) + } #[cfg(feature = "mock")] InnerConnection::Mock(conn) => return conn.query_one(stmt), }; #[cfg(feature = "sqlx-dep")] if let Err(sqlx::Error::RowNotFound) = _res { Ok(None) - } - else { + } else { _res.map_err(sqlx_error_to_query_err) } } @@ -242,21 +286,27 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(conn) => { let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - query.fetch_all(conn).await + query + .fetch_all(conn) + .await .map(|rows| rows.into_iter().map(|r| r.into()).collect()) - }, + } #[cfg(feature = "sqlx-postgres")] InnerConnection::Postgres(conn) => { let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - query.fetch_all(conn).await + query + .fetch_all(conn) + .await .map(|rows| rows.into_iter().map(|r| r.into()).collect()) - }, + } #[cfg(feature = "sqlx-sqlite")] InnerConnection::Sqlite(conn) => { let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); - query.fetch_all(conn).await + query + .fetch_all(conn) + .await .map(|rows| rows.into_iter().map(|r| r.into()).collect()) - }, + } #[cfg(feature = "mock")] InnerConnection::Mock(conn) => return conn.query_all(stmt), }; @@ -264,10 +314,13 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { _res.map_err(sqlx_error_to_query_err) } - fn stream(&'a self, stmt: Statement) -> Pin> + 'a>> { - Box::pin(async move { - Ok(crate::TransactionStream::build(self.conn.lock().await, stmt).await) - }) + fn stream( + &'a self, + stmt: Statement, + ) -> Pin> + 'a>> { + Box::pin( + async move { Ok(crate::TransactionStream::build(self.conn.lock().await, stmt).await) }, + ) } async fn begin(&self) -> Result { @@ -278,24 +331,34 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. async fn transaction(&self, _callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, + F: for<'c> FnOnce( + &'c DatabaseTransaction, + ) -> Pin> + Send + 'c>> + + Send, T: Send, E: std::error::Error + Send, { - let transaction = self.begin().await.map_err(|e| TransactionError::Connection(e))?; + let transaction = self + .begin() + .await + .map_err(|e| TransactionError::Connection(e))?; transaction.run(_callback).await } } #[derive(Debug)] pub enum TransactionError -where E: std::error::Error { +where + E: std::error::Error, +{ Connection(DbErr), Transaction(E), } impl std::fmt::Display for TransactionError -where E: std::error::Error { +where + E: std::error::Error, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { TransactionError::Connection(e) => std::fmt::Display::fmt(e, f), @@ -304,5 +367,4 @@ where E: std::error::Error { } } -impl std::error::Error for TransactionError -where E: std::error::Error {} +impl std::error::Error for TransactionError where E: std::error::Error {} diff --git a/src/database/mock.rs b/src/database/mock.rs index d9e1f9d0..f42add7a 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -100,8 +100,7 @@ impl MockRow { where T: ValueType, { - T::try_from(self.values.get(col).unwrap().clone()) - .map_err(|e| DbErr::Query(e.to_string())) + T::try_from(self.values.get(col).unwrap().clone()).map_err(|e| DbErr::Query(e.to_string())) } pub fn into_column_value_tuples(self) -> impl Iterator { diff --git a/src/database/mod.rs b/src/database/mod.rs index ce4127e3..369ed539 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,20 +1,20 @@ mod connection; +mod db_connection; +mod db_transaction; #[cfg(feature = "mock")] mod mock; mod statement; -mod transaction; -mod db_connection; -mod db_transaction; mod stream; +mod transaction; pub use connection::*; +pub use db_connection::*; +pub use db_transaction::*; #[cfg(feature = "mock")] pub use mock::*; pub use statement::*; -pub use transaction::*; -pub use db_connection::*; -pub use db_transaction::*; pub use stream::*; +pub use transaction::*; use crate::DbErr; diff --git a/src/database/stream/query.rs b/src/database/stream/query.rs index 553d9f7b..73668da0 100644 --- a/src/database/stream/query.rs +++ b/src/database/stream/query.rs @@ -1,4 +1,4 @@ -use std::{pin::Pin, task::Poll, sync::Arc}; +use std::{pin::Pin, sync::Arc, task::Poll}; use futures::Stream; #[cfg(feature = "sqlx-dep")] @@ -57,52 +57,50 @@ impl QueryStream { QueryStreamBuilder { stmt, conn, - stream_builder: |conn, stmt| { - match conn { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(c) => { - let query = crate::driver::sqlx_mysql::sqlx_query(stmt); - Box::pin( - c.fetch(query) - .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err) - ) - }, - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(c) => { - let query = crate::driver::sqlx_postgres::sqlx_query(stmt); - Box::pin( - c.fetch(query) - .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err) - ) - }, - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(c) => { - let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); - Box::pin( - c.fetch(query) - .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err) - ) - }, - #[cfg(feature = "mock")] - InnerConnection::Mock(c) => { - c.fetch(stmt) - }, + stream_builder: |conn, stmt| match conn { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + let query = crate::driver::sqlx_mysql::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + let query = crate::driver::sqlx_postgres::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) + } + #[cfg(feature = "mock")] + InnerConnection::Mock(c) => c.fetch(stmt), }, - }.build() + } + .build() } } impl Stream for QueryStream { type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { let this = self.get_mut(); - this.with_stream_mut(|stream| { - stream.as_mut().poll_next(cx) - }) + this.with_stream_mut(|stream| stream.as_mut().poll_next(cx)) } } diff --git a/src/database/stream/transaction.rs b/src/database/stream/transaction.rs index d945f409..651f3d11 100644 --- a/src/database/stream/transaction.rs +++ b/src/database/stream/transaction.rs @@ -27,56 +27,65 @@ impl<'a> std::fmt::Debug for TransactionStream<'a> { } impl<'a> TransactionStream<'a> { - pub(crate) async fn build(conn: MutexGuard<'a, InnerConnection>, stmt: Statement) -> TransactionStream<'a> { + pub(crate) async fn build( + conn: MutexGuard<'a, InnerConnection>, + stmt: Statement, + ) -> TransactionStream<'a> { TransactionStreamAsyncBuilder { stmt, conn, - stream_builder: |conn, stmt| Box::pin(async move { - match conn.deref_mut() { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(c) => { - let query = crate::driver::sqlx_mysql::sqlx_query(stmt); - Box::pin( - c.fetch(query) - .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err) - ) as Pin>>> - }, - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(c) => { - let query = crate::driver::sqlx_postgres::sqlx_query(stmt); - Box::pin( - c.fetch(query) - .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err) - ) as Pin>>> - }, - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(c) => { - let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); - Box::pin( - c.fetch(query) - .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err) - ) as Pin>>> - }, - #[cfg(feature = "mock")] - InnerConnection::Mock(c) => { - c.fetch(stmt) - }, - } - }), - }.build().await + stream_builder: |conn, stmt| { + Box::pin(async move { + match conn.deref_mut() { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + let query = crate::driver::sqlx_mysql::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) + as Pin>>> + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + let query = crate::driver::sqlx_postgres::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) + as Pin>>> + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) + as Pin>>> + } + #[cfg(feature = "mock")] + InnerConnection::Mock(c) => c.fetch(stmt), + } + }) + }, + } + .build() + .await } } impl<'a> Stream for TransactionStream<'a> { type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { let this = self.get_mut(); - this.with_stream_mut(|stream| { - stream.as_mut().poll_next(cx) - }) + this.with_stream_mut(|stream| stream.as_mut().poll_next(cx)) } } diff --git a/src/docs.rs b/src/docs.rs index bab054ef..4d1226c3 100644 --- a/src/docs.rs +++ b/src/docs.rs @@ -163,4 +163,4 @@ //! }, //! ) //! } -//! ``` \ No newline at end of file +//! ``` diff --git a/src/driver/mock.rs b/src/driver/mock.rs index 388ac911..96317c6d 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -2,11 +2,15 @@ use crate::{ debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult, Statement, Transaction, }; -use std::{fmt::Debug, pin::Pin, sync::{Arc, - atomic::{AtomicUsize, Ordering}, - Mutex, -}}; use futures::Stream; +use std::{ + fmt::Debug, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, +}; #[derive(Debug)] pub struct MockDatabaseConnector; @@ -49,9 +53,9 @@ impl MockDatabaseConnector { pub async fn connect(string: &str) -> Result { macro_rules! connect_mock_db { ( $syntax: expr ) => { - Ok(DatabaseConnection::MockDatabaseConnection( - Arc::new(MockDatabaseConnection::new(MockDatabase::new($syntax))), - )) + Ok(DatabaseConnection::MockDatabaseConnection(Arc::new( + MockDatabaseConnection::new(MockDatabase::new($syntax)), + ))) }; } @@ -105,7 +109,10 @@ impl MockDatabaseConnection { self.mocker.lock().unwrap().query(counter, statement) } - pub fn fetch(&self, statement: &Statement) -> Pin>>> { + pub fn fetch( + &self, + statement: &Statement, + ) -> Pin>>> { match self.query_all(statement.clone()) { Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(|r| Ok(r)))), Err(e) => Box::pin(futures::stream::iter(Some(Err(e)).into_iter())), diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index 75e6e5ff..6b6f9507 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -1,11 +1,17 @@ use std::{future::Future, pin::Pin}; -use sqlx::{MySql, MySqlPool, mysql::{MySqlArguments, MySqlQueryResult, MySqlRow}}; +use sqlx::{ + mysql::{MySqlArguments, MySqlQueryResult, MySqlRow}, + MySql, MySqlPool, +}; sea_query::sea_query_driver_mysql!(); use sea_query_driver_mysql::bind_query; -use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; +use crate::{ + debug_print, error::*, executor::*, DatabaseConnection, DatabaseTransaction, QueryStream, + Statement, TransactionError, +}; use super::sqlx_common::*; @@ -115,12 +121,17 @@ impl SqlxMySqlPoolConnection { pub async fn transaction(&self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + F: for<'b> FnOnce( + &'b DatabaseTransaction, + ) -> Pin> + Send + 'b>> + + Send, T: Send, E: std::error::Error + Send, { if let Ok(conn) = self.pool.acquire().await { - let transaction = DatabaseTransaction::new_mysql(conn).await.map_err(|e| TransactionError::Connection(e))?; + let transaction = DatabaseTransaction::new_mysql(conn) + .await + .map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await } else { Err(TransactionError::Connection(DbErr::Query( diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index c9949375..13cb51cd 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -1,11 +1,17 @@ use std::{future::Future, pin::Pin}; -use sqlx::{PgPool, Postgres, postgres::{PgArguments, PgQueryResult, PgRow}}; +use sqlx::{ + postgres::{PgArguments, PgQueryResult, PgRow}, + PgPool, Postgres, +}; sea_query::sea_query_driver_postgres!(); use sea_query_driver_postgres::bind_query; -use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; +use crate::{ + debug_print, error::*, executor::*, DatabaseConnection, DatabaseTransaction, QueryStream, + Statement, TransactionError, +}; use super::sqlx_common::*; @@ -115,12 +121,17 @@ impl SqlxPostgresPoolConnection { pub async fn transaction(&self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + F: for<'b> FnOnce( + &'b DatabaseTransaction, + ) -> Pin> + Send + 'b>> + + Send, T: Send, E: std::error::Error + Send, { if let Ok(conn) = self.pool.acquire().await { - let transaction = DatabaseTransaction::new_postgres(conn).await.map_err(|e| TransactionError::Connection(e))?; + let transaction = DatabaseTransaction::new_postgres(conn) + .await + .map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await } else { Err(TransactionError::Connection(DbErr::Query( diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index bf06a265..0f7548fc 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -1,11 +1,17 @@ use std::{future::Future, pin::Pin}; -use sqlx::{Sqlite, SqlitePool, sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}}; +use sqlx::{ + sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}, + Sqlite, SqlitePool, +}; sea_query::sea_query_driver_sqlite!(); use sea_query_driver_sqlite::bind_query; -use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; +use crate::{ + debug_print, error::*, executor::*, DatabaseConnection, DatabaseTransaction, QueryStream, + Statement, TransactionError, +}; use super::sqlx_common::*; @@ -115,12 +121,17 @@ impl SqlxSqlitePoolConnection { pub async fn transaction(&self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + F: for<'b> FnOnce( + &'b DatabaseTransaction, + ) -> Pin> + Send + 'b>> + + Send, T: Send, E: std::error::Error + Send, { if let Ok(conn) = self.pool.acquire().await { - let transaction = DatabaseTransaction::new_sqlite(conn).await.map_err(|e| TransactionError::Connection(e))?; + let transaction = DatabaseTransaction::new_sqlite(conn) + .await + .map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await } else { Err(TransactionError::Connection(DbErr::Query( diff --git a/src/executor/delete.rs b/src/executor/delete.rs index 85b37cb0..2fa5f64a 100644 --- a/src/executor/delete.rs +++ b/src/executor/delete.rs @@ -1,4 +1,6 @@ -use crate::{ActiveModelTrait, ConnectionTrait, DeleteMany, DeleteOne, EntityTrait, Statement, error::*}; +use crate::{ + error::*, ActiveModelTrait, ConnectionTrait, DeleteMany, DeleteOne, EntityTrait, Statement, +}; use sea_query::DeleteStatement; use std::future::Future; @@ -16,11 +18,10 @@ impl<'a, A: 'a> DeleteOne where A: ActiveModelTrait, { - pub fn exec( - self, - db: &'a C, - ) -> impl Future> + 'a - where C: ConnectionTrait<'a> { + pub fn exec(self, db: &'a C) -> impl Future> + 'a + where + C: ConnectionTrait<'a>, + { // so that self is dropped before entering await exec_delete_only(self.query, db) } @@ -30,11 +31,10 @@ impl<'a, E> DeleteMany where E: EntityTrait, { - pub fn exec( - self, - db: &'a C, - ) -> impl Future> + 'a - where C: ConnectionTrait<'a> { + pub fn exec(self, db: &'a C) -> impl Future> + 'a + where + C: ConnectionTrait<'a>, + { // so that self is dropped before entering await exec_delete_only(self.query, db) } @@ -45,27 +45,27 @@ impl Deleter { Self { query } } - pub fn exec<'a, C>( - self, - db: &'a C, - ) -> impl Future> + '_ - where C: ConnectionTrait<'a> { + pub fn exec<'a, C>(self, db: &'a C) -> impl Future> + '_ + where + C: ConnectionTrait<'a>, + { let builder = db.get_database_backend(); exec_delete(builder.build(&self.query), db) } } -async fn exec_delete_only<'a, C>( - query: DeleteStatement, - db: &'a C, -) -> Result -where C: ConnectionTrait<'a> { +async fn exec_delete_only<'a, C>(query: DeleteStatement, db: &'a C) -> Result +where + C: ConnectionTrait<'a>, +{ Deleter::new(query).exec(db).await } // Only Statement impl Send async fn exec_delete<'a, C>(statement: Statement, db: &C) -> Result -where C: ConnectionTrait<'a> { +where + C: ConnectionTrait<'a>, +{ let result = db.execute(statement).await?; Ok(DeleteResult { rows_affected: result.rows_affected(), diff --git a/src/executor/insert.rs b/src/executor/insert.rs index 10b5d8d3..7d2e3b11 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,4 +1,7 @@ -use crate::{ActiveModelTrait, ConnectionTrait, DbBackend, EntityTrait, Insert, PrimaryKeyTrait, Statement, TryFromU64, error::*}; +use crate::{ + error::*, ActiveModelTrait, ConnectionTrait, DbBackend, EntityTrait, Insert, PrimaryKeyTrait, + Statement, TryFromU64, +}; use sea_query::InsertStatement; use std::marker::PhantomData; @@ -24,10 +27,7 @@ where A: ActiveModelTrait, { #[allow(unused_mut)] - pub async fn exec<'a, C>( - self, - db: &'a C, - ) -> Result, DbErr> + pub async fn exec<'a, C>(self, db: &'a C) -> Result, DbErr> where C: ConnectionTrait<'a>, A: 'a, @@ -61,10 +61,7 @@ where } } - pub async fn exec<'a, C>( - self, - db: &'a C, - ) -> Result, DbErr> + pub async fn exec<'a, C>(self, db: &'a C) -> Result, DbErr> where C: ConnectionTrait<'a>, A: 'a, @@ -75,10 +72,7 @@ where } // Only Statement impl Send -async fn exec_insert<'a, A, C>( - statement: Statement, - db: &C, -) -> Result, DbErr> +async fn exec_insert<'a, A, C>(statement: Statement, db: &C) -> Result, DbErr> where C: ConnectionTrait<'a>, A: ActiveModelTrait, @@ -93,13 +87,13 @@ where .collect::>(); let res = db.query_one(statement).await?.unwrap(); res.try_get_many("", cols.as_ref()).unwrap_or_default() - }, + } _ => { let last_insert_id = db.execute(statement).await?.last_insert_id(); ValueTypeOf::::try_from_u64(last_insert_id) .ok() .unwrap_or_default() - }, + } }; Ok(InsertResult { last_insert_id }) } diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index 0d07591f..28f8574b 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -1,4 +1,4 @@ -use crate::{ConnectionTrait, DbBackend, SelectorTrait, error::*}; +use crate::{error::*, ConnectionTrait, DbBackend, SelectorTrait}; use async_stream::stream; use futures::Stream; use sea_query::{Alias, Expr, SelectStatement}; @@ -157,7 +157,7 @@ where #[cfg(feature = "mock")] mod tests { use crate::entity::prelude::*; - use crate::{ConnectionTrait, tests_cfg::*}; + use crate::{tests_cfg::*, ConnectionTrait}; use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction}; use futures::TryStreamExt; use sea_query::{Alias, Expr, SelectStatement, Value}; diff --git a/src/executor/select.rs b/src/executor/select.rs index dc8b7afb..f4ff69d4 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -1,10 +1,14 @@ -#[cfg(feature = "sqlx-dep")] -use std::pin::Pin; -use crate::{ConnectionTrait, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, SelectTwoMany, Statement, TryGetableMany, error::*}; +use crate::{ + error::*, ConnectionTrait, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, + ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, + SelectTwoMany, Statement, TryGetableMany, +}; #[cfg(feature = "sqlx-dep")] use futures::{Stream, TryStreamExt}; use sea_query::SelectStatement; use std::marker::PhantomData; +#[cfg(feature = "sqlx-dep")] +use std::pin::Pin; #[derive(Clone, Debug)] pub struct Selector @@ -236,17 +240,24 @@ where } pub async fn one<'a, C>(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait<'a> { + where + C: ConnectionTrait<'a>, + { self.into_model().one(db).await } pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait<'a> { + where + C: ConnectionTrait<'a>, + { self.into_model().all(db).await } #[cfg(feature = "sqlx-dep")] - pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result> + 'b, DbErr> + pub async fn stream<'a: 'b, 'b, C>( + self, + db: &'a C, + ) -> Result> + 'b, DbErr> where C: ConnectionTrait<'a>, { @@ -258,12 +269,16 @@ where db: &'a C, page_size: usize, ) -> Paginator<'a, C, SelectModel> - where C: ConnectionTrait<'a> { + where + C: ConnectionTrait<'a>, + { self.into_model().paginate(db, page_size) } pub async fn count<'a, C>(self, db: &'a C) -> Result - where C: ConnectionTrait<'a> { + where + C: ConnectionTrait<'a>, + { self.paginate(db, 1).num_items().await } } @@ -292,24 +307,25 @@ where } } - pub async fn one<'a, C>( - self, - db: &C, - ) -> Result)>, DbErr> - where C: ConnectionTrait<'a> { + pub async fn one<'a, C>(self, db: &C) -> Result)>, DbErr> + where + C: ConnectionTrait<'a>, + { self.into_model().one(db).await } - pub async fn all<'a, C>( - self, - db: &C, - ) -> Result)>, DbErr> - where C: ConnectionTrait<'a> { + pub async fn all<'a, C>(self, db: &C) -> Result)>, DbErr> + where + C: ConnectionTrait<'a>, + { self.into_model().all(db).await } #[cfg(feature = "sqlx-dep")] - pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result), DbErr>> + 'b, DbErr> + pub async fn stream<'a: 'b, 'b, C>( + self, + db: &'a C, + ) -> Result), DbErr>> + 'b, DbErr> where C: ConnectionTrait<'a>, { @@ -321,12 +337,16 @@ where db: &'a C, page_size: usize, ) -> Paginator<'a, C, SelectTwoModel> - where C: ConnectionTrait<'a> { + where + C: ConnectionTrait<'a>, + { self.into_model().paginate(db, page_size) } pub async fn count<'a, C>(self, db: &'a C) -> Result - where C: ConnectionTrait<'a> { + where + C: ConnectionTrait<'a>, + { self.paginate(db, 1).num_items().await } } @@ -355,27 +375,28 @@ where } } - pub async fn one<'a, C>( - self, - db: &C, - ) -> Result)>, DbErr> - where C: ConnectionTrait<'a> { + pub async fn one<'a, C>(self, db: &C) -> Result)>, DbErr> + where + C: ConnectionTrait<'a>, + { self.into_model().one(db).await } #[cfg(feature = "sqlx-dep")] - pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result), DbErr>> + 'b, DbErr> + pub async fn stream<'a: 'b, 'b, C>( + self, + db: &'a C, + ) -> Result), DbErr>> + 'b, DbErr> where C: ConnectionTrait<'a>, { self.into_model().stream(db).await } - pub async fn all<'a, C>( - self, - db: &C, - ) -> Result)>, DbErr> - where C: ConnectionTrait<'a> { + pub async fn all<'a, C>(self, db: &C) -> Result)>, DbErr> + where + C: ConnectionTrait<'a>, + { let rows = self.into_model().all(db).await?; Ok(consolidate_query_result::(rows)) } @@ -411,7 +432,9 @@ where } pub async fn one<'a, C>(mut self, db: &C) -> Result, DbErr> - where C: ConnectionTrait<'a> { + where + C: ConnectionTrait<'a>, + { let builder = db.get_database_backend(); self.query.limit(1); let row = db.query_one(builder.build(&self.query)).await?; @@ -422,7 +445,9 @@ where } pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait<'a> { + where + C: ConnectionTrait<'a>, + { let builder = db.get_database_backend(); let rows = db.query_all(builder.build(&self.query)).await?; let mut models = Vec::new(); @@ -433,7 +458,10 @@ where } #[cfg(feature = "sqlx-dep")] - pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result> + 'b>>, DbErr> + pub async fn stream<'a: 'b, 'b, C>( + self, + db: &'a C, + ) -> Result> + 'b>>, DbErr> where C: ConnectionTrait<'a>, S: 'b, @@ -446,7 +474,9 @@ where } pub fn paginate<'a, C>(self, db: &'a C, page_size: usize) -> Paginator<'a, C, S> - where C: ConnectionTrait<'a> { + where + C: ConnectionTrait<'a>, + { Paginator { query: self.query, page: 0, @@ -657,7 +687,9 @@ where /// ); /// ``` pub async fn one<'a, C>(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait<'a> { + where + C: ConnectionTrait<'a>, + { let row = db.query_one(self.stmt).await?; match row { Some(row) => Ok(Some(S::from_raw_query_result(row)?)), @@ -697,7 +729,9 @@ where /// ); /// ``` pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait<'a> { + where + C: ConnectionTrait<'a>, + { let rows = db.query_all(self.stmt).await?; let mut models = Vec::new(); for row in rows.into_iter() { diff --git a/src/executor/update.rs b/src/executor/update.rs index 06cd514e..9e36de57 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -1,4 +1,6 @@ -use crate::{ActiveModelTrait, ConnectionTrait, EntityTrait, Statement, UpdateMany, UpdateOne, error::*}; +use crate::{ + error::*, ActiveModelTrait, ConnectionTrait, EntityTrait, Statement, UpdateMany, UpdateOne, +}; use sea_query::UpdateStatement; use std::future::Future; @@ -17,7 +19,9 @@ where A: ActiveModelTrait, { pub async fn exec<'b, C>(self, db: &'b C) -> Result - where C: ConnectionTrait<'b> { + where + C: ConnectionTrait<'b>, + { // so that self is dropped before entering await exec_update_and_return_original(self.query, self.model, db).await } @@ -27,11 +31,10 @@ impl<'a, E> UpdateMany where E: EntityTrait, { - pub fn exec( - self, - db: &'a C, - ) -> impl Future> + 'a - where C: ConnectionTrait<'a> { + pub fn exec(self, db: &'a C) -> impl Future> + 'a + where + C: ConnectionTrait<'a>, + { // so that self is dropped before entering await exec_update_only(self.query, db) } @@ -42,21 +45,19 @@ impl Updater { Self { query } } - pub async fn exec<'a, C>( - self, - db: &'a C, - ) -> Result - where C: ConnectionTrait<'a> { + pub async fn exec<'a, C>(self, db: &'a C) -> Result + where + C: ConnectionTrait<'a>, + { let builder = db.get_database_backend(); exec_update(builder.build(&self.query), db).await } } -async fn exec_update_only<'a, C>( - query: UpdateStatement, - db: &'a C, -) -> Result -where C: ConnectionTrait<'a> { +async fn exec_update_only<'a, C>(query: UpdateStatement, db: &'a C) -> Result +where + C: ConnectionTrait<'a>, +{ Updater::new(query).exec(db).await } @@ -75,7 +76,9 @@ where // Only Statement impl Send async fn exec_update<'a, C>(statement: Statement, db: &'a C) -> Result -where C: ConnectionTrait<'a> { +where + C: ConnectionTrait<'a>, +{ let result = db.execute(statement).await?; Ok(UpdateResult { rows_affected: result.rows_affected(), diff --git a/src/lib.rs b/src/lib.rs index 910044a5..6ddc442c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -265,6 +265,7 @@ )] mod database; +mod docs; mod driver; pub mod entity; pub mod error; @@ -273,7 +274,6 @@ pub mod query; pub mod schema; #[doc(hidden)] pub mod tests_cfg; -mod docs; mod util; pub use database::*; diff --git a/src/query/mod.rs b/src/query/mod.rs index c7f60049..9a2d0aba 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -20,4 +20,4 @@ pub use select::*; pub use traits::*; pub use update::*; -pub use crate::{InsertResult, Statement, UpdateResult, Value, Values, ConnectionTrait}; +pub use crate::{ConnectionTrait, InsertResult, Statement, UpdateResult, Value, Values}; diff --git a/tests/common/setup/mod.rs b/tests/common/setup/mod.rs index 9deb903f..b9195edb 100644 --- a/tests/common/setup/mod.rs +++ b/tests/common/setup/mod.rs @@ -1,4 +1,4 @@ -use sea_orm::{Database, DatabaseBackend, DatabaseConnection, ConnectionTrait, Statement}; +use sea_orm::{ConnectionTrait, Database, DatabaseBackend, DatabaseConnection, Statement}; pub mod schema; pub use schema::*; diff --git a/tests/common/setup/schema.rs b/tests/common/setup/schema.rs index ac34304a..77d85308 100644 --- a/tests/common/setup/schema.rs +++ b/tests/common/setup/schema.rs @@ -1,6 +1,8 @@ pub use super::super::bakery_chain::*; use pretty_assertions::assert_eq; -use sea_orm::{error::*, sea_query, ConnectionTrait, DbBackend, DbConn, EntityTrait, ExecResult, Schema}; +use sea_orm::{ + error::*, sea_query, ConnectionTrait, DbBackend, DbConn, EntityTrait, ExecResult, Schema, +}; use sea_query::{ Alias, ColumnDef, ForeignKey, ForeignKeyAction, Index, Table, TableCreateStatement, }; diff --git a/tests/query_tests.rs b/tests/query_tests.rs index e688b14f..4b117dc6 100644 --- a/tests/query_tests.rs +++ b/tests/query_tests.rs @@ -2,7 +2,7 @@ pub mod common; pub use common::{bakery_chain::*, setup::*, TestContext}; pub use sea_orm::entity::*; -pub use sea_orm::{QueryFilter, ConnectionTrait}; +pub use sea_orm::{ConnectionTrait, QueryFilter}; // Run the test locally: // DATABASE_URL="mysql://root:@localhost" cargo test --features sqlx-mysql,runtime-async-std --test query_tests diff --git a/tests/stream_tests.rs b/tests/stream_tests.rs index 969b93e1..d30063e5 100644 --- a/tests/stream_tests.rs +++ b/tests/stream_tests.rs @@ -1,9 +1,9 @@ pub mod common; pub use common::{bakery_chain::*, setup::*, TestContext}; -pub use sea_orm::entity::*; -pub use sea_orm::{QueryFilter, ConnectionTrait, DbErr}; use futures::StreamExt; +pub use sea_orm::entity::*; +pub use sea_orm::{ConnectionTrait, DbErr, QueryFilter}; #[sea_orm_macros::test] #[cfg(any( diff --git a/tests/transaction_tests.rs b/tests/transaction_tests.rs index 539eaefc..61d194c4 100644 --- a/tests/transaction_tests.rs +++ b/tests/transaction_tests.rs @@ -1,9 +1,9 @@ pub mod common; pub use common::{bakery_chain::*, setup::*, TestContext}; -use sea_orm::{DatabaseTransaction, DbErr}; pub use sea_orm::entity::*; -pub use sea_orm::{QueryFilter, ConnectionTrait}; +pub use sea_orm::{ConnectionTrait, QueryFilter}; +use sea_orm::{DatabaseTransaction, DbErr}; #[sea_orm_macros::test] #[cfg(any( @@ -14,32 +14,37 @@ pub use sea_orm::{QueryFilter, ConnectionTrait}; pub async fn transaction() { let ctx = TestContext::new("transaction_test").await; - ctx.db.transaction::<_, _, DbErr>(|txn| Box::pin(async move { - let _ = bakery::ActiveModel { - name: Set("SeaSide Bakery".to_owned()), - profit_margin: Set(10.4), - ..Default::default() - } - .save(txn) - .await?; + ctx.db + .transaction::<_, _, DbErr>(|txn| { + Box::pin(async move { + let _ = bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(txn) + .await?; - let _ = bakery::ActiveModel { - name: Set("Top Bakery".to_owned()), - profit_margin: Set(15.0), - ..Default::default() - } - .save(txn) - .await?; + let _ = bakery::ActiveModel { + name: Set("Top Bakery".to_owned()), + profit_margin: Set(15.0), + ..Default::default() + } + .save(txn) + .await?; - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; - assert_eq!(bakeries.len(), 2); + assert_eq!(bakeries.len(), 2); - Ok(()) - })).await.unwrap(); + Ok(()) + }) + }) + .await + .unwrap(); ctx.delete().await; } @@ -55,28 +60,36 @@ pub async fn transaction_with_reference() { let name1 = "SeaSide Bakery"; let name2 = "Top Bakery"; let search_name = "Bakery"; - ctx.db.transaction(|txn| _transaction_with_reference(txn, name1, name2, search_name)).await.unwrap(); + ctx.db + .transaction(|txn| _transaction_with_reference(txn, name1, name2, search_name)) + .await + .unwrap(); ctx.delete().await; } -fn _transaction_with_reference<'a>(txn: &'a DatabaseTransaction, name1: &'a str, name2: &'a str, search_name: &'a str) -> std::pin::Pin> + Send + 'a>> { +fn _transaction_with_reference<'a>( + txn: &'a DatabaseTransaction, + name1: &'a str, + name2: &'a str, + search_name: &'a str, +) -> std::pin::Pin> + Send + 'a>> { Box::pin(async move { let _ = bakery::ActiveModel { name: Set(name1.to_owned()), profit_margin: Set(10.4), ..Default::default() } - .save(txn) - .await?; + .save(txn) + .await?; let _ = bakery::ActiveModel { name: Set(name2.to_owned()), profit_margin: Set(15.0), ..Default::default() } - .save(txn) - .await?; + .save(txn) + .await?; let bakeries = Bakery::find() .filter(bakery::Column::Name.contains(search_name)) From ab8fec21b315346bbeb1d9fc619d85192f399770 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 4 Oct 2021 12:09:50 +0800 Subject: [PATCH 17/65] Fix build errors --- examples/actix4_example/src/setup.rs | 2 +- examples/actix_example/src/setup.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/actix4_example/src/setup.rs b/examples/actix4_example/src/setup.rs index 034e8b53..04677af4 100644 --- a/examples/actix4_example/src/setup.rs +++ b/examples/actix4_example/src/setup.rs @@ -1,5 +1,5 @@ use sea_orm::sea_query::{ColumnDef, TableCreateStatement}; -use sea_orm::{error::*, sea_query, DbConn, ExecResult}; +use sea_orm::{error::*, sea_query, ConnectionTrait, DbConn, ExecResult}; async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result { let builder = db.get_database_backend(); diff --git a/examples/actix_example/src/setup.rs b/examples/actix_example/src/setup.rs index 034e8b53..04677af4 100644 --- a/examples/actix_example/src/setup.rs +++ b/examples/actix_example/src/setup.rs @@ -1,5 +1,5 @@ use sea_orm::sea_query::{ColumnDef, TableCreateStatement}; -use sea_orm::{error::*, sea_query, DbConn, ExecResult}; +use sea_orm::{error::*, sea_query, ConnectionTrait, DbConn, ExecResult}; async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result { let builder = db.get_database_backend(); From af93ea44ad970cd04476730b80cb613f7471ca8c Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 4 Oct 2021 12:10:06 +0800 Subject: [PATCH 18/65] Fix clippy warnings --- src/executor/query.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/executor/query.rs b/src/executor/query.rs index 0248fa5c..e6c2d124 100644 --- a/src/executor/query.rs +++ b/src/executor/query.rs @@ -126,12 +126,12 @@ macro_rules! try_getable_unsigned { ( $type: ty ) => { impl TryGetable for $type { fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result { - let column = format!("{}{}", pre, col); + let _column = format!("{}{}", pre, col); match &res.row { #[cfg(feature = "sqlx-mysql")] QueryResultRow::SqlxMySql(row) => { use sqlx::Row; - row.try_get::, _>(column.as_str()) + row.try_get::, _>(_column.as_str()) .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e))) .and_then(|opt| opt.ok_or(TryGetError::Null)) } @@ -142,13 +142,13 @@ macro_rules! try_getable_unsigned { #[cfg(feature = "sqlx-sqlite")] QueryResultRow::SqlxSqlite(row) => { use sqlx::Row; - row.try_get::, _>(column.as_str()) + row.try_get::, _>(_column.as_str()) .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e))) .and_then(|opt| opt.ok_or(TryGetError::Null)) } #[cfg(feature = "mock")] #[allow(unused_variables)] - QueryResultRow::Mock(row) => row.try_get(column.as_str()).map_err(|e| { + QueryResultRow::Mock(row) => row.try_get(_column.as_str()).map_err(|e| { debug_print!("{:#?}", e.to_string()); TryGetError::Null }), @@ -162,12 +162,12 @@ macro_rules! try_getable_mysql { ( $type: ty ) => { impl TryGetable for $type { fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result { - let column = format!("{}{}", pre, col); + let _column = format!("{}{}", pre, col); match &res.row { #[cfg(feature = "sqlx-mysql")] QueryResultRow::SqlxMySql(row) => { use sqlx::Row; - row.try_get::, _>(column.as_str()) + row.try_get::, _>(_column.as_str()) .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e))) .and_then(|opt| opt.ok_or(TryGetError::Null)) } @@ -181,7 +181,7 @@ macro_rules! try_getable_mysql { } #[cfg(feature = "mock")] #[allow(unused_variables)] - QueryResultRow::Mock(row) => row.try_get(column.as_str()).map_err(|e| { + QueryResultRow::Mock(row) => row.try_get(_column.as_str()).map_err(|e| { debug_print!("{:#?}", e.to_string()); TryGetError::Null }), @@ -195,7 +195,7 @@ macro_rules! try_getable_postgres { ( $type: ty ) => { impl TryGetable for $type { fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result { - let column = format!("{}{}", pre, col); + let _column = format!("{}{}", pre, col); match &res.row { #[cfg(feature = "sqlx-mysql")] QueryResultRow::SqlxMySql(_) => { @@ -204,7 +204,7 @@ macro_rules! try_getable_postgres { #[cfg(feature = "sqlx-postgres")] QueryResultRow::SqlxPostgres(row) => { use sqlx::Row; - row.try_get::, _>(column.as_str()) + row.try_get::, _>(_column.as_str()) .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e))) .and_then(|opt| opt.ok_or(TryGetError::Null)) } @@ -214,7 +214,7 @@ macro_rules! try_getable_postgres { } #[cfg(feature = "mock")] #[allow(unused_variables)] - QueryResultRow::Mock(row) => row.try_get(column.as_str()).map_err(|e| { + QueryResultRow::Mock(row) => row.try_get(_column.as_str()).map_err(|e| { debug_print!("{:#?}", e.to_string()); TryGetError::Null }), From 35b8eb9a4df9c284f03625c707c677f978b881fb Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Mon, 4 Oct 2021 12:17:31 +0800 Subject: [PATCH 19/65] `ActiveValue::take()` & `ActiveValue::into_value()` without `unwrap()` --- src/entity/active_model.rs | 12 ++++++------ src/query/insert.rs | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index cfcb0bbd..6f4b86f6 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -209,23 +209,23 @@ where matches!(self.state, ActiveValueState::Unset) } - pub fn take(&mut self) -> V { + pub fn take(&mut self) -> Option { self.state = ActiveValueState::Unset; - self.value.take().unwrap() + self.value.take() } pub fn unwrap(self) -> V { self.value.unwrap() } - pub fn into_value(self) -> Value { - self.value.unwrap().into() + pub fn into_value(self) -> Option { + self.value.map(Into::into) } pub fn into_wrapped_value(self) -> ActiveValue { match self.state { - ActiveValueState::Set => ActiveValue::set(self.into_value()), - ActiveValueState::Unchanged => ActiveValue::unchanged(self.into_value()), + ActiveValueState::Set => ActiveValue::set(self.into_value().unwrap()), + ActiveValueState::Unchanged => ActiveValue::unchanged(self.into_value().unwrap()), ActiveValueState::Unset => ActiveValue::unset(), } } diff --git a/src/query/insert.rs b/src/query/insert.rs index a65071e1..f2d60dc8 100644 --- a/src/query/insert.rs +++ b/src/query/insert.rs @@ -120,7 +120,7 @@ where } if av_has_val { columns.push(col); - values.push(av.into_value()); + values.push(av.into_value().unwrap()); } } self.query.columns(columns); From af6665fe4f9e8fa6d7a9d4ad85907d59f27ddb56 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 4 Oct 2021 12:12:58 +0800 Subject: [PATCH 20/65] Fix clippy warnings --- src/database/connection.rs | 5 ++++- src/database/stream/query.rs | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/database/connection.rs b/src/database/connection.rs index d2283caf..2c65fa28 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -3,7 +3,10 @@ use crate::{ StatementBuilder, TransactionError, }; use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; -use std::{future::Future, pin::Pin, sync::Arc}; +use std::{future::Future, pin::Pin}; + +#[cfg(feature = "mock")] +use std::sync::Arc; #[cfg_attr(not(feature = "mock"), derive(Clone))] pub enum DatabaseConnection { diff --git a/src/database/stream/query.rs b/src/database/stream/query.rs index 73668da0..8383659a 100644 --- a/src/database/stream/query.rs +++ b/src/database/stream/query.rs @@ -1,4 +1,7 @@ -use std::{pin::Pin, sync::Arc, task::Poll}; +use std::{pin::Pin, task::Poll}; + +#[cfg(feature = "mock")] +use std::sync::Arc; use futures::Stream; #[cfg(feature = "sqlx-dep")] From b74b1f343f6e8090029db381018aba30f71e764f Mon Sep 17 00:00:00 2001 From: "baoyachi. Aka Rust Hairy crabs" Date: Mon, 4 Oct 2021 12:57:34 +0800 Subject: [PATCH 21/65] Add debug_query and debug_query_stmt macro (#189) --- src/query/mod.rs | 2 + src/query/util.rs | 112 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 src/query/util.rs diff --git a/src/query/mod.rs b/src/query/mod.rs index 54cc12dd..5d2be142 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -8,6 +8,7 @@ mod json; mod select; mod traits; mod update; +mod util; pub use combine::{SelectA, SelectB}; pub use delete::*; @@ -19,5 +20,6 @@ pub use json::*; pub use select::*; pub use traits::*; pub use update::*; +pub use util::*; pub use crate::{InsertResult, Statement, UpdateResult, Value, Values}; diff --git a/src/query/util.rs b/src/query/util.rs new file mode 100644 index 00000000..7e771f7c --- /dev/null +++ b/src/query/util.rs @@ -0,0 +1,112 @@ +use crate::{DatabaseConnection, DbBackend, QueryTrait, Statement}; + +#[derive(Debug)] +pub struct DebugQuery<'a, Q, T> { + pub query: &'a Q, + pub value: T, +} + +macro_rules! debug_query_build { + ($impl_obj:ty, $db_expr:expr) => { + impl<'a, Q> DebugQuery<'a, Q, $impl_obj> + where + Q: QueryTrait, + { + pub fn build(&self) -> Statement { + let func = $db_expr; + let db_backend = func(self); + self.query.build(db_backend) + } + } + }; +} + +debug_query_build!(DbBackend, |x: &DebugQuery<_, DbBackend>| x.value); +debug_query_build!(&DbBackend, |x: &DebugQuery<_, &DbBackend>| *x.value); +debug_query_build!( + DatabaseConnection, + |x: &DebugQuery<_, DatabaseConnection>| x.value.get_database_backend() +); +debug_query_build!( + &DatabaseConnection, + |x: &DebugQuery<_, &DatabaseConnection>| x.value.get_database_backend() +); + +/// Helper to get a `Statement` from an object that impl `QueryTrait`. +/// +/// # Example +/// +/// ``` +/// # #[cfg(feature = "mock")] +/// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; +/// # +/// # let conn = MockDatabase::new(DbBackend::Postgres) +/// # .into_connection(); +/// # +/// use sea_orm::{entity::*, query::*, tests_cfg::cake, debug_query_stmt}; +/// +/// let c = cake::Entity::insert( +/// cake::ActiveModel { +/// id: ActiveValue::set(1), +/// name: ActiveValue::set("Apple Pie".to_owned()), +/// }); +/// +/// let raw_sql = debug_query_stmt!(&c, &conn).to_string(); +/// assert_eq!(raw_sql, r#"INSERT INTO "cake" ("id", "name") VALUES (1, 'Apple Pie')"#); +/// +/// let raw_sql = debug_query_stmt!(&c, conn).to_string(); +/// assert_eq!(raw_sql, r#"INSERT INTO "cake" ("id", "name") VALUES (1, 'Apple Pie')"#); +/// +/// let raw_sql = debug_query_stmt!(&c, DbBackend::MySql).to_string(); +/// assert_eq!(raw_sql, r#"INSERT INTO `cake` (`id`, `name`) VALUES (1, 'Apple Pie')"#); +/// +/// let raw_sql = debug_query_stmt!(&c, &DbBackend::MySql).to_string(); +/// assert_eq!(raw_sql, r#"INSERT INTO `cake` (`id`, `name`) VALUES (1, 'Apple Pie')"#); +/// +/// ``` +#[macro_export] +macro_rules! debug_query_stmt { + ($query:expr,$value:expr) => { + $crate::DebugQuery { + query: $query, + value: $value, + } + .build(); + }; +} + +/// Helper to get a raw SQL string from an object that impl `QueryTrait`. +/// +/// # Example +/// +/// ``` +/// # #[cfg(feature = "mock")] +/// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; +/// # +/// # let conn = MockDatabase::new(DbBackend::Postgres) +/// # .into_connection(); +/// # +/// use sea_orm::{entity::*, query::*, tests_cfg::cake,debug_query}; +/// +/// let c = cake::Entity::insert( +/// cake::ActiveModel { +/// id: ActiveValue::set(1), +/// name: ActiveValue::set("Apple Pie".to_owned()), +/// }); +/// +/// let raw_sql = debug_query!(&c, &conn); +/// assert_eq!(raw_sql, r#"INSERT INTO "cake" ("id", "name") VALUES (1, 'Apple Pie')"#); +/// +/// let raw_sql = debug_query!(&c, conn); +/// assert_eq!(raw_sql, r#"INSERT INTO "cake" ("id", "name") VALUES (1, 'Apple Pie')"#); +/// +/// let raw_sql = debug_query!(&c, DbBackend::Sqlite); +/// assert_eq!(raw_sql, r#"INSERT INTO `cake` (`id`, `name`) VALUES (1, 'Apple Pie')"#); +/// +/// ``` +#[macro_export] +macro_rules! debug_query { + ($query:expr,$value:expr) => { + $crate::debug_query_stmt!($query, $value).to_string(); + }; +} From 4fd5d56dbf1bc95b6219219231e27098bfed9d9e Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 4 Oct 2021 13:13:36 +0800 Subject: [PATCH 22/65] cargo +nightly fmt --- src/docs.rs | 2 +- src/entity/model.rs | 10 ++++------ src/executor/query.rs | 2 +- src/executor/select.rs | 9 ++++----- src/lib.rs | 2 +- 5 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/docs.rs b/src/docs.rs index bab054ef..4d1226c3 100644 --- a/src/docs.rs +++ b/src/docs.rs @@ -163,4 +163,4 @@ //! }, //! ) //! } -//! ``` \ No newline at end of file +//! ``` diff --git a/src/entity/model.rs b/src/entity/model.rs index 318b8e70..cd96b22e 100644 --- a/src/entity/model.rs +++ b/src/entity/model.rs @@ -69,12 +69,10 @@ pub trait FromQueryResult: Sized { /// /// assert_eq!( /// res, - /// vec![ - /// SelectResult { - /// name: "Chocolate Forest".to_owned(), - /// num_of_cakes: 2, - /// }, - /// ] + /// vec![SelectResult { + /// name: "Chocolate Forest".to_owned(), + /// num_of_cakes: 2, + /// },] /// ); /// # /// # Ok(()) diff --git a/src/executor/query.rs b/src/executor/query.rs index e6c2d124..a164e911 100644 --- a/src/executor/query.rs +++ b/src/executor/query.rs @@ -326,7 +326,7 @@ pub trait TryGetableMany: Sized { /// # ]]) /// # .into_connection(); /// # - /// use sea_orm::{entity::*, query::*, tests_cfg::cake, EnumIter, DeriveIden, TryGetableMany}; + /// use sea_orm::{entity::*, query::*, tests_cfg::cake, DeriveIden, EnumIter, TryGetableMany}; /// /// #[derive(EnumIter, DeriveIden)] /// enum ResultCol { diff --git a/src/executor/select.rs b/src/executor/select.rs index bb386722..0db698f0 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -207,10 +207,7 @@ where /// .all(&db) /// .await?; /// - /// assert_eq!( - /// res, - /// vec![("Chocolate Forest".to_owned(), 2i64)] - /// ); + /// assert_eq!(res, vec![("Chocolate Forest".to_owned(), 2i64)]); /// # /// # Ok(()) /// # }); @@ -222,7 +219,9 @@ where /// vec![ /// r#"SELECT "cake"."name" AS "cake_name", COUNT("cake"."id") AS "num_of_cakes""#, /// r#"FROM "cake" GROUP BY "cake"."name""#, - /// ].join(" ").as_str(), + /// ] + /// .join(" ") + /// .as_str(), /// vec![] /// )] /// ); diff --git a/src/lib.rs b/src/lib.rs index 910044a5..6ddc442c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -265,6 +265,7 @@ )] mod database; +mod docs; mod driver; pub mod entity; pub mod error; @@ -273,7 +274,6 @@ pub mod query; pub mod schema; #[doc(hidden)] pub mod tests_cfg; -mod docs; mod util; pub use database::*; From 6b98a6f3955155ab062b739ddda86f35c8d3cf83 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 4 Oct 2021 20:40:27 +0800 Subject: [PATCH 23/65] Move code --- src/database/connection.rs | 236 ++------------------ src/database/db_connection.rs | 239 +++++++++++++++++++-- src/database/db_transaction.rs | 370 -------------------------------- src/database/mock.rs | 44 +++- src/database/mod.rs | 2 - src/database/transaction.rs | 380 ++++++++++++++++++++++++++++++--- 6 files changed, 634 insertions(+), 637 deletions(-) delete mode 100644 src/database/db_transaction.rs diff --git a/src/database/connection.rs b/src/database/connection.rs index 2c65fa28..d90c72a9 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -1,250 +1,40 @@ use crate::{ - error::*, ConnectionTrait, DatabaseTransaction, ExecResult, QueryResult, Statement, - StatementBuilder, TransactionError, + DatabaseTransaction, DbBackend, DbErr, ExecResult, QueryResult, Statement, TransactionError, }; -use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; +use futures::Stream; use std::{future::Future, pin::Pin}; -#[cfg(feature = "mock")] -use std::sync::Arc; - -#[cfg_attr(not(feature = "mock"), derive(Clone))] -pub enum DatabaseConnection { - #[cfg(feature = "sqlx-mysql")] - SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection), - #[cfg(feature = "sqlx-postgres")] - SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection), - #[cfg(feature = "sqlx-sqlite")] - SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection), - #[cfg(feature = "mock")] - MockDatabaseConnection(Arc), - Disconnected, -} - -pub type DbConn = DatabaseConnection; - -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum DatabaseBackend { - MySql, - Postgres, - Sqlite, -} - -pub type DbBackend = DatabaseBackend; - -impl Default for DatabaseConnection { - fn default() -> Self { - Self::Disconnected - } -} - -impl std::fmt::Debug for DatabaseConnection { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "{}", - match self { - #[cfg(feature = "sqlx-mysql")] - Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection", - #[cfg(feature = "sqlx-postgres")] - Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection", - #[cfg(feature = "sqlx-sqlite")] - Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection", - #[cfg(feature = "mock")] - Self::MockDatabaseConnection(_) => "MockDatabaseConnection", - Self::Disconnected => "Disconnected", - } - ) - } -} - #[async_trait::async_trait] -impl<'a> ConnectionTrait<'a> for DatabaseConnection { - type Stream = crate::QueryStream; +pub trait ConnectionTrait<'a>: Sync { + type Stream: Stream>; - fn get_database_backend(&self) -> DbBackend { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.get_database_backend(), - DatabaseConnection::Disconnected => panic!("Disconnected"), - } - } + fn get_database_backend(&self) -> DbBackend; - async fn execute(&self, stmt: Statement) -> Result { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt), - DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), - } - } + async fn execute(&self, stmt: Statement) -> Result; - async fn query_one(&self, stmt: Statement) -> Result, DbErr> { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt), - DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), - } - } + async fn query_one(&self, stmt: Statement) -> Result, DbErr>; - async fn query_all(&self, stmt: Statement) -> Result, DbErr> { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt), - DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), - } - } + async fn query_all(&self, stmt: Statement) -> Result, DbErr>; fn stream( &'a self, stmt: Statement, - ) -> Pin> + 'a>> { - Box::pin(async move { - Ok(match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => { - crate::QueryStream::from((Arc::clone(conn), stmt)) - } - DatabaseConnection::Disconnected => panic!("Disconnected"), - }) - }) - } + ) -> Pin> + 'a>>; - async fn begin(&self) -> Result { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => { - DatabaseTransaction::new_mock(Arc::clone(conn)).await - } - DatabaseConnection::Disconnected => panic!("Disconnected"), - } - } + async fn begin(&self) -> Result; /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. - async fn transaction(&self, _callback: F) -> Result> + async fn transaction(&self, callback: F) -> Result> where F: for<'c> FnOnce( &'c DatabaseTransaction, ) -> Pin> + Send + 'c>> + Send, T: Send, - E: std::error::Error + Send, - { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => { - conn.transaction(_callback).await - } - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => { - let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)) - .await - .map_err(|e| TransactionError::Connection(e))?; - transaction.run(_callback).await - } - DatabaseConnection::Disconnected => panic!("Disconnected"), - } - } + E: std::error::Error + Send; - #[cfg(feature = "mock")] fn is_mock_connection(&self) -> bool { - match self { - DatabaseConnection::MockDatabaseConnection(_) => true, - _ => false, - } - } -} - -#[cfg(feature = "mock")] -impl DatabaseConnection { - pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection { - match self { - DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn, - _ => panic!("not mock connection"), - } - } - - pub fn into_transaction_log(self) -> Vec { - let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap(); - mocker.drain_transaction_log() - } -} - -impl DbBackend { - pub fn is_prefix_of(self, base_url: &str) -> bool { - match self { - Self::Postgres => { - base_url.starts_with("postgres://") || base_url.starts_with("postgresql://") - } - Self::MySql => base_url.starts_with("mysql://"), - Self::Sqlite => base_url.starts_with("sqlite:"), - } - } - - pub fn build(&self, statement: &S) -> Statement - where - S: StatementBuilder, - { - statement.build(self) - } - - pub fn get_query_builder(&self) -> Box { - match self { - Self::MySql => Box::new(MysqlQueryBuilder), - Self::Postgres => Box::new(PostgresQueryBuilder), - Self::Sqlite => Box::new(SqliteQueryBuilder), - } - } -} - -#[cfg(test)] -mod tests { - use crate::DatabaseConnection; - - #[test] - fn assert_database_connection_traits() { - fn assert_send_sync() {} - - assert_send_sync::(); + false } } diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 6040e452..60d70ac7 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -1,10 +1,39 @@ use crate::{ - DatabaseTransaction, DbBackend, DbErr, ExecResult, QueryResult, Statement, TransactionError, + error::*, ConnectionTrait, DatabaseTransaction, ExecResult, QueryResult, Statement, + StatementBuilder, TransactionError, }; -use futures::Stream; +use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; +use std::{future::Future, pin::Pin}; + #[cfg(feature = "sqlx-dep")] use sqlx::pool::PoolConnection; -use std::{future::Future, pin::Pin}; + +#[cfg(feature = "mock")] +use std::sync::Arc; + +#[cfg_attr(not(feature = "mock"), derive(Clone))] +pub enum DatabaseConnection { + #[cfg(feature = "sqlx-mysql")] + SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection), + #[cfg(feature = "sqlx-postgres")] + SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection), + #[cfg(feature = "sqlx-sqlite")] + SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection), + #[cfg(feature = "mock")] + MockDatabaseConnection(Arc), + Disconnected, +} + +pub type DbConn = DatabaseConnection; + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum DatabaseBackend { + MySql, + Postgres, + Sqlite, +} + +pub type DbBackend = DatabaseBackend; pub(crate) enum InnerConnection { #[cfg(feature = "sqlx-mysql")] @@ -17,37 +46,219 @@ pub(crate) enum InnerConnection { Mock(std::sync::Arc), } +impl Default for DatabaseConnection { + fn default() -> Self { + Self::Disconnected + } +} + +impl std::fmt::Debug for DatabaseConnection { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "{}", + match self { + #[cfg(feature = "sqlx-mysql")] + Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection", + #[cfg(feature = "sqlx-postgres")] + Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection", + #[cfg(feature = "sqlx-sqlite")] + Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection", + #[cfg(feature = "mock")] + Self::MockDatabaseConnection(_) => "MockDatabaseConnection", + Self::Disconnected => "Disconnected", + } + ) + } +} + #[async_trait::async_trait] -pub trait ConnectionTrait<'a>: Sync { - type Stream: Stream>; +impl<'a> ConnectionTrait<'a> for DatabaseConnection { + type Stream = crate::QueryStream; - fn get_database_backend(&self) -> DbBackend; + fn get_database_backend(&self) -> DbBackend { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => conn.get_database_backend(), + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } - async fn execute(&self, stmt: Statement) -> Result; + async fn execute(&self, stmt: Statement) -> Result { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt), + DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), + } + } - async fn query_one(&self, stmt: Statement) -> Result, DbErr>; + async fn query_one(&self, stmt: Statement) -> Result, DbErr> { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt), + DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), + } + } - async fn query_all(&self, stmt: Statement) -> Result, DbErr>; + async fn query_all(&self, stmt: Statement) -> Result, DbErr> { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt), + DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), + } + } fn stream( &'a self, stmt: Statement, - ) -> Pin> + 'a>>; + ) -> Pin> + 'a>> { + Box::pin(async move { + Ok(match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => { + crate::QueryStream::from((Arc::clone(conn), stmt)) + } + DatabaseConnection::Disconnected => panic!("Disconnected"), + }) + }) + } - async fn begin(&self) -> Result; + async fn begin(&self) -> Result { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => { + DatabaseTransaction::new_mock(Arc::clone(conn)).await + } + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. - async fn transaction(&self, callback: F) -> Result> + async fn transaction(&self, _callback: F) -> Result> where F: for<'c> FnOnce( &'c DatabaseTransaction, ) -> Pin> + Send + 'c>> + Send, T: Send, - E: std::error::Error + Send; + E: std::error::Error + Send, + { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => { + conn.transaction(_callback).await + } + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => { + let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)) + .await + .map_err(|e| TransactionError::Connection(e))?; + transaction.run(_callback).await + } + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } + #[cfg(feature = "mock")] fn is_mock_connection(&self) -> bool { - false + match self { + DatabaseConnection::MockDatabaseConnection(_) => true, + _ => false, + } + } +} + +#[cfg(feature = "mock")] +impl DatabaseConnection { + pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection { + match self { + DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn, + _ => panic!("not mock connection"), + } + } + + pub fn into_transaction_log(self) -> Vec { + let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap(); + mocker.drain_transaction_log() + } +} + +impl DbBackend { + pub fn is_prefix_of(self, base_url: &str) -> bool { + match self { + Self::Postgres => { + base_url.starts_with("postgres://") || base_url.starts_with("postgresql://") + } + Self::MySql => base_url.starts_with("mysql://"), + Self::Sqlite => base_url.starts_with("sqlite:"), + } + } + + pub fn build(&self, statement: &S) -> Statement + where + S: StatementBuilder, + { + statement.build(self) + } + + pub fn get_query_builder(&self) -> Box { + match self { + Self::MySql => Box::new(MysqlQueryBuilder), + Self::Postgres => Box::new(PostgresQueryBuilder), + Self::Sqlite => Box::new(SqliteQueryBuilder), + } + } +} + +#[cfg(test)] +mod tests { + use crate::DatabaseConnection; + + #[test] + fn assert_database_connection_traits() { + fn assert_send_sync() {} + + assert_send_sync::(); } } diff --git a/src/database/db_transaction.rs b/src/database/db_transaction.rs deleted file mode 100644 index ae954097..00000000 --- a/src/database/db_transaction.rs +++ /dev/null @@ -1,370 +0,0 @@ -use crate::{ - debug_print, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, - Statement, TransactionStream, -}; -#[cfg(feature = "sqlx-dep")] -use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err}; -use futures::lock::Mutex; -#[cfg(feature = "sqlx-dep")] -use sqlx::{pool::PoolConnection, TransactionManager}; -use std::{future::Future, pin::Pin, sync::Arc}; - -// a Transaction is just a sugar for a connection where START TRANSACTION has been executed -pub struct DatabaseTransaction { - conn: Arc>, - backend: DbBackend, - open: bool, -} - -impl std::fmt::Debug for DatabaseTransaction { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "DatabaseTransaction") - } -} - -impl DatabaseTransaction { - #[cfg(feature = "sqlx-mysql")] - pub(crate) async fn new_mysql( - inner: PoolConnection, - ) -> Result { - Self::build( - Arc::new(Mutex::new(InnerConnection::MySql(inner))), - DbBackend::MySql, - ) - .await - } - - #[cfg(feature = "sqlx-postgres")] - pub(crate) async fn new_postgres( - inner: PoolConnection, - ) -> Result { - Self::build( - Arc::new(Mutex::new(InnerConnection::Postgres(inner))), - DbBackend::Postgres, - ) - .await - } - - #[cfg(feature = "sqlx-sqlite")] - pub(crate) async fn new_sqlite( - inner: PoolConnection, - ) -> Result { - Self::build( - Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), - DbBackend::Sqlite, - ) - .await - } - - #[cfg(feature = "mock")] - pub(crate) async fn new_mock( - inner: Arc, - ) -> Result { - let backend = inner.get_database_backend(); - Self::build(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await - } - - async fn build( - conn: Arc>, - backend: DbBackend, - ) -> Result { - let res = DatabaseTransaction { - conn, - backend, - open: true, - }; - match *res.conn.lock().await { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(ref mut c) => { - ::TransactionManager::begin(c) - .await - .map_err(sqlx_error_to_query_err)? - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(ref mut c) => { - ::TransactionManager::begin(c) - .await - .map_err(sqlx_error_to_query_err)? - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(ref mut c) => { - ::TransactionManager::begin(c) - .await - .map_err(sqlx_error_to_query_err)? - } - // should we do something for mocked connections? - #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} - } - Ok(res) - } - - pub(crate) async fn run(self, callback: F) -> Result> - where - F: for<'b> FnOnce( - &'b DatabaseTransaction, - ) -> Pin> + Send + 'b>> - + Send, - T: Send, - E: std::error::Error + Send, - { - let res = callback(&self) - .await - .map_err(|e| TransactionError::Transaction(e)); - if res.is_ok() { - self.commit() - .await - .map_err(|e| TransactionError::Connection(e))?; - } else { - self.rollback() - .await - .map_err(|e| TransactionError::Connection(e))?; - } - res - } - - pub async fn commit(mut self) -> Result<(), DbErr> { - self.open = false; - match *self.conn.lock().await { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(ref mut c) => { - ::TransactionManager::commit(c) - .await - .map_err(sqlx_error_to_query_err)? - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(ref mut c) => { - ::TransactionManager::commit(c) - .await - .map_err(sqlx_error_to_query_err)? - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(ref mut c) => { - ::TransactionManager::commit(c) - .await - .map_err(sqlx_error_to_query_err)? - } - //Should we do something for mocked connections? - #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} - } - Ok(()) - } - - pub async fn rollback(mut self) -> Result<(), DbErr> { - self.open = false; - match *self.conn.lock().await { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(ref mut c) => { - ::TransactionManager::rollback(c) - .await - .map_err(sqlx_error_to_query_err)? - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(ref mut c) => { - ::TransactionManager::rollback(c) - .await - .map_err(sqlx_error_to_query_err)? - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(ref mut c) => { - ::TransactionManager::rollback(c) - .await - .map_err(sqlx_error_to_query_err)? - } - //Should we do something for mocked connections? - #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} - } - Ok(()) - } - - // the rollback is queued and will be performed on next async operation, like returning the connection to the pool - fn start_rollback(&mut self) { - if self.open { - if let Some(mut conn) = self.conn.try_lock() { - match &mut *conn { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(c) => { - ::TransactionManager::start_rollback(c); - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(c) => { - ::TransactionManager::start_rollback(c); - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(c) => { - ::TransactionManager::start_rollback(c); - } - //Should we do something for mocked connections? - #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} - } - } else { - //this should never happen - panic!("Dropping a locked Transaction"); - } - } - } -} - -impl Drop for DatabaseTransaction { - fn drop(&mut self) { - self.start_rollback(); - } -} - -#[async_trait::async_trait] -impl<'a> ConnectionTrait<'a> for DatabaseTransaction { - type Stream = TransactionStream<'a>; - - fn get_database_backend(&self) -> DbBackend { - // this way we don't need to lock - self.backend - } - - async fn execute(&self, stmt: Statement) -> Result { - debug_print!("{}", stmt); - - let _res = match &mut *self.conn.lock().await { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(conn) => { - let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - query.execute(conn).await.map(Into::into) - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(conn) => { - let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - query.execute(conn).await.map(Into::into) - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(conn) => { - let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); - query.execute(conn).await.map(Into::into) - } - #[cfg(feature = "mock")] - InnerConnection::Mock(conn) => return conn.execute(stmt), - }; - #[cfg(feature = "sqlx-dep")] - _res.map_err(sqlx_error_to_exec_err) - } - - async fn query_one(&self, stmt: Statement) -> Result, DbErr> { - debug_print!("{}", stmt); - - let _res = match &mut *self.conn.lock().await { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(conn) => { - let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - query.fetch_one(conn).await.map(|row| Some(row.into())) - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(conn) => { - let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - query.fetch_one(conn).await.map(|row| Some(row.into())) - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(conn) => { - let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); - query.fetch_one(conn).await.map(|row| Some(row.into())) - } - #[cfg(feature = "mock")] - InnerConnection::Mock(conn) => return conn.query_one(stmt), - }; - #[cfg(feature = "sqlx-dep")] - if let Err(sqlx::Error::RowNotFound) = _res { - Ok(None) - } else { - _res.map_err(sqlx_error_to_query_err) - } - } - - async fn query_all(&self, stmt: Statement) -> Result, DbErr> { - debug_print!("{}", stmt); - - let _res = match &mut *self.conn.lock().await { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(conn) => { - let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - query - .fetch_all(conn) - .await - .map(|rows| rows.into_iter().map(|r| r.into()).collect()) - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(conn) => { - let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - query - .fetch_all(conn) - .await - .map(|rows| rows.into_iter().map(|r| r.into()).collect()) - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(conn) => { - let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); - query - .fetch_all(conn) - .await - .map(|rows| rows.into_iter().map(|r| r.into()).collect()) - } - #[cfg(feature = "mock")] - InnerConnection::Mock(conn) => return conn.query_all(stmt), - }; - #[cfg(feature = "sqlx-dep")] - _res.map_err(sqlx_error_to_query_err) - } - - fn stream( - &'a self, - stmt: Statement, - ) -> Pin> + 'a>> { - Box::pin( - async move { Ok(crate::TransactionStream::build(self.conn.lock().await, stmt).await) }, - ) - } - - async fn begin(&self) -> Result { - DatabaseTransaction::build(Arc::clone(&self.conn), self.backend).await - } - - /// Execute the function inside a transaction. - /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. - async fn transaction(&self, _callback: F) -> Result> - where - F: for<'c> FnOnce( - &'c DatabaseTransaction, - ) -> Pin> + Send + 'c>> - + Send, - T: Send, - E: std::error::Error + Send, - { - let transaction = self - .begin() - .await - .map_err(|e| TransactionError::Connection(e))?; - transaction.run(_callback).await - } -} - -#[derive(Debug)] -pub enum TransactionError -where - E: std::error::Error, -{ - Connection(DbErr), - Transaction(E), -} - -impl std::fmt::Display for TransactionError -where - E: std::error::Error, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - TransactionError::Connection(e) => std::fmt::Display::fmt(e, f), - TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f), - } - } -} - -impl std::error::Error for TransactionError where E: std::error::Error {} diff --git a/src/database/mock.rs b/src/database/mock.rs index f42add7a..f98c4b9d 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -1,9 +1,9 @@ use crate::{ error::*, DatabaseConnection, DbBackend, EntityTrait, ExecResult, ExecResultHolder, Iden, Iterable, MockDatabaseConnection, MockDatabaseTrait, ModelTrait, QueryResult, QueryResultRow, - Statement, Transaction, + Statement, }; -use sea_query::{Value, ValueType}; +use sea_query::{Value, ValueType, Values}; use std::{collections::BTreeMap, sync::Arc}; #[derive(Debug)] @@ -29,6 +29,11 @@ pub trait IntoMockRow { fn into_mock_row(self) -> MockRow; } +#[derive(Debug, Clone, PartialEq)] +pub struct Transaction { + stmts: Vec, +} + impl MockDatabase { pub fn new(db_backend: DbBackend) -> Self { Self { @@ -134,3 +139,38 @@ impl IntoMockRow for BTreeMap<&str, Value> { } } } + +impl Transaction { + pub fn from_sql_and_values(db_backend: DbBackend, sql: &str, values: I) -> Self + where + I: IntoIterator, + { + Self::one(Statement::from_string_values_tuple( + db_backend, + (sql.to_string(), Values(values.into_iter().collect())), + )) + } + + /// Create a Transaction with one statement + pub fn one(stmt: Statement) -> Self { + Self { stmts: vec![stmt] } + } + + /// Create a Transaction with many statements + pub fn many(stmts: I) -> Self + where + I: IntoIterator, + { + Self { + stmts: stmts.into_iter().collect(), + } + } + + /// Wrap each Statement as a single-statement Transaction + pub fn wrap(stmts: I) -> Vec + where + I: IntoIterator, + { + stmts.into_iter().map(Self::one).collect() + } +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 369ed539..a1dfea93 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,6 +1,5 @@ mod connection; mod db_connection; -mod db_transaction; #[cfg(feature = "mock")] mod mock; mod statement; @@ -9,7 +8,6 @@ mod transaction; pub use connection::*; pub use db_connection::*; -pub use db_transaction::*; #[cfg(feature = "mock")] pub use mock::*; pub use statement::*; diff --git a/src/database/transaction.rs b/src/database/transaction.rs index 6bf06491..3757052e 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -1,42 +1,370 @@ -use crate::{DbBackend, Statement}; -use sea_query::{Value, Values}; +use crate::{ + debug_print, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, + Statement, TransactionStream, +}; +#[cfg(feature = "sqlx-dep")] +use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err}; +use futures::lock::Mutex; +#[cfg(feature = "sqlx-dep")] +use sqlx::{pool::PoolConnection, TransactionManager}; +use std::{future::Future, pin::Pin, sync::Arc}; -#[derive(Debug, Clone, PartialEq)] -pub struct Transaction { - stmts: Vec, +// a Transaction is just a sugar for a connection where START TRANSACTION has been executed +pub struct DatabaseTransaction { + conn: Arc>, + backend: DbBackend, + open: bool, } -impl Transaction { - pub fn from_sql_and_values(db_backend: DbBackend, sql: &str, values: I) -> Self - where - I: IntoIterator, - { - Self::one(Statement::from_string_values_tuple( - db_backend, - (sql.to_string(), Values(values.into_iter().collect())), - )) +impl std::fmt::Debug for DatabaseTransaction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DatabaseTransaction") + } +} + +impl DatabaseTransaction { + #[cfg(feature = "sqlx-mysql")] + pub(crate) async fn new_mysql( + inner: PoolConnection, + ) -> Result { + Self::begin( + Arc::new(Mutex::new(InnerConnection::MySql(inner))), + DbBackend::MySql, + ) + .await } - /// Create a Transaction with one statement - pub fn one(stmt: Statement) -> Self { - Self { stmts: vec![stmt] } + #[cfg(feature = "sqlx-postgres")] + pub(crate) async fn new_postgres( + inner: PoolConnection, + ) -> Result { + Self::begin( + Arc::new(Mutex::new(InnerConnection::Postgres(inner))), + DbBackend::Postgres, + ) + .await } - /// Create a Transaction with many statements - pub fn many(stmts: I) -> Self + #[cfg(feature = "sqlx-sqlite")] + pub(crate) async fn new_sqlite( + inner: PoolConnection, + ) -> Result { + Self::begin( + Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), + DbBackend::Sqlite, + ) + .await + } + + #[cfg(feature = "mock")] + pub(crate) async fn new_mock( + inner: Arc, + ) -> Result { + let backend = inner.get_database_backend(); + Self::begin(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await + } + + async fn begin( + conn: Arc>, + backend: DbBackend, + ) -> Result { + let res = DatabaseTransaction { + conn, + backend, + open: true, + }; + match *res.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::begin(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::begin(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::begin(c) + .await + .map_err(sqlx_error_to_query_err)? + } + // should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {} + } + Ok(res) + } + + pub(crate) async fn run(self, callback: F) -> Result> where - I: IntoIterator, + F: for<'b> FnOnce( + &'b DatabaseTransaction, + ) -> Pin> + Send + 'b>> + + Send, + T: Send, + E: std::error::Error + Send, { - Self { - stmts: stmts.into_iter().collect(), + let res = callback(&self) + .await + .map_err(|e| TransactionError::Transaction(e)); + if res.is_ok() { + self.commit() + .await + .map_err(|e| TransactionError::Connection(e))?; + } else { + self.rollback() + .await + .map_err(|e| TransactionError::Connection(e))?; + } + res + } + + pub async fn commit(mut self) -> Result<(), DbErr> { + self.open = false; + match *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::commit(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::commit(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::commit(c) + .await + .map_err(sqlx_error_to_query_err)? + } + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {} + } + Ok(()) + } + + pub async fn rollback(mut self) -> Result<(), DbErr> { + self.open = false; + match *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::rollback(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::rollback(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::rollback(c) + .await + .map_err(sqlx_error_to_query_err)? + } + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {} + } + Ok(()) + } + + // the rollback is queued and will be performed on next async operation, like returning the connection to the pool + fn start_rollback(&mut self) { + if self.open { + if let Some(mut conn) = self.conn.try_lock() { + match &mut *conn { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + ::TransactionManager::start_rollback(c); + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + ::TransactionManager::start_rollback(c); + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + ::TransactionManager::start_rollback(c); + } + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {} + } + } else { + //this should never happen + panic!("Dropping a locked Transaction"); + } + } + } +} + +impl Drop for DatabaseTransaction { + fn drop(&mut self) { + self.start_rollback(); + } +} + +#[async_trait::async_trait] +impl<'a> ConnectionTrait<'a> for DatabaseTransaction { + type Stream = TransactionStream<'a>; + + fn get_database_backend(&self) -> DbBackend { + // this way we don't need to lock + self.backend + } + + async fn execute(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query.execute(conn).await.map(Into::into) + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query.execute(conn).await.map(Into::into) + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query.execute(conn).await.map(Into::into) + } + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.execute(stmt), + }; + #[cfg(feature = "sqlx-dep")] + _res.map_err(sqlx_error_to_exec_err) + } + + async fn query_one(&self, stmt: Statement) -> Result, DbErr> { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query.fetch_one(conn).await.map(|row| Some(row.into())) + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query.fetch_one(conn).await.map(|row| Some(row.into())) + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query.fetch_one(conn).await.map(|row| Some(row.into())) + } + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.query_one(stmt), + }; + #[cfg(feature = "sqlx-dep")] + if let Err(sqlx::Error::RowNotFound) = _res { + Ok(None) + } else { + _res.map_err(sqlx_error_to_query_err) } } - /// Wrap each Statement as a single-statement Transaction - pub fn wrap(stmts: I) -> Vec + async fn query_all(&self, stmt: Statement) -> Result, DbErr> { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query + .fetch_all(conn) + .await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query + .fetch_all(conn) + .await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query + .fetch_all(conn) + .await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + } + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.query_all(stmt), + }; + #[cfg(feature = "sqlx-dep")] + _res.map_err(sqlx_error_to_query_err) + } + + fn stream( + &'a self, + stmt: Statement, + ) -> Pin> + 'a>> { + Box::pin( + async move { Ok(crate::TransactionStream::build(self.conn.lock().await, stmt).await) }, + ) + } + + async fn begin(&self) -> Result { + DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend).await + } + + /// Execute the function inside a transaction. + /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. + async fn transaction(&self, _callback: F) -> Result> where - I: IntoIterator, + F: for<'c> FnOnce( + &'c DatabaseTransaction, + ) -> Pin> + Send + 'c>> + + Send, + T: Send, + E: std::error::Error + Send, { - stmts.into_iter().map(Self::one).collect() + let transaction = self + .begin() + .await + .map_err(|e| TransactionError::Connection(e))?; + transaction.run(_callback).await } } + +#[derive(Debug)] +pub enum TransactionError +where + E: std::error::Error, +{ + Connection(DbErr), + Transaction(E), +} + +impl std::fmt::Display for TransactionError +where + E: std::error::Error, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TransactionError::Connection(e) => std::fmt::Display::fmt(e, f), + TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f), + } + } +} + +impl std::error::Error for TransactionError where E: std::error::Error {} From 632290469b649d40b45760ff00f687a1467c5508 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 4 Oct 2021 21:01:02 +0800 Subject: [PATCH 24/65] Fixup --- src/query/util.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/query/util.rs b/src/query/util.rs index 7e771f7c..545f8376 100644 --- a/src/query/util.rs +++ b/src/query/util.rs @@ -1,4 +1,4 @@ -use crate::{DatabaseConnection, DbBackend, QueryTrait, Statement}; +use crate::{database::*, QueryTrait, Statement}; #[derive(Debug)] pub struct DebugQuery<'a, Q, T> { @@ -38,7 +38,7 @@ debug_query_build!( /// /// ``` /// # #[cfg(feature = "mock")] -/// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; +/// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, DbBackend}; /// # /// # let conn = MockDatabase::new(DbBackend::Postgres) /// # .into_connection(); @@ -81,7 +81,7 @@ macro_rules! debug_query_stmt { /// /// ``` /// # #[cfg(feature = "mock")] -/// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; +/// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, DbBackend}; /// # /// # let conn = MockDatabase::new(DbBackend::Postgres) /// # .into_connection(); From df4df87d09643f47e062f54e5ffcfdac7c81d269 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 4 Oct 2021 20:44:29 +0800 Subject: [PATCH 25/65] Rename Transaction -> MockTransaction --- src/database/db_connection.rs | 2 +- src/database/mock.rs | 18 ++++++++--------- src/database/statement.rs | 2 +- src/docs.rs | 10 ++++----- src/driver/mock.rs | 6 +++--- src/entity/base_entity.rs | 38 +++++++++++++++++------------------ src/entity/model.rs | 4 ++-- src/executor/paginator.rs | 12 +++++------ src/executor/query.rs | 6 +++--- src/executor/select.rs | 24 +++++++++++----------- 10 files changed, 61 insertions(+), 61 deletions(-) diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 60d70ac7..d5612014 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -218,7 +218,7 @@ impl DatabaseConnection { } } - pub fn into_transaction_log(self) -> Vec { + pub fn into_transaction_log(self) -> Vec { let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap(); mocker.drain_transaction_log() } diff --git a/src/database/mock.rs b/src/database/mock.rs index f98c4b9d..012c941f 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -9,7 +9,7 @@ use std::{collections::BTreeMap, sync::Arc}; #[derive(Debug)] pub struct MockDatabase { db_backend: DbBackend, - transaction_log: Vec, + transaction_log: Vec, exec_results: Vec, query_results: Vec>, } @@ -30,7 +30,7 @@ pub trait IntoMockRow { } #[derive(Debug, Clone, PartialEq)] -pub struct Transaction { +pub struct MockTransaction { stmts: Vec, } @@ -67,7 +67,7 @@ impl MockDatabase { impl MockDatabaseTrait for MockDatabase { fn execute(&mut self, counter: usize, statement: Statement) -> Result { - self.transaction_log.push(Transaction::one(statement)); + self.transaction_log.push(MockTransaction::one(statement)); if counter < self.exec_results.len() { Ok(ExecResult { result: ExecResultHolder::Mock(std::mem::take(&mut self.exec_results[counter])), @@ -78,7 +78,7 @@ impl MockDatabaseTrait for MockDatabase { } fn query(&mut self, counter: usize, statement: Statement) -> Result, DbErr> { - self.transaction_log.push(Transaction::one(statement)); + self.transaction_log.push(MockTransaction::one(statement)); if counter < self.query_results.len() { Ok(std::mem::take(&mut self.query_results[counter]) .into_iter() @@ -91,7 +91,7 @@ impl MockDatabaseTrait for MockDatabase { } } - fn drain_transaction_log(&mut self) -> Vec { + fn drain_transaction_log(&mut self) -> Vec { std::mem::take(&mut self.transaction_log) } @@ -140,7 +140,7 @@ impl IntoMockRow for BTreeMap<&str, Value> { } } -impl Transaction { +impl MockTransaction { pub fn from_sql_and_values(db_backend: DbBackend, sql: &str, values: I) -> Self where I: IntoIterator, @@ -151,12 +151,12 @@ impl Transaction { )) } - /// Create a Transaction with one statement + /// Create a MockTransaction with one statement pub fn one(stmt: Statement) -> Self { Self { stmts: vec![stmt] } } - /// Create a Transaction with many statements + /// Create a MockTransaction with many statements pub fn many(stmts: I) -> Self where I: IntoIterator, @@ -166,7 +166,7 @@ impl Transaction { } } - /// Wrap each Statement as a single-statement Transaction + /// Wrap each Statement as a single-statement MockTransaction pub fn wrap(stmts: I) -> Vec where I: IntoIterator, diff --git a/src/database/statement.rs b/src/database/statement.rs index 12b07487..63a1d57f 100644 --- a/src/database/statement.rs +++ b/src/database/statement.rs @@ -104,4 +104,4 @@ build_schema_stmt!(sea_query::TableCreateStatement); build_schema_stmt!(sea_query::TableDropStatement); build_schema_stmt!(sea_query::TableAlterStatement); build_schema_stmt!(sea_query::TableRenameStatement); -build_schema_stmt!(sea_query::TableTruncateStatement); +build_schema_stmt!(sea_query::TableTruncateStatement); \ No newline at end of file diff --git a/src/docs.rs b/src/docs.rs index 4d1226c3..ec8542a6 100644 --- a/src/docs.rs +++ b/src/docs.rs @@ -3,7 +3,7 @@ //! Relying on [SQLx](https://github.com/launchbadge/sqlx), SeaORM is a new library with async support from day 1. //! //! ``` -//! # use sea_orm::{DbConn, error::*, entity::*, query::*, tests_cfg::*, DatabaseConnection, DbBackend, MockDatabase, Transaction, IntoMockRow}; +//! # use sea_orm::{DbConn, error::*, entity::*, query::*, tests_cfg::*, DatabaseConnection, DbBackend, MockDatabase, MockTransaction, IntoMockRow}; //! # let db = MockDatabase::new(DbBackend::Postgres) //! # .append_query_results(vec![ //! # vec![cake::Model { @@ -40,12 +40,12 @@ //! # assert_eq!( //! # db.into_transaction_log(), //! # vec![ -//! # Transaction::from_sql_and_values( +//! # MockTransaction::from_sql_and_values( //! # DbBackend::Postgres, //! # r#"SELECT "cake"."id", "cake"."name" FROM "cake""#, //! # vec![] //! # ), -//! # Transaction::from_sql_and_values( +//! # MockTransaction::from_sql_and_values( //! # DbBackend::Postgres, //! # r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#, //! # vec![] @@ -88,7 +88,7 @@ //! Use mock connections to write unit tests for your logic. //! //! ``` -//! # use sea_orm::{error::*, entity::*, query::*, tests_cfg::*, DbConn, MockDatabase, Transaction, DbBackend}; +//! # use sea_orm::{error::*, entity::*, query::*, tests_cfg::*, DbConn, MockDatabase, MockTransaction, DbBackend}; //! # async fn function(db: DbConn) -> Result<(), DbErr> { //! // Setup mock connection //! let db = MockDatabase::new(DbBackend::Postgres) @@ -115,7 +115,7 @@ //! assert_eq!( //! db.into_transaction_log(), //! vec![ -//! Transaction::from_sql_and_values( +//! MockTransaction::from_sql_and_values( //! DbBackend::Postgres, //! r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#, //! vec![1u64.into()] diff --git a/src/driver/mock.rs b/src/driver/mock.rs index 96317c6d..b932eafd 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -1,6 +1,6 @@ use crate::{ - debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult, - Statement, Transaction, + debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, + MockTransaction, QueryResult, Statement, }; use futures::Stream; use std::{ @@ -26,7 +26,7 @@ pub trait MockDatabaseTrait: Send + Debug { fn query(&mut self, counter: usize, stmt: Statement) -> Result, DbErr>; - fn drain_transaction_log(&mut self) -> Vec; + fn drain_transaction_log(&mut self) -> Vec; fn get_database_backend(&self) -> DbBackend; } diff --git a/src/entity/base_entity.rs b/src/entity/base_entity.rs index aef46207..2a40c686 100644 --- a/src/entity/base_entity.rs +++ b/src/entity/base_entity.rs @@ -82,7 +82,7 @@ pub trait EntityTrait: EntityName { /// /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_query_results(vec![ @@ -137,12 +137,12 @@ pub trait EntityTrait: EntityName { /// assert_eq!( /// db.into_transaction_log(), /// vec![ - /// Transaction::from_sql_and_values( + /// MockTransaction::from_sql_and_values( /// DbBackend::Postgres, /// r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#, /// vec![1u64.into()] /// ), - /// Transaction::from_sql_and_values( + /// MockTransaction::from_sql_and_values( /// DbBackend::Postgres, /// r#"SELECT "cake"."id", "cake"."name" FROM "cake""#, /// vec![] @@ -160,7 +160,7 @@ pub trait EntityTrait: EntityName { /// /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_query_results(vec![ @@ -190,7 +190,7 @@ pub trait EntityTrait: EntityName { /// /// assert_eq!( /// db.into_transaction_log(), - /// vec![Transaction::from_sql_and_values( + /// vec![MockTransaction::from_sql_and_values( /// DbBackend::Postgres, /// r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1"#, /// vec![11i32.into()] @@ -200,7 +200,7 @@ pub trait EntityTrait: EntityName { /// Find by composite key /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_query_results(vec![ @@ -230,7 +230,7 @@ pub trait EntityTrait: EntityName { /// /// assert_eq!( /// db.into_transaction_log(), - /// vec![Transaction::from_sql_and_values( + /// vec![MockTransaction::from_sql_and_values( /// DbBackend::Postgres, /// [ /// r#"SELECT "cake_filling"."cake_id", "cake_filling"."filling_id" FROM "cake_filling""#, @@ -262,7 +262,7 @@ pub trait EntityTrait: EntityName { /// /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_exec_results(vec![ @@ -292,7 +292,7 @@ pub trait EntityTrait: EntityName { /// /// assert_eq!( /// db.into_transaction_log(), - /// vec![Transaction::from_sql_and_values( + /// vec![MockTransaction::from_sql_and_values( /// DbBackend::Postgres, r#"INSERT INTO "cake" ("name") VALUES ($1) RETURNING "id""#, vec!["Apple Pie".into()] /// )]); /// ``` @@ -309,7 +309,7 @@ pub trait EntityTrait: EntityName { /// /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_exec_results(vec![ @@ -343,7 +343,7 @@ pub trait EntityTrait: EntityName { /// /// assert_eq!( /// db.into_transaction_log(), - /// vec![Transaction::from_sql_and_values( + /// vec![MockTransaction::from_sql_and_values( /// DbBackend::Postgres, r#"INSERT INTO "cake" ("name") VALUES ($1), ($2) RETURNING "id""#, /// vec!["Apple Pie".into(), "Orange Scone".into()] /// )]); @@ -364,7 +364,7 @@ pub trait EntityTrait: EntityName { /// /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_exec_results(vec![ @@ -398,7 +398,7 @@ pub trait EntityTrait: EntityName { /// /// assert_eq!( /// db.into_transaction_log(), - /// vec![Transaction::from_sql_and_values( + /// vec![MockTransaction::from_sql_and_values( /// DbBackend::Postgres, r#"UPDATE "fruit" SET "name" = $1 WHERE "fruit"."id" = $2 AND "fruit"."name" LIKE $3"#, /// vec!["Orange".into(), 1i32.into(), "%orange%".into()] /// )]); @@ -418,7 +418,7 @@ pub trait EntityTrait: EntityName { /// /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_exec_results(vec![ @@ -446,7 +446,7 @@ pub trait EntityTrait: EntityName { /// /// assert_eq!( /// db.into_transaction_log(), - /// vec![Transaction::from_sql_and_values( + /// vec![MockTransaction::from_sql_and_values( /// DbBackend::Postgres, r#"UPDATE "fruit" SET "cake_id" = $1 WHERE "fruit"."name" LIKE $2"#, vec![Value::Int(None), "%Apple%".into()] /// )]); /// ``` @@ -462,7 +462,7 @@ pub trait EntityTrait: EntityName { /// /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_exec_results(vec![ @@ -491,7 +491,7 @@ pub trait EntityTrait: EntityName { /// /// assert_eq!( /// db.into_transaction_log(), - /// vec![Transaction::from_sql_and_values( + /// vec![MockTransaction::from_sql_and_values( /// DbBackend::Postgres, r#"DELETE FROM "fruit" WHERE "fruit"."id" = $1"#, vec![3i32.into()] /// )]); /// ``` @@ -510,7 +510,7 @@ pub trait EntityTrait: EntityName { /// /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{entity::*, error::*, query::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; + /// # use sea_orm::{entity::*, error::*, query::*, tests_cfg::*, MockDatabase, MockExecResult, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_exec_results(vec![ @@ -537,7 +537,7 @@ pub trait EntityTrait: EntityName { /// /// assert_eq!( /// db.into_transaction_log(), - /// vec![Transaction::from_sql_and_values( + /// vec![MockTransaction::from_sql_and_values( /// DbBackend::Postgres, r#"DELETE FROM "fruit" WHERE "fruit"."name" LIKE $1"#, vec!["%Apple%".into()] /// )]); /// ``` diff --git a/src/entity/model.rs b/src/entity/model.rs index 318b8e70..f8225ab8 100644 --- a/src/entity/model.rs +++ b/src/entity/model.rs @@ -38,7 +38,7 @@ pub trait FromQueryResult: Sized { /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_query_results(vec![vec![ @@ -81,7 +81,7 @@ pub trait FromQueryResult: Sized { /// # }); /// # assert_eq!( /// # db.into_transaction_log(), - /// # vec![Transaction::from_sql_and_values( + /// # vec![MockTransaction::from_sql_and_values( /// # DbBackend::Postgres, /// # r#"SELECT "name", COUNT(*) AS "num_of_cakes" FROM "cake" GROUP BY("name")"#, /// # vec![] diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index 28f8574b..712c0afa 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -158,7 +158,7 @@ where mod tests { use crate::entity::prelude::*; use crate::{tests_cfg::*, ConnectionTrait}; - use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction}; + use crate::{DatabaseConnection, DbBackend, MockDatabase, MockTransaction}; use futures::TryStreamExt; use sea_query::{Alias, Expr, SelectStatement, Value}; @@ -228,7 +228,7 @@ mod tests { query_builder.build(select.offset(4).limit(2)), ]; - assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts)); + assert_eq!(db.into_transaction_log(), MockTransaction::wrap(stmts)); Ok(()) } @@ -262,7 +262,7 @@ mod tests { query_builder.build(select.offset(4).limit(2)), ]; - assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts)); + assert_eq!(db.into_transaction_log(), MockTransaction::wrap(stmts)); Ok(()) } @@ -294,7 +294,7 @@ mod tests { let query_builder = db.get_database_backend(); let stmts = vec![query_builder.build(&select)]; - assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts)); + assert_eq!(db.into_transaction_log(), MockTransaction::wrap(stmts)); Ok(()) } @@ -345,7 +345,7 @@ mod tests { query_builder.build(select.offset(4).limit(2)), ]; - assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts)); + assert_eq!(db.into_transaction_log(), MockTransaction::wrap(stmts)); Ok(()) } @@ -377,7 +377,7 @@ mod tests { query_builder.build(select.offset(4).limit(2)), ]; - assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts)); + assert_eq!(db.into_transaction_log(), MockTransaction::wrap(stmts)); Ok(()) } } diff --git a/src/executor/query.rs b/src/executor/query.rs index 0248fa5c..10748036 100644 --- a/src/executor/query.rs +++ b/src/executor/query.rs @@ -311,7 +311,7 @@ pub trait TryGetableMany: Sized { /// ``` /// # #[cfg(all(feature = "mock", feature = "macros"))] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_query_results(vec![vec![ @@ -326,7 +326,7 @@ pub trait TryGetableMany: Sized { /// # ]]) /// # .into_connection(); /// # - /// use sea_orm::{entity::*, query::*, tests_cfg::cake, EnumIter, DeriveIden, TryGetableMany}; + /// use sea_orm::{entity::*, query::*, tests_cfg::cake, DeriveIden, EnumIter, TryGetableMany}; /// /// #[derive(EnumIter, DeriveIden)] /// enum ResultCol { @@ -358,7 +358,7 @@ pub trait TryGetableMany: Sized { /// /// assert_eq!( /// db.into_transaction_log(), - /// vec![Transaction::from_sql_and_values( + /// vec![MockTransaction::from_sql_and_values( /// DbBackend::Postgres, /// r#"SELECT "cake"."name", count("cake"."id") AS "num_of_cakes" FROM "cake""#, /// vec![] diff --git a/src/executor/select.rs b/src/executor/select.rs index f4ff69d4..68984955 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -132,7 +132,7 @@ where /// ``` /// # #[cfg(all(feature = "mock", feature = "macros"))] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_query_results(vec![vec![ @@ -171,7 +171,7 @@ where /// /// assert_eq!( /// db.into_transaction_log(), - /// vec![Transaction::from_sql_and_values( + /// vec![MockTransaction::from_sql_and_values( /// DbBackend::Postgres, /// r#"SELECT "cake"."name" AS "cake_name" FROM "cake""#, /// vec![] @@ -181,7 +181,7 @@ where /// /// ``` /// # #[cfg(all(feature = "mock", feature = "macros"))] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_query_results(vec![vec![ @@ -221,7 +221,7 @@ where /// /// assert_eq!( /// db.into_transaction_log(), - /// vec![Transaction::from_sql_and_values( + /// vec![MockTransaction::from_sql_and_values( /// DbBackend::Postgres, /// vec![ /// r#"SELECT "cake"."name" AS "cake_name", COUNT("cake"."id") AS "num_of_cakes""#, @@ -521,7 +521,7 @@ where /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_query_results(vec![vec![ @@ -575,7 +575,7 @@ where /// /// assert_eq!( /// db.into_transaction_log(), - /// vec![Transaction::from_sql_and_values( + /// vec![MockTransaction::from_sql_and_values( /// DbBackend::Postgres, /// r#"SELECT "cake"."name", count("cake"."id") AS "num_of_cakes" FROM "cake""#, /// vec![] @@ -594,7 +594,7 @@ where /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_query_results(vec![vec![ @@ -642,7 +642,7 @@ where /// assert_eq!( /// db.into_transaction_log(), /// vec![ - /// Transaction::from_sql_and_values( + /// MockTransaction::from_sql_and_values( /// DbBackend::Postgres, r#"SELECT "cake"."id", "cake"."name" FROM "cake""#, vec![] /// ), /// ]); @@ -657,7 +657,7 @@ where /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres).into_connection(); /// # @@ -679,7 +679,7 @@ where /// /// assert_eq!( /// db.into_transaction_log(), - /// vec![Transaction::from_sql_and_values( + /// vec![MockTransaction::from_sql_and_values( /// DbBackend::Postgres, /// r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "id" = $1"#, /// vec![1.into()] @@ -699,7 +699,7 @@ where /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, Transaction, DbBackend}; + /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockTransaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres).into_connection(); /// # @@ -721,7 +721,7 @@ where /// /// assert_eq!( /// db.into_transaction_log(), - /// vec![Transaction::from_sql_and_values( + /// vec![MockTransaction::from_sql_and_values( /// DbBackend::Postgres, /// r#"SELECT "cake"."id", "cake"."name" FROM "cake""#, /// vec![] From 8d0ba28b7eb312b5077f41ffd801ced2b272ba8d Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 4 Oct 2021 22:40:17 +0800 Subject: [PATCH 26/65] Fix clippy warning --- src/database/db_connection.rs | 7 ++----- src/database/transaction.rs | 8 ++++---- src/driver/mock.rs | 2 +- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index d5612014..fff654a3 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -193,7 +193,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { DatabaseConnection::MockDatabaseConnection(conn) => { let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)) .await - .map_err(|e| TransactionError::Connection(e))?; + .map_err(TransactionError::Connection)?; transaction.run(_callback).await } DatabaseConnection::Disconnected => panic!("Disconnected"), @@ -202,10 +202,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { #[cfg(feature = "mock")] fn is_mock_connection(&self) -> bool { - match self { - DatabaseConnection::MockDatabaseConnection(_) => true, - _ => false, - } + matches!(self, DatabaseConnection::MockDatabaseConnection(_)) } } diff --git a/src/database/transaction.rs b/src/database/transaction.rs index 3757052e..7230488b 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -110,15 +110,15 @@ impl DatabaseTransaction { { let res = callback(&self) .await - .map_err(|e| TransactionError::Transaction(e)); + .map_err(TransactionError::Transaction); if res.is_ok() { self.commit() .await - .map_err(|e| TransactionError::Connection(e))?; + .map_err(TransactionError::Connection)?; } else { self.rollback() .await - .map_err(|e| TransactionError::Connection(e))?; + .map_err(TransactionError::Connection)?; } res } @@ -341,7 +341,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { let transaction = self .begin() .await - .map_err(|e| TransactionError::Connection(e))?; + .map_err(TransactionError::Connection)?; transaction.run(_callback).await } } diff --git a/src/driver/mock.rs b/src/driver/mock.rs index b932eafd..ad6b2298 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -114,7 +114,7 @@ impl MockDatabaseConnection { statement: &Statement, ) -> Pin>>> { match self.query_all(statement.clone()) { - Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(|r| Ok(r)))), + Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(Ok))), Err(e) => Box::pin(futures::stream::iter(Some(Err(e)).into_iter())), } } From 19a572b72195c54814ae0755e09876b65089f3f7 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Mon, 4 Oct 2021 23:30:20 +0800 Subject: [PATCH 27/65] Escape rust keywords with `r#` raw identifier --- sea-orm-macros/src/derives/active_model.rs | 10 +-- sea-orm-macros/src/derives/entity_model.rs | 12 ++-- sea-orm-macros/src/derives/model.rs | 13 ++-- sea-orm-macros/src/util.rs | 40 ++++++++++++ src/tests_cfg/mod.rs | 2 + src/tests_cfg/rust_keyword.rs | 71 ++++++++++++++++++++++ 6 files changed, 134 insertions(+), 14 deletions(-) create mode 100644 src/tests_cfg/rust_keyword.rs diff --git a/sea-orm-macros/src/derives/active_model.rs b/sea-orm-macros/src/derives/active_model.rs index 2227f09b..85bdcb69 100644 --- a/sea-orm-macros/src/derives/active_model.rs +++ b/sea-orm-macros/src/derives/active_model.rs @@ -1,4 +1,4 @@ -use crate::util::field_not_ignored; +use crate::util::{escape_rust_keyword, field_not_ignored, trim_starting_raw_identifier}; use heck::CamelCase; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote, quote_spanned}; @@ -29,10 +29,10 @@ pub fn expand_derive_active_model(ident: Ident, data: Data) -> syn::Result) -> syn::Result { // if #[sea_orm(table_name = "foo", schema_name = "bar")] specified, create Entity struct let mut table_name = None; @@ -60,8 +60,10 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res if let Fields::Named(fields) = item_struct.fields { for field in fields.named { if let Some(ident) = &field.ident { - let mut field_name = - Ident::new(&ident.to_string().to_case(Case::Pascal), Span::call_site()); + let mut field_name = Ident::new( + &trim_starting_raw_identifier(&ident).to_case(Case::Pascal), + Span::call_site(), + ); let mut nullable = false; let mut default_value = None; @@ -168,6 +170,8 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res field_name = enum_name; } + field_name = Ident::new(&escape_rust_keyword(field_name), Span::call_site()); + if ignore { continue; } else { diff --git a/sea-orm-macros/src/derives/model.rs b/sea-orm-macros/src/derives/model.rs index a43b487f..29a597b9 100644 --- a/sea-orm-macros/src/derives/model.rs +++ b/sea-orm-macros/src/derives/model.rs @@ -1,4 +1,7 @@ -use crate::{attributes::derive_attr, util::field_not_ignored}; +use crate::{ + attributes::derive_attr, + util::{escape_rust_keyword, field_not_ignored, trim_starting_raw_identifier}, +}; use heck::CamelCase; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; @@ -43,10 +46,10 @@ impl DeriveModel { let column_idents = fields .iter() .map(|field| { - let mut ident = format_ident!( - "{}", - field.ident.as_ref().unwrap().to_string().to_camel_case() - ); + let ident = field.ident.as_ref().unwrap().to_string(); + let ident = trim_starting_raw_identifier(ident).to_camel_case(); + let ident = escape_rust_keyword(ident); + let mut ident = format_ident!("{}", &ident); for attr in field.attrs.iter() { if let Some(ident) = attr.path.get_ident() { if ident != "sea_orm" { diff --git a/sea-orm-macros/src/util.rs b/sea-orm-macros/src/util.rs index 7dda1087..8929e9e8 100644 --- a/sea-orm-macros/src/util.rs +++ b/sea-orm-macros/src/util.rs @@ -24,3 +24,43 @@ pub(crate) fn field_not_ignored(field: &Field) -> bool { } true } + +pub(crate) fn trim_starting_raw_identifier(string: T) -> String +where + T: ToString, +{ + string + .to_string() + .trim_start_matches(RAW_IDENTIFIER) + .to_string() +} + +pub(crate) fn escape_rust_keyword(string: T) -> String +where + T: ToString, +{ + let string = string.to_string(); + if is_rust_keyword(&string) { + format!("r#{}", string) + } else { + string + } +} + +pub(crate) fn is_rust_keyword(string: T) -> bool +where + T: ToString, +{ + let string = string.to_string(); + RUST_KEYWORDS.iter().any(|s| s.eq(&string)) +} + +pub(crate) const RAW_IDENTIFIER: &str = "r#"; + +pub(crate) const RUST_KEYWORDS: [&str; 52] = [ + "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern", + "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", + "ref", "return", "Self", "self", "static", "struct", "super", "trait", "true", "type", "union", + "unsafe", "use", "where", "while", "abstract", "become", "box", "do", "final", "macro", + "override", "priv", "try", "typeof", "unsized", "virtual", "yield", +]; diff --git a/src/tests_cfg/mod.rs b/src/tests_cfg/mod.rs index 6bc86aed..d6c80b36 100644 --- a/src/tests_cfg/mod.rs +++ b/src/tests_cfg/mod.rs @@ -7,6 +7,7 @@ pub mod cake_filling_price; pub mod entity_linked; pub mod filling; pub mod fruit; +pub mod rust_keyword; pub mod vendor; pub use cake::Entity as Cake; @@ -15,4 +16,5 @@ pub use cake_filling::Entity as CakeFilling; pub use cake_filling_price::Entity as CakeFillingPrice; pub use filling::Entity as Filling; pub use fruit::Entity as Fruit; +pub use rust_keyword::Entity as RustKeyword; pub use vendor::Entity as Vendor; diff --git a/src/tests_cfg/rust_keyword.rs b/src/tests_cfg/rust_keyword.rs new file mode 100644 index 00000000..90671a34 --- /dev/null +++ b/src/tests_cfg/rust_keyword.rs @@ -0,0 +1,71 @@ +use crate as sea_orm; +use crate::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "rust_keyword")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub testing: i32, + pub rust: i32, + pub keywords: i32, + pub r#as: i32, + pub r#async: i32, + pub r#await: i32, + pub r#break: i32, + pub r#const: i32, + pub r#continue: i32, + pub r#dyn: i32, + pub r#else: i32, + pub r#enum: i32, + pub r#extern: i32, + pub r#false: i32, + pub r#fn: i32, + pub r#for: i32, + pub r#if: i32, + pub r#impl: i32, + pub r#in: i32, + pub r#let: i32, + pub r#loop: i32, + pub r#match: i32, + pub r#mod: i32, + pub r#move: i32, + pub r#mut: i32, + pub r#pub: i32, + pub r#ref: i32, + pub r#return: i32, + pub r#static: i32, + pub r#struct: i32, + pub r#trait: i32, + pub r#true: i32, + pub r#type: i32, + pub r#union: i32, + pub r#unsafe: i32, + pub r#use: i32, + pub r#where: i32, + pub r#while: i32, + pub r#abstract: i32, + pub r#become: i32, + pub r#box: i32, + pub r#do: i32, + pub r#final: i32, + pub r#macro: i32, + pub r#override: i32, + pub r#priv: i32, + pub r#try: i32, + pub r#typeof: i32, + pub r#unsized: i32, + pub r#virtual: i32, + pub r#yield: i32, +} + +#[derive(Debug, EnumIter)] +pub enum Relation {} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + unreachable!() + } +} + +impl ActiveModelBehavior for ActiveModel {} From f3f24320e9a8ffdbb166d5f5b92c4da513bcfd8f Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Tue, 5 Oct 2021 10:23:45 +0800 Subject: [PATCH 28/65] Add test cases --- src/tests_cfg/rust_keyword.rs | 64 +++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/src/tests_cfg/rust_keyword.rs b/src/tests_cfg/rust_keyword.rs index 90671a34..299e9c6f 100644 --- a/src/tests_cfg/rust_keyword.rs +++ b/src/tests_cfg/rust_keyword.rs @@ -9,6 +9,7 @@ pub struct Model { pub testing: i32, pub rust: i32, pub keywords: i32, + pub r#raw_identifier: i32, pub r#as: i32, pub r#async: i32, pub r#await: i32, @@ -69,3 +70,66 @@ impl RelationTrait for Relation { } impl ActiveModelBehavior for ActiveModel {} + +#[cfg(test)] +mod tests { + use crate::tests_cfg::*; + use sea_query::Iden; + + #[test] + fn test_columns() { + assert_eq!(rust_keyword::Column::Id.to_string(), "id".to_owned()); + assert_eq!(rust_keyword::Column::Testing.to_string(), "testing".to_owned()); + assert_eq!(rust_keyword::Column::Rust.to_string(), "rust".to_owned()); + assert_eq!(rust_keyword::Column::Keywords.to_string(), "keywords".to_owned()); + assert_eq!(rust_keyword::Column::RawIdentifier.to_string(), "raw_identifier".to_owned()); + assert_eq!(rust_keyword::Column::As.to_string(), "as".to_owned()); + assert_eq!(rust_keyword::Column::Async.to_string(), "async".to_owned()); + assert_eq!(rust_keyword::Column::Await.to_string(), "await".to_owned()); + assert_eq!(rust_keyword::Column::Break.to_string(), "break".to_owned()); + assert_eq!(rust_keyword::Column::Const.to_string(), "const".to_owned()); + assert_eq!(rust_keyword::Column::Continue.to_string(), "continue".to_owned()); + assert_eq!(rust_keyword::Column::Dyn.to_string(), "dyn".to_owned()); + assert_eq!(rust_keyword::Column::Else.to_string(), "else".to_owned()); + assert_eq!(rust_keyword::Column::Enum.to_string(), "enum".to_owned()); + assert_eq!(rust_keyword::Column::Extern.to_string(), "extern".to_owned()); + assert_eq!(rust_keyword::Column::False.to_string(), "false".to_owned()); + assert_eq!(rust_keyword::Column::Fn.to_string(), "fn".to_owned()); + assert_eq!(rust_keyword::Column::For.to_string(), "for".to_owned()); + assert_eq!(rust_keyword::Column::If.to_string(), "if".to_owned()); + assert_eq!(rust_keyword::Column::Impl.to_string(), "impl".to_owned()); + assert_eq!(rust_keyword::Column::In.to_string(), "in".to_owned()); + assert_eq!(rust_keyword::Column::Let.to_string(), "let".to_owned()); + assert_eq!(rust_keyword::Column::Loop.to_string(), "loop".to_owned()); + assert_eq!(rust_keyword::Column::Match.to_string(), "match".to_owned()); + assert_eq!(rust_keyword::Column::Mod.to_string(), "mod".to_owned()); + assert_eq!(rust_keyword::Column::Move.to_string(), "move".to_owned()); + assert_eq!(rust_keyword::Column::Mut.to_string(), "mut".to_owned()); + assert_eq!(rust_keyword::Column::Pub.to_string(), "pub".to_owned()); + assert_eq!(rust_keyword::Column::Ref.to_string(), "ref".to_owned()); + assert_eq!(rust_keyword::Column::Return.to_string(), "return".to_owned()); + assert_eq!(rust_keyword::Column::Static.to_string(), "static".to_owned()); + assert_eq!(rust_keyword::Column::Struct.to_string(), "struct".to_owned()); + assert_eq!(rust_keyword::Column::Trait.to_string(), "trait".to_owned()); + assert_eq!(rust_keyword::Column::True.to_string(), "true".to_owned()); + assert_eq!(rust_keyword::Column::Type.to_string(), "type".to_owned()); + assert_eq!(rust_keyword::Column::Union.to_string(), "union".to_owned()); + assert_eq!(rust_keyword::Column::Unsafe.to_string(), "unsafe".to_owned()); + assert_eq!(rust_keyword::Column::Use.to_string(), "use".to_owned()); + assert_eq!(rust_keyword::Column::Where.to_string(), "where".to_owned()); + assert_eq!(rust_keyword::Column::While.to_string(), "while".to_owned()); + assert_eq!(rust_keyword::Column::Abstract.to_string(), "abstract".to_owned()); + assert_eq!(rust_keyword::Column::Become.to_string(), "become".to_owned()); + assert_eq!(rust_keyword::Column::Box.to_string(), "box".to_owned()); + assert_eq!(rust_keyword::Column::Do.to_string(), "do".to_owned()); + assert_eq!(rust_keyword::Column::Final.to_string(), "final".to_owned()); + assert_eq!(rust_keyword::Column::Macro.to_string(), "macro".to_owned()); + assert_eq!(rust_keyword::Column::Override.to_string(), "override".to_owned()); + assert_eq!(rust_keyword::Column::Priv.to_string(), "priv".to_owned()); + assert_eq!(rust_keyword::Column::Try.to_string(), "try".to_owned()); + assert_eq!(rust_keyword::Column::Typeof.to_string(), "typeof".to_owned()); + assert_eq!(rust_keyword::Column::Unsized.to_string(), "unsized".to_owned()); + assert_eq!(rust_keyword::Column::Virtual.to_string(), "virtual".to_owned()); + assert_eq!(rust_keyword::Column::Yield.to_string(), "yield".to_owned()); + } +} From 7779ac886eac1609fc1d44a96dea1f00919e2e77 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Tue, 5 Oct 2021 10:49:06 +0800 Subject: [PATCH 29/65] Escape rust keyword on codegen --- sea-orm-codegen/src/entity/column.rs | 5 +- sea-orm-codegen/src/entity/writer.rs | 58 ++++++++- sea-orm-codegen/src/lib.rs | 1 + sea-orm-codegen/src/util.rs | 27 +++++ sea-orm-codegen/tests/compact/rust_keyword.rs | 28 +++++ .../tests/expanded/rust_keyword.rs | 73 +++++++++++ src/tests_cfg/rust_keyword.rs | 114 +++++++++--------- 7 files changed, 246 insertions(+), 60 deletions(-) create mode 100644 sea-orm-codegen/src/util.rs create mode 100644 sea-orm-codegen/tests/compact/rust_keyword.rs create mode 100644 sea-orm-codegen/tests/expanded/rust_keyword.rs diff --git a/sea-orm-codegen/src/entity/column.rs b/sea-orm-codegen/src/entity/column.rs index 532f2e91..d3d47cb1 100644 --- a/sea-orm-codegen/src/entity/column.rs +++ b/sea-orm-codegen/src/entity/column.rs @@ -1,3 +1,4 @@ +use crate::util::escape_rust_keyword; use heck::{CamelCase, SnakeCase}; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; @@ -14,11 +15,11 @@ pub struct Column { impl Column { pub fn get_name_snake_case(&self) -> Ident { - format_ident!("{}", self.name.to_snake_case()) + format_ident!("{}", escape_rust_keyword(self.name.to_snake_case())) } pub fn get_name_camel_case(&self) -> Ident { - format_ident!("{}", self.name.to_camel_case()) + format_ident!("{}", escape_rust_keyword(self.name.to_camel_case())) } pub fn get_rs_type(&self) -> TokenStream { diff --git a/sea-orm-codegen/src/entity/writer.rs b/sea-orm-codegen/src/entity/writer.rs index 59f54537..5d064f88 100644 --- a/sea-orm-codegen/src/entity/writer.rs +++ b/sea-orm-codegen/src/entity/writer.rs @@ -597,18 +597,71 @@ mod tests { name: "id".to_owned(), }], }, + Entity { + table_name: "rust_keyword".to_owned(), + columns: vec![ + Column { + name: "id".to_owned(), + col_type: ColumnType::Integer(Some(11)), + auto_increment: true, + not_null: true, + unique: false, + }, + Column { + name: "testing".to_owned(), + col_type: ColumnType::Integer(Some(11)), + auto_increment: false, + not_null: true, + unique: false, + }, + Column { + name: "rust".to_owned(), + col_type: ColumnType::Integer(Some(11)), + auto_increment: false, + not_null: true, + unique: false, + }, + Column { + name: "keywords".to_owned(), + col_type: ColumnType::Integer(Some(11)), + auto_increment: false, + not_null: true, + unique: false, + }, + Column { + name: "type".to_owned(), + col_type: ColumnType::Integer(Some(11)), + auto_increment: false, + not_null: true, + unique: false, + }, + Column { + name: "typeof".to_owned(), + col_type: ColumnType::Integer(Some(11)), + auto_increment: false, + not_null: true, + unique: false, + }, + ], + relations: vec![], + conjunct_relations: vec![], + primary_keys: vec![PrimaryKey { + name: "id".to_owned(), + }], + }, ] } #[test] fn test_gen_expanded_code_blocks() -> io::Result<()> { let entities = setup(); - const ENTITY_FILES: [&str; 5] = [ + const ENTITY_FILES: [&str; 6] = [ include_str!("../../tests/expanded/cake.rs"), include_str!("../../tests/expanded/cake_filling.rs"), include_str!("../../tests/expanded/filling.rs"), include_str!("../../tests/expanded/fruit.rs"), include_str!("../../tests/expanded/vendor.rs"), + include_str!("../../tests/expanded/rust_keyword.rs"), ]; assert_eq!(entities.len(), ENTITY_FILES.len()); @@ -642,12 +695,13 @@ mod tests { #[test] fn test_gen_compact_code_blocks() -> io::Result<()> { let entities = setup(); - const ENTITY_FILES: [&str; 5] = [ + const ENTITY_FILES: [&str; 6] = [ include_str!("../../tests/compact/cake.rs"), include_str!("../../tests/compact/cake_filling.rs"), include_str!("../../tests/compact/filling.rs"), include_str!("../../tests/compact/fruit.rs"), include_str!("../../tests/compact/vendor.rs"), + include_str!("../../tests/compact/rust_keyword.rs"), ]; assert_eq!(entities.len(), ENTITY_FILES.len()); diff --git a/sea-orm-codegen/src/lib.rs b/sea-orm-codegen/src/lib.rs index 07e167bc..5e637de1 100644 --- a/sea-orm-codegen/src/lib.rs +++ b/sea-orm-codegen/src/lib.rs @@ -1,5 +1,6 @@ mod entity; mod error; +mod util; pub use entity::*; pub use error::*; diff --git a/sea-orm-codegen/src/util.rs b/sea-orm-codegen/src/util.rs new file mode 100644 index 00000000..e752215b --- /dev/null +++ b/sea-orm-codegen/src/util.rs @@ -0,0 +1,27 @@ +pub(crate) fn escape_rust_keyword(string: T) -> String +where + T: ToString, +{ + let string = string.to_string(); + if is_rust_keyword(&string) { + format!("r#{}", string) + } else { + string + } +} + +pub(crate) fn is_rust_keyword(string: T) -> bool +where + T: ToString, +{ + let string = string.to_string(); + RUST_KEYWORDS.iter().any(|s| s.eq(&string)) +} + +pub(crate) const RUST_KEYWORDS: [&str; 52] = [ + "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern", + "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", + "ref", "return", "Self", "self", "static", "struct", "super", "trait", "true", "type", "union", + "unsafe", "use", "where", "while", "abstract", "become", "box", "do", "final", "macro", + "override", "priv", "try", "typeof", "unsized", "virtual", "yield", +]; diff --git a/sea-orm-codegen/tests/compact/rust_keyword.rs b/sea-orm-codegen/tests/compact/rust_keyword.rs new file mode 100644 index 00000000..9e51bafd --- /dev/null +++ b/sea-orm-codegen/tests/compact/rust_keyword.rs @@ -0,0 +1,28 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "rust_keyword")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub testing: i32, + pub rust: i32, + pub keywords: i32, + pub r#type: i32, + pub r#typeof: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter)] +pub enum Relation {} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + match self { + _ => panic!("No RelationDef"), + } + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/sea-orm-codegen/tests/expanded/rust_keyword.rs b/sea-orm-codegen/tests/expanded/rust_keyword.rs new file mode 100644 index 00000000..5c24d71c --- /dev/null +++ b/sea-orm-codegen/tests/expanded/rust_keyword.rs @@ -0,0 +1,73 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +use sea_orm::entity::prelude::*; + +#[derive(Copy, Clone, Default, Debug, DeriveEntity)] +pub struct Entity; + +impl EntityName for Entity { + fn table_name(&self) -> &str { + "rust_keyword" + } +} + +#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel)] +pub struct Model { + pub id: i32, + pub testing: i32, + pub rust: i32, + pub keywords: i32, + pub r#type: i32, + pub r#typeof: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +pub enum Column { + Id, + Testing, + Rust, + Keywords, + Type, + Typeof, +} + +#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)] +pub enum PrimaryKey { + Id, +} + +impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + + fn auto_increment() -> bool { + true + } +} + +#[derive(Copy, Clone, Debug, EnumIter)] +pub enum Relation {} + +impl ColumnTrait for Column { + type EntityName = Entity; + + fn def(&self) -> ColumnDef { + match self { + Self::Id => ColumnType::Integer.def(), + Self::Testing => ColumnType::Integer.def(), + Self::Rust => ColumnType::Integer.def(), + Self::Keywords => ColumnType::Integer.def(), + Self::Type => ColumnType::Integer.def(), + Self::Typeof => ColumnType::Integer.def(), + } + } +} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + match self { + _ => panic!("No RelationDef"), + } + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/src/tests_cfg/rust_keyword.rs b/src/tests_cfg/rust_keyword.rs index 299e9c6f..30052db6 100644 --- a/src/tests_cfg/rust_keyword.rs +++ b/src/tests_cfg/rust_keyword.rs @@ -60,12 +60,14 @@ pub struct Model { pub r#yield: i32, } -#[derive(Debug, EnumIter)] +#[derive(Copy, Clone, Debug, EnumIter)] pub enum Relation {} impl RelationTrait for Relation { fn def(&self) -> RelationDef { - unreachable!() + match self { + _ => panic!("No RelationDef"), + } } } @@ -73,63 +75,63 @@ impl ActiveModelBehavior for ActiveModel {} #[cfg(test)] mod tests { - use crate::tests_cfg::*; + use crate::tests_cfg::rust_keyword::*; use sea_query::Iden; #[test] fn test_columns() { - assert_eq!(rust_keyword::Column::Id.to_string(), "id".to_owned()); - assert_eq!(rust_keyword::Column::Testing.to_string(), "testing".to_owned()); - assert_eq!(rust_keyword::Column::Rust.to_string(), "rust".to_owned()); - assert_eq!(rust_keyword::Column::Keywords.to_string(), "keywords".to_owned()); - assert_eq!(rust_keyword::Column::RawIdentifier.to_string(), "raw_identifier".to_owned()); - assert_eq!(rust_keyword::Column::As.to_string(), "as".to_owned()); - assert_eq!(rust_keyword::Column::Async.to_string(), "async".to_owned()); - assert_eq!(rust_keyword::Column::Await.to_string(), "await".to_owned()); - assert_eq!(rust_keyword::Column::Break.to_string(), "break".to_owned()); - assert_eq!(rust_keyword::Column::Const.to_string(), "const".to_owned()); - assert_eq!(rust_keyword::Column::Continue.to_string(), "continue".to_owned()); - assert_eq!(rust_keyword::Column::Dyn.to_string(), "dyn".to_owned()); - assert_eq!(rust_keyword::Column::Else.to_string(), "else".to_owned()); - assert_eq!(rust_keyword::Column::Enum.to_string(), "enum".to_owned()); - assert_eq!(rust_keyword::Column::Extern.to_string(), "extern".to_owned()); - assert_eq!(rust_keyword::Column::False.to_string(), "false".to_owned()); - assert_eq!(rust_keyword::Column::Fn.to_string(), "fn".to_owned()); - assert_eq!(rust_keyword::Column::For.to_string(), "for".to_owned()); - assert_eq!(rust_keyword::Column::If.to_string(), "if".to_owned()); - assert_eq!(rust_keyword::Column::Impl.to_string(), "impl".to_owned()); - assert_eq!(rust_keyword::Column::In.to_string(), "in".to_owned()); - assert_eq!(rust_keyword::Column::Let.to_string(), "let".to_owned()); - assert_eq!(rust_keyword::Column::Loop.to_string(), "loop".to_owned()); - assert_eq!(rust_keyword::Column::Match.to_string(), "match".to_owned()); - assert_eq!(rust_keyword::Column::Mod.to_string(), "mod".to_owned()); - assert_eq!(rust_keyword::Column::Move.to_string(), "move".to_owned()); - assert_eq!(rust_keyword::Column::Mut.to_string(), "mut".to_owned()); - assert_eq!(rust_keyword::Column::Pub.to_string(), "pub".to_owned()); - assert_eq!(rust_keyword::Column::Ref.to_string(), "ref".to_owned()); - assert_eq!(rust_keyword::Column::Return.to_string(), "return".to_owned()); - assert_eq!(rust_keyword::Column::Static.to_string(), "static".to_owned()); - assert_eq!(rust_keyword::Column::Struct.to_string(), "struct".to_owned()); - assert_eq!(rust_keyword::Column::Trait.to_string(), "trait".to_owned()); - assert_eq!(rust_keyword::Column::True.to_string(), "true".to_owned()); - assert_eq!(rust_keyword::Column::Type.to_string(), "type".to_owned()); - assert_eq!(rust_keyword::Column::Union.to_string(), "union".to_owned()); - assert_eq!(rust_keyword::Column::Unsafe.to_string(), "unsafe".to_owned()); - assert_eq!(rust_keyword::Column::Use.to_string(), "use".to_owned()); - assert_eq!(rust_keyword::Column::Where.to_string(), "where".to_owned()); - assert_eq!(rust_keyword::Column::While.to_string(), "while".to_owned()); - assert_eq!(rust_keyword::Column::Abstract.to_string(), "abstract".to_owned()); - assert_eq!(rust_keyword::Column::Become.to_string(), "become".to_owned()); - assert_eq!(rust_keyword::Column::Box.to_string(), "box".to_owned()); - assert_eq!(rust_keyword::Column::Do.to_string(), "do".to_owned()); - assert_eq!(rust_keyword::Column::Final.to_string(), "final".to_owned()); - assert_eq!(rust_keyword::Column::Macro.to_string(), "macro".to_owned()); - assert_eq!(rust_keyword::Column::Override.to_string(), "override".to_owned()); - assert_eq!(rust_keyword::Column::Priv.to_string(), "priv".to_owned()); - assert_eq!(rust_keyword::Column::Try.to_string(), "try".to_owned()); - assert_eq!(rust_keyword::Column::Typeof.to_string(), "typeof".to_owned()); - assert_eq!(rust_keyword::Column::Unsized.to_string(), "unsized".to_owned()); - assert_eq!(rust_keyword::Column::Virtual.to_string(), "virtual".to_owned()); - assert_eq!(rust_keyword::Column::Yield.to_string(), "yield".to_owned()); + assert_eq!(Column::Id.to_string().as_str(), "id"); + assert_eq!(Column::Testing.to_string().as_str(), "testing"); + assert_eq!(Column::Rust.to_string().as_str(), "rust"); + assert_eq!(Column::Keywords.to_string().as_str(), "keywords"); + assert_eq!(Column::RawIdentifier.to_string().as_str(), "raw_identifier"); + assert_eq!(Column::As.to_string().as_str(), "as"); + assert_eq!(Column::Async.to_string().as_str(), "async"); + assert_eq!(Column::Await.to_string().as_str(), "await"); + assert_eq!(Column::Break.to_string().as_str(), "break"); + assert_eq!(Column::Const.to_string().as_str(), "const"); + assert_eq!(Column::Continue.to_string().as_str(), "continue"); + assert_eq!(Column::Dyn.to_string().as_str(), "dyn"); + assert_eq!(Column::Else.to_string().as_str(), "else"); + assert_eq!(Column::Enum.to_string().as_str(), "enum"); + assert_eq!(Column::Extern.to_string().as_str(), "extern"); + assert_eq!(Column::False.to_string().as_str(), "false"); + assert_eq!(Column::Fn.to_string().as_str(), "fn"); + assert_eq!(Column::For.to_string().as_str(), "for"); + assert_eq!(Column::If.to_string().as_str(), "if"); + assert_eq!(Column::Impl.to_string().as_str(), "impl"); + assert_eq!(Column::In.to_string().as_str(), "in"); + assert_eq!(Column::Let.to_string().as_str(), "let"); + assert_eq!(Column::Loop.to_string().as_str(), "loop"); + assert_eq!(Column::Match.to_string().as_str(), "match"); + assert_eq!(Column::Mod.to_string().as_str(), "mod"); + assert_eq!(Column::Move.to_string().as_str(), "move"); + assert_eq!(Column::Mut.to_string().as_str(), "mut"); + assert_eq!(Column::Pub.to_string().as_str(), "pub"); + assert_eq!(Column::Ref.to_string().as_str(), "ref"); + assert_eq!(Column::Return.to_string().as_str(), "return"); + assert_eq!(Column::Static.to_string().as_str(), "static"); + assert_eq!(Column::Struct.to_string().as_str(), "struct"); + assert_eq!(Column::Trait.to_string().as_str(), "trait"); + assert_eq!(Column::True.to_string().as_str(), "true"); + assert_eq!(Column::Type.to_string().as_str(), "type"); + assert_eq!(Column::Union.to_string().as_str(), "union"); + assert_eq!(Column::Unsafe.to_string().as_str(), "unsafe"); + assert_eq!(Column::Use.to_string().as_str(), "use"); + assert_eq!(Column::Where.to_string().as_str(), "where"); + assert_eq!(Column::While.to_string().as_str(), "while"); + assert_eq!(Column::Abstract.to_string().as_str(), "abstract"); + assert_eq!(Column::Become.to_string().as_str(), "become"); + assert_eq!(Column::Box.to_string().as_str(), "box"); + assert_eq!(Column::Do.to_string().as_str(), "do"); + assert_eq!(Column::Final.to_string().as_str(), "final"); + assert_eq!(Column::Macro.to_string().as_str(), "macro"); + assert_eq!(Column::Override.to_string().as_str(), "override"); + assert_eq!(Column::Priv.to_string().as_str(), "priv"); + assert_eq!(Column::Try.to_string().as_str(), "try"); + assert_eq!(Column::Typeof.to_string().as_str(), "typeof"); + assert_eq!(Column::Unsized.to_string().as_str(), "unsized"); + assert_eq!(Column::Virtual.to_string().as_str(), "virtual"); + assert_eq!(Column::Yield.to_string().as_str(), "yield"); } } From c7532bcc08c26a644816e81e77f2ef9befa3d514 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Tue, 5 Oct 2021 19:21:05 +0800 Subject: [PATCH 30/65] Basic MockTransaction implementation TODO: nested transaction --- src/database/mock.rs | 113 +++++++++++++++++++++++++++++++++++- src/database/statement.rs | 2 +- src/database/transaction.rs | 33 +++++------ src/driver/mock.rs | 24 +++++++- tests/stream_tests.rs | 3 +- 5 files changed, 150 insertions(+), 25 deletions(-) diff --git a/src/database/mock.rs b/src/database/mock.rs index 012c941f..d1bb65a4 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -9,6 +9,7 @@ use std::{collections::BTreeMap, sync::Arc}; #[derive(Debug)] pub struct MockDatabase { db_backend: DbBackend, + transaction: Option, transaction_log: Vec, exec_results: Vec, query_results: Vec>, @@ -29,6 +30,12 @@ pub trait IntoMockRow { fn into_mock_row(self) -> MockRow; } +#[derive(Debug)] +pub struct OpenTransaction { + stmts: Vec, + transaction_depth: usize, +} + #[derive(Debug, Clone, PartialEq)] pub struct MockTransaction { stmts: Vec, @@ -38,6 +45,7 @@ impl MockDatabase { pub fn new(db_backend: DbBackend) -> Self { Self { db_backend, + transaction: None, transaction_log: Vec::new(), exec_results: Vec::new(), query_results: Vec::new(), @@ -67,7 +75,11 @@ impl MockDatabase { impl MockDatabaseTrait for MockDatabase { fn execute(&mut self, counter: usize, statement: Statement) -> Result { - self.transaction_log.push(MockTransaction::one(statement)); + if let Some(transaction) = &mut self.transaction { + transaction.push(statement); + } else { + self.transaction_log.push(MockTransaction::one(statement)); + } if counter < self.exec_results.len() { Ok(ExecResult { result: ExecResultHolder::Mock(std::mem::take(&mut self.exec_results[counter])), @@ -78,7 +90,11 @@ impl MockDatabaseTrait for MockDatabase { } fn query(&mut self, counter: usize, statement: Statement) -> Result, DbErr> { - self.transaction_log.push(MockTransaction::one(statement)); + if let Some(transaction) = &mut self.transaction { + transaction.push(statement); + } else { + self.transaction_log.push(MockTransaction::one(statement)); + } if counter < self.query_results.len() { Ok(std::mem::take(&mut self.query_results[counter]) .into_iter() @@ -91,6 +107,32 @@ impl MockDatabaseTrait for MockDatabase { } } + fn begin(&mut self) { + if self.transaction.is_some() { + panic!("There is uncommitted transaction"); + } else { + self.transaction = Some(OpenTransaction::init()); + } + } + + fn commit(&mut self) { + if self.transaction.is_some() { + let transaction = self.transaction.take().unwrap(); + self.transaction_log + .push(transaction.into_mock_transaction()); + } else { + panic!("There is no open transaction to commit"); + } + } + + fn rollback(&mut self) { + if self.transaction.is_some() { + self.transaction = None; + } else { + panic!("There is no open transaction to rollback"); + } + } + fn drain_transaction_log(&mut self) -> Vec { std::mem::take(&mut self.transaction_log) } @@ -174,3 +216,70 @@ impl MockTransaction { stmts.into_iter().map(Self::one).collect() } } + +impl OpenTransaction { + fn init() -> Self { + Self { + stmts: Vec::new(), + transaction_depth: 0, + } + } + + fn push(&mut self, stmt: Statement) { + self.stmts.push(stmt); + } + + fn into_mock_transaction(self) -> MockTransaction { + MockTransaction { stmts: self.stmts } + } +} + +#[cfg(test)] +#[cfg(feature = "mock")] +mod tests { + use crate::{ + entity::*, tests_cfg::*, ConnectionTrait, DbBackend, DbErr, MockDatabase, MockTransaction, + Statement, + }; + + #[smol_potat::test] + async fn test_transaction_1() { + let db = MockDatabase::new(DbBackend::Postgres).into_connection(); + + db.transaction::<_, _, DbErr>(|txn| { + Box::pin(async move { + let _1 = cake::Entity::find().one(txn).await; + let _2 = fruit::Entity::find().all(txn).await; + + Ok(()) + }) + }) + .await + .unwrap(); + + let _ = cake::Entity::find().all(&db).await; + + assert_eq!( + db.into_transaction_log(), + vec![ + MockTransaction::many(vec![ + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#, + vec![1u64.into()] + ), + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#, + vec![] + ), + ]), + MockTransaction::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake""#, + vec![] + ), + ] + ); + } +} diff --git a/src/database/statement.rs b/src/database/statement.rs index 63a1d57f..12b07487 100644 --- a/src/database/statement.rs +++ b/src/database/statement.rs @@ -104,4 +104,4 @@ build_schema_stmt!(sea_query::TableCreateStatement); build_schema_stmt!(sea_query::TableDropStatement); build_schema_stmt!(sea_query::TableAlterStatement); build_schema_stmt!(sea_query::TableRenameStatement); -build_schema_stmt!(sea_query::TableTruncateStatement); \ No newline at end of file +build_schema_stmt!(sea_query::TableTruncateStatement); diff --git a/src/database/transaction.rs b/src/database/transaction.rs index 7230488b..d7bbc058 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -92,9 +92,10 @@ impl DatabaseTransaction { .await .map_err(sqlx_error_to_query_err)? } - // should we do something for mocked connections? #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} + InnerConnection::Mock(ref mut c) => { + c.begin(); + } } Ok(res) } @@ -108,13 +109,9 @@ impl DatabaseTransaction { T: Send, E: std::error::Error + Send, { - let res = callback(&self) - .await - .map_err(TransactionError::Transaction); + let res = callback(&self).await.map_err(TransactionError::Transaction); if res.is_ok() { - self.commit() - .await - .map_err(TransactionError::Connection)?; + self.commit().await.map_err(TransactionError::Connection)?; } else { self.rollback() .await @@ -144,9 +141,10 @@ impl DatabaseTransaction { .await .map_err(sqlx_error_to_query_err)? } - //Should we do something for mocked connections? #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} + InnerConnection::Mock(ref mut c) => { + c.commit(); + } } Ok(()) } @@ -172,9 +170,10 @@ impl DatabaseTransaction { .await .map_err(sqlx_error_to_query_err)? } - //Should we do something for mocked connections? #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} + InnerConnection::Mock(ref mut c) => { + c.rollback(); + } } Ok(()) } @@ -196,9 +195,10 @@ impl DatabaseTransaction { InnerConnection::Sqlite(c) => { ::TransactionManager::start_rollback(c); } - //Should we do something for mocked connections? #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {} + InnerConnection::Mock(c) => { + c.rollback(); + } } } else { //this should never happen @@ -338,10 +338,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { T: Send, E: std::error::Error + Send, { - let transaction = self - .begin() - .await - .map_err(TransactionError::Connection)?; + let transaction = self.begin().await.map_err(TransactionError::Connection)?; transaction.run(_callback).await } } diff --git a/src/driver/mock.rs b/src/driver/mock.rs index ad6b2298..d3c36262 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -26,6 +26,12 @@ pub trait MockDatabaseTrait: Send + Debug { fn query(&mut self, counter: usize, stmt: Statement) -> Result, DbErr>; + fn begin(&mut self); + + fn commit(&mut self); + + fn rollback(&mut self); + fn drain_transaction_log(&mut self) -> Vec; fn get_database_backend(&self) -> DbBackend; @@ -86,10 +92,14 @@ impl MockDatabaseConnection { } } - pub fn get_mocker_mutex(&self) -> &Mutex> { + pub(crate) fn get_mocker_mutex(&self) -> &Mutex> { &self.mocker } + pub fn get_database_backend(&self) -> DbBackend { + self.mocker.lock().unwrap().get_database_backend() + } + pub fn execute(&self, statement: Statement) -> Result { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); @@ -119,7 +129,15 @@ impl MockDatabaseConnection { } } - pub fn get_database_backend(&self) -> DbBackend { - self.mocker.lock().unwrap().get_database_backend() + pub fn begin(&self) { + self.mocker.lock().unwrap().begin() + } + + pub fn commit(&self) { + self.mocker.lock().unwrap().commit() + } + + pub fn rollback(&self) { + self.mocker.lock().unwrap().rollback() } } diff --git a/tests/stream_tests.rs b/tests/stream_tests.rs index d30063e5..560fc01e 100644 --- a/tests/stream_tests.rs +++ b/tests/stream_tests.rs @@ -1,7 +1,6 @@ pub mod common; pub use common::{bakery_chain::*, setup::*, TestContext}; -use futures::StreamExt; pub use sea_orm::entity::*; pub use sea_orm::{ConnectionTrait, DbErr, QueryFilter}; @@ -12,6 +11,8 @@ pub use sea_orm::{ConnectionTrait, DbErr, QueryFilter}; feature = "sqlx-postgres" ))] pub async fn stream() -> Result<(), DbErr> { + use futures::StreamExt; + let ctx = TestContext::new("stream").await; let bakery = bakery::ActiveModel { From 8990261d703cbeecb6e1292efd7e58f50d060eaf Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Wed, 6 Oct 2021 12:28:46 +0800 Subject: [PATCH 31/65] Bind null custom types --- Cargo.toml | 2 +- tests/common/bakery_chain/metadata.rs | 4 ++-- tests/common/setup/schema.rs | 4 ++-- tests/parallel_tests.rs | 12 ++++++------ tests/uuid_tests.rs | 4 ++-- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a4466152..b9675c38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ futures-util = { version = "^0.3" } log = { version = "^0.4", optional = true } rust_decimal = { version = "^1", optional = true } sea-orm-macros = { version = "^0.2.4", path = "sea-orm-macros", optional = true } -sea-query = { version = "^0.16.5", features = ["thread-safe"] } +sea-query = { version = "^0.17.0", features = ["thread-safe"] } sea-strum = { version = "^0.21", features = ["derive", "sea-orm"] } serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } diff --git a/tests/common/bakery_chain/metadata.rs b/tests/common/bakery_chain/metadata.rs index de513a22..2c297cd3 100644 --- a/tests/common/bakery_chain/metadata.rs +++ b/tests/common/bakery_chain/metadata.rs @@ -10,8 +10,8 @@ pub struct Model { pub key: String, pub value: String, pub bytes: Vec, - pub date: Date, - pub time: Time, + pub date: Option, + pub time: Option( - mut active_model: A, - ) -> <::PrimaryKey as PrimaryKeyTrait>::ValueType - where - A: ActiveModelTrait, - { - #primary_key_value - } - } )) } diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index ab076bf0..0243cc7d 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -2,6 +2,7 @@ use crate::{ error::*, DatabaseConnection, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, Value, }; use async_trait::async_trait; +use sea_query::ValueTuple; use std::fmt::Debug; #[derive(Clone, Debug, Default)] @@ -9,7 +10,8 @@ pub struct ActiveValue where V: Into, { - value: Option, + // Don't want to call ActiveValue::unwrap() and cause panic + pub(self) value: Option, state: ActiveValueState, } @@ -66,6 +68,41 @@ pub trait ActiveModelTrait: Clone + Debug { fn default() -> Self; + fn get_primary_key_value(&self) -> Option { + let mut cols = ::PrimaryKey::iter(); + macro_rules! next { + () => { + if let Some(col) = cols.next() { + if let Some(val) = self.get(col.into_column()).value { + val + } else { + return None; + } + } else { + return None; + } + }; + } + match ::PrimaryKey::iter().count() { + 1 => { + let s1 = next!(); + Some(ValueTuple::One(s1)) + } + 2 => { + let s1 = next!(); + let s2 = next!(); + Some(ValueTuple::Two(s1, s2)) + } + 3 => { + let s1 = next!(); + let s2 = next!(); + let s3 = next!(); + Some(ValueTuple::Three(s1, s2, s3)) + } + _ => panic!("The arity cannot be larger than 3"), + } + } + async fn insert(self, db: &DatabaseConnection) -> Result where ::Model: IntoActiveModel, diff --git a/src/entity/base_entity.rs b/src/entity/base_entity.rs index d691fb3e..764f2524 100644 --- a/src/entity/base_entity.rs +++ b/src/entity/base_entity.rs @@ -1,7 +1,7 @@ use crate::{ ActiveModelTrait, ColumnTrait, Delete, DeleteMany, DeleteOne, FromQueryResult, Insert, - ModelTrait, PrimaryKeyToColumn, PrimaryKeyTrait, PrimaryKeyValue, QueryFilter, Related, - RelationBuilder, RelationTrait, RelationType, Select, Update, UpdateMany, UpdateOne, + ModelTrait, PrimaryKeyToColumn, PrimaryKeyTrait, QueryFilter, Related, RelationBuilder, + RelationTrait, RelationType, Select, Update, UpdateMany, UpdateOne, }; use sea_query::{Alias, Iden, IntoIden, IntoTableRef, IntoValueTuple, TableRef}; pub use sea_strum::IntoEnumIterator as Iterable; @@ -49,9 +49,7 @@ pub trait EntityTrait: EntityName { type Relation: RelationTrait; - type PrimaryKey: PrimaryKeyTrait - + PrimaryKeyToColumn - + PrimaryKeyValue; + type PrimaryKey: PrimaryKeyTrait + PrimaryKeyToColumn; fn belongs_to(related: R) -> RelationBuilder where diff --git a/src/entity/prelude.rs b/src/entity/prelude.rs index 1fecfa96..fd61613b 100644 --- a/src/entity/prelude.rs +++ b/src/entity/prelude.rs @@ -1,8 +1,8 @@ pub use crate::{ error::*, ActiveModelBehavior, ActiveModelTrait, ColumnDef, ColumnTrait, ColumnType, EntityName, EntityTrait, EnumIter, ForeignKeyAction, Iden, IdenStatic, Linked, ModelTrait, - PrimaryKeyToColumn, PrimaryKeyTrait, PrimaryKeyValue, QueryFilter, QueryResult, Related, - RelationDef, RelationTrait, Select, Value, + PrimaryKeyToColumn, PrimaryKeyTrait, QueryFilter, QueryResult, Related, RelationDef, + RelationTrait, Select, Value, }; #[cfg(feature = "macros")] diff --git a/src/entity/primary_key.rs b/src/entity/primary_key.rs index 530eba30..a5e4cde0 100644 --- a/src/entity/primary_key.rs +++ b/src/entity/primary_key.rs @@ -1,11 +1,18 @@ use super::{ColumnTrait, IdenStatic, Iterable}; -use crate::{ActiveModelTrait, EntityTrait, TryFromU64, TryGetableMany}; -use sea_query::IntoValueTuple; +use crate::{TryFromU64, TryGetableMany}; +use sea_query::{FromValueTuple, IntoValueTuple}; use std::fmt::Debug; //LINT: composite primary key cannot auto increment pub trait PrimaryKeyTrait: IdenStatic + Iterable { - type ValueType: Sized + Send + Debug + PartialEq + IntoValueTuple + TryGetableMany + TryFromU64; + type ValueType: Sized + + Send + + Debug + + PartialEq + + IntoValueTuple + + FromValueTuple + + TryGetableMany + + TryFromU64; fn auto_increment() -> bool; } @@ -19,12 +26,3 @@ pub trait PrimaryKeyToColumn { where Self: Sized; } - -pub trait PrimaryKeyValue -where - E: EntityTrait, -{ - fn get_primary_key_value(active_model: A) -> ::ValueType - where - A: ActiveModelTrait; -} diff --git a/src/executor/insert.rs b/src/executor/insert.rs index ed4da269..02d02c0b 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -2,7 +2,7 @@ use crate::{ error::*, ActiveModelTrait, DatabaseConnection, DbBackend, EntityTrait, Insert, PrimaryKeyTrait, Statement, TryFromU64, }; -use sea_query::InsertStatement; +use sea_query::{FromValueTuple, InsertStatement, ValueTuple}; use std::{future::Future, marker::PhantomData}; #[derive(Debug)] @@ -10,7 +10,7 @@ pub struct Inserter where A: ActiveModelTrait, { - primary_key: Option<<<::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType>, + primary_key: Option, query: InsertStatement, model: PhantomData, } @@ -35,7 +35,6 @@ where where A: 'a, { - // TODO: extract primary key's value from query // so that self is dropped before entering await let mut query = self.query; if db.get_database_backend() == DbBackend::Postgres { @@ -49,7 +48,6 @@ where } } Inserter::::new(self.primary_key, query).exec(db) - // TODO: return primary key if extracted before, otherwise use InsertResult } } @@ -57,10 +55,7 @@ impl Inserter where A: ActiveModelTrait, { - pub fn new( - primary_key: Option<<<::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType>, - query: InsertStatement, - ) -> Self { + pub fn new(primary_key: Option, query: InsertStatement) -> Self { Self { primary_key, query, @@ -82,7 +77,7 @@ where // Only Statement impl Send async fn exec_insert( - primary_key: Option<<<::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType>, + primary_key: Option, statement: Statement, db: &DatabaseConnection, ) -> Result, DbErr> @@ -108,7 +103,7 @@ where let last_insert_id = match last_insert_id_opt { Some(last_insert_id) => last_insert_id, None => match primary_key { - Some(primary_key) => primary_key, + Some(value_tuple) => FromValueTuple::from_value_tuple(value_tuple), None => return Err(DbErr::Exec("Fail to unpack last_insert_id".to_owned())), }, }; diff --git a/src/query/insert.rs b/src/query/insert.rs index 59c06b97..615ce06e 100644 --- a/src/query/insert.rs +++ b/src/query/insert.rs @@ -1,9 +1,9 @@ use crate::{ ActiveModelTrait, EntityName, EntityTrait, IntoActiveModel, Iterable, PrimaryKeyTrait, - PrimaryKeyValue, QueryTrait, + QueryTrait, }; use core::marker::PhantomData; -use sea_query::InsertStatement; +use sea_query::{InsertStatement, ValueTuple}; #[derive(Debug)] pub struct Insert @@ -12,7 +12,7 @@ where { pub(crate) query: InsertStatement, pub(crate) columns: Vec, - pub(crate) primary_key: Option<<<::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType>, + pub(crate) primary_key: Option, pub(crate) model: PhantomData, } @@ -114,7 +114,7 @@ where let mut am: A = m.into_active_model(); self.primary_key = if !<::PrimaryKey as PrimaryKeyTrait>::auto_increment() { - Some(<::PrimaryKey as PrimaryKeyValue>::get_primary_key_value::(am.clone())) + am.get_primary_key_value() } else { None }; From 23215c8dd53000b3d7eca75be1bf842ab7d1842a Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Fri, 8 Oct 2021 18:22:25 +0800 Subject: [PATCH 46/65] fix clippy warnings --- src/entity/active_model.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index 0243cc7d..5b4289f9 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -68,6 +68,7 @@ pub trait ActiveModelTrait: Clone + Debug { fn default() -> Self; + #[allow(clippy::question_mark)] fn get_primary_key_value(&self) -> Option { let mut cols = ::PrimaryKey::iter(); macro_rules! next { From c24d7704d9ae76485ec9312373284982510e140e Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Sat, 9 Oct 2021 14:38:09 +0800 Subject: [PATCH 47/65] Move examples #232 --- .github/workflows/rust.yml | 24 ++++++++++++++++++- examples/{async-std => basic}/Cargo.toml | 0 examples/{async-std => basic}/Readme.md | 0 examples/{async-std => basic}/bakery.sql | 0 examples/{async-std => basic}/import.sh | 0 examples/{async-std => basic}/src/entities.rs | 0 .../{async-std => basic}/src/example_cake.rs | 0 .../src/example_cake_filling.rs | 0 .../src/example_filling.rs | 0 .../{async-std => basic}/src/example_fruit.rs | 0 examples/{async-std => basic}/src/main.rs | 0 .../{async-std => basic}/src/operation.rs | 0 examples/{async-std => basic}/src/select.rs | 0 {examples/tokio => issues/86}/Cargo.toml | 0 {examples/tokio => issues/86}/src/cake.rs | 0 {examples/tokio => issues/86}/src/main.rs | 0 16 files changed, 23 insertions(+), 1 deletion(-) rename examples/{async-std => basic}/Cargo.toml (100%) rename examples/{async-std => basic}/Readme.md (100%) rename examples/{async-std => basic}/bakery.sql (100%) rename examples/{async-std => basic}/import.sh (100%) rename examples/{async-std => basic}/src/entities.rs (100%) rename examples/{async-std => basic}/src/example_cake.rs (100%) rename examples/{async-std => basic}/src/example_cake_filling.rs (100%) rename examples/{async-std => basic}/src/example_filling.rs (100%) rename examples/{async-std => basic}/src/example_fruit.rs (100%) rename examples/{async-std => basic}/src/main.rs (100%) rename examples/{async-std => basic}/src/operation.rs (100%) rename examples/{async-std => basic}/src/select.rs (100%) rename {examples/tokio => issues/86}/Cargo.toml (100%) rename {examples/tokio => issues/86}/src/cake.rs (100%) rename {examples/tokio => issues/86}/src/main.rs (100%) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 5a5e6e91..ec5ea734 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -171,7 +171,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - path: [async-std, tokio, actix_example, actix4_example, rocket_example] + path: [basic, actix_example, actix4_example, rocket_example] steps: - uses: actions/checkout@v2 @@ -187,6 +187,28 @@ jobs: args: > --manifest-path examples/${{ matrix.path }}/Cargo.toml + issues: + name: Issues + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + path: [86] + steps: + - uses: actions/checkout@v2 + + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + + - uses: actions-rs/cargo@v1 + with: + command: build + args: > + --manifest-path issues/${{ matrix.path }}/Cargo.toml + sqlite: name: SQLite runs-on: ubuntu-20.04 diff --git a/examples/async-std/Cargo.toml b/examples/basic/Cargo.toml similarity index 100% rename from examples/async-std/Cargo.toml rename to examples/basic/Cargo.toml diff --git a/examples/async-std/Readme.md b/examples/basic/Readme.md similarity index 100% rename from examples/async-std/Readme.md rename to examples/basic/Readme.md diff --git a/examples/async-std/bakery.sql b/examples/basic/bakery.sql similarity index 100% rename from examples/async-std/bakery.sql rename to examples/basic/bakery.sql diff --git a/examples/async-std/import.sh b/examples/basic/import.sh similarity index 100% rename from examples/async-std/import.sh rename to examples/basic/import.sh diff --git a/examples/async-std/src/entities.rs b/examples/basic/src/entities.rs similarity index 100% rename from examples/async-std/src/entities.rs rename to examples/basic/src/entities.rs diff --git a/examples/async-std/src/example_cake.rs b/examples/basic/src/example_cake.rs similarity index 100% rename from examples/async-std/src/example_cake.rs rename to examples/basic/src/example_cake.rs diff --git a/examples/async-std/src/example_cake_filling.rs b/examples/basic/src/example_cake_filling.rs similarity index 100% rename from examples/async-std/src/example_cake_filling.rs rename to examples/basic/src/example_cake_filling.rs diff --git a/examples/async-std/src/example_filling.rs b/examples/basic/src/example_filling.rs similarity index 100% rename from examples/async-std/src/example_filling.rs rename to examples/basic/src/example_filling.rs diff --git a/examples/async-std/src/example_fruit.rs b/examples/basic/src/example_fruit.rs similarity index 100% rename from examples/async-std/src/example_fruit.rs rename to examples/basic/src/example_fruit.rs diff --git a/examples/async-std/src/main.rs b/examples/basic/src/main.rs similarity index 100% rename from examples/async-std/src/main.rs rename to examples/basic/src/main.rs diff --git a/examples/async-std/src/operation.rs b/examples/basic/src/operation.rs similarity index 100% rename from examples/async-std/src/operation.rs rename to examples/basic/src/operation.rs diff --git a/examples/async-std/src/select.rs b/examples/basic/src/select.rs similarity index 100% rename from examples/async-std/src/select.rs rename to examples/basic/src/select.rs diff --git a/examples/tokio/Cargo.toml b/issues/86/Cargo.toml similarity index 100% rename from examples/tokio/Cargo.toml rename to issues/86/Cargo.toml diff --git a/examples/tokio/src/cake.rs b/issues/86/src/cake.rs similarity index 100% rename from examples/tokio/src/cake.rs rename to issues/86/src/cake.rs diff --git a/examples/tokio/src/main.rs b/issues/86/src/main.rs similarity index 100% rename from examples/tokio/src/main.rs rename to issues/86/src/main.rs From 12a8e5c8e90a0d503844d2b658cf42e72a32328c Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Sat, 9 Oct 2021 14:40:12 +0800 Subject: [PATCH 48/65] Readme --- README.md | 2 +- src/lib.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b4525b31..d6bfcdff 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ SeaORM is a relational ORM to help you build light weight and concurrent web services in Rust. [![Getting Started](https://img.shields.io/badge/Getting%20Started-brightgreen)](https://www.sea-ql.org/SeaORM/docs/index) -[![Usage Example](https://img.shields.io/badge/Usage%20Example-yellow)](https://github.com/SeaQL/sea-orm/tree/master/examples/async-std) +[![Usage Example](https://img.shields.io/badge/Usage%20Example-yellow)](https://github.com/SeaQL/sea-orm/tree/master/examples/basic) [![Actix Example](https://img.shields.io/badge/Actix%20Example-blue)](https://github.com/SeaQL/sea-orm/tree/master/examples/actix_example) [![Rocket Example](https://img.shields.io/badge/Rocket%20Example-orange)](https://github.com/SeaQL/sea-orm/tree/master/examples/rocket_example) [![Discord](https://img.shields.io/discord/873880840487206962?label=Discord)](https://discord.com/invite/uCPdDXzbdv) diff --git a/src/lib.rs b/src/lib.rs index 6ddc442c..1b78cf58 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,7 @@ //! SeaORM is a relational ORM to help you build light weight and concurrent web services in Rust. //! //! [![Getting Started](https://img.shields.io/badge/Getting%20Started-brightgreen)](https://www.sea-ql.org/SeaORM/docs/index) -//! [![Usage Example](https://img.shields.io/badge/Usage%20Example-yellow)](https://github.com/SeaQL/sea-orm/tree/master/examples/async-std) +//! [![Usage Example](https://img.shields.io/badge/Usage%20Example-yellow)](https://github.com/SeaQL/sea-orm/tree/master/examples/basic) //! [![Actix Example](https://img.shields.io/badge/Actix%20Example-blue)](https://github.com/SeaQL/sea-orm/tree/master/examples/actix_example) //! [![Rocket Example](https://img.shields.io/badge/Rocket%20Example-orange)](https://github.com/SeaQL/sea-orm/tree/master/examples/rocket_example) //! [![Discord](https://img.shields.io/discord/873880840487206962?label=Discord)](https://discord.com/invite/uCPdDXzbdv) From 018f7dd19fdae11f8f560a5ea96588802528c1d2 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Sat, 9 Oct 2021 21:14:08 +0800 Subject: [PATCH 49/65] Streaming for MockConnection --- src/database/mock.rs | 31 +++++++++++++++++++++++++++++++ src/executor/select.rs | 6 ------ 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/database/mock.rs b/src/database/mock.rs index b6e96a79..d0fce2cd 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -315,4 +315,35 @@ mod tests { assert_eq!(db.into_transaction_log(), vec![]); } + + #[smol_potat::test] + async fn test_stream_1() -> Result<(), DbErr> { + use futures::TryStreamExt; + + let apple = fruit::Model { + id: 1, + name: "Apple".to_owned(), + cake_id: Some(1), + }; + + let orange = fruit::Model { + id: 2, + name: "orange".to_owned(), + cake_id: None, + }; + + let db = MockDatabase::new(DbBackend::Postgres) + .append_query_results(vec![vec![apple.clone(), orange.clone()]]) + .into_connection(); + + let mut stream = fruit::Entity::find().stream(&db).await?; + + assert_eq!(stream.try_next().await?, Some(apple)); + + assert_eq!(stream.try_next().await?, Some(orange)); + + assert_eq!(stream.try_next().await?, None); + + Ok(()) + } } diff --git a/src/executor/select.rs b/src/executor/select.rs index 1c1825cc..a95fe463 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -3,11 +3,9 @@ use crate::{ ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, SelectTwoMany, Statement, TryGetableMany, }; -#[cfg(feature = "sqlx-dep")] use futures::{Stream, TryStreamExt}; use sea_query::SelectStatement; use std::marker::PhantomData; -#[cfg(feature = "sqlx-dep")] use std::pin::Pin; #[derive(Clone, Debug)] @@ -252,7 +250,6 @@ where self.into_model().all(db).await } - #[cfg(feature = "sqlx-dep")] pub async fn stream<'a: 'b, 'b, C>( self, db: &'a C, @@ -320,7 +317,6 @@ where self.into_model().all(db).await } - #[cfg(feature = "sqlx-dep")] pub async fn stream<'a: 'b, 'b, C>( self, db: &'a C, @@ -381,7 +377,6 @@ where self.into_model().one(db).await } - #[cfg(feature = "sqlx-dep")] pub async fn stream<'a: 'b, 'b, C>( self, db: &'a C, @@ -456,7 +451,6 @@ where Ok(models) } - #[cfg(feature = "sqlx-dep")] pub async fn stream<'a: 'b, 'b, C>( self, db: &'a C, From 91efa1fae29bd661d8bbf1b347ecfa42a7f9ff9f Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Sat, 9 Oct 2021 21:45:25 +0800 Subject: [PATCH 50/65] Test streaming in transaction --- src/database/mock.rs | 59 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 11 deletions(-) diff --git a/src/database/mock.rs b/src/database/mock.rs index d0fce2cd..8bde4725 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -241,6 +241,17 @@ mod tests { Transaction, TransactionError, }; + #[derive(Debug, PartialEq)] + pub struct MyErr(String); + + impl std::error::Error for MyErr {} + + impl std::fmt::Display for MyErr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0.as_str()) + } + } + #[smol_potat::test] async fn test_transaction_1() { let db = MockDatabase::new(DbBackend::Postgres).into_connection(); @@ -286,17 +297,6 @@ mod tests { async fn test_transaction_2() { let db = MockDatabase::new(DbBackend::Postgres).into_connection(); - #[derive(Debug, PartialEq)] - pub struct MyErr(String); - - impl std::error::Error for MyErr {} - - impl std::fmt::Display for MyErr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.0.as_str()) - } - } - let result = db .transaction::<_, (), MyErr>(|txn| { Box::pin(async move { @@ -346,4 +346,41 @@ mod tests { Ok(()) } + + #[smol_potat::test] + async fn test_stream_in_transaction() -> Result<(), DbErr> { + use futures::TryStreamExt; + + let apple = fruit::Model { + id: 1, + name: "Apple".to_owned(), + cake_id: Some(1), + }; + + let orange = fruit::Model { + id: 2, + name: "orange".to_owned(), + cake_id: None, + }; + + let db = MockDatabase::new(DbBackend::Postgres) + .append_query_results(vec![vec![apple.clone(), orange.clone()]]) + .into_connection(); + + let txn = db.begin().await?; + + let mut stream = fruit::Entity::find().stream(&txn).await?; + + assert_eq!(stream.try_next().await?, Some(apple)); + + assert_eq!(stream.try_next().await?, Some(orange)); + + assert_eq!(stream.try_next().await?, None); + + std::mem::drop(stream); + + txn.commit().await?; + + Ok(()) + } } From 7c8e766e8b34022a6e97b5953dcb4bcfbdc4390b Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Sat, 9 Oct 2021 23:01:06 +0800 Subject: [PATCH 51/65] sea-orm-codegen 0.2.6 --- sea-orm-codegen/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sea-orm-codegen/Cargo.toml b/sea-orm-codegen/Cargo.toml index 0e8fa624..9013cea4 100644 --- a/sea-orm-codegen/Cargo.toml +++ b/sea-orm-codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sea-orm-codegen" -version = "0.2.5" +version = "0.2.6" authors = ["Billy Chan "] edition = "2018" description = "Code Generator for SeaORM" From 0eee2206ba14b4a6d4db1a4ac9fec956046e81bc Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Sat, 9 Oct 2021 23:01:46 +0800 Subject: [PATCH 52/65] sea-orm-cli 0.2.6 --- sea-orm-cli/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sea-orm-cli/Cargo.toml b/sea-orm-cli/Cargo.toml index 3c05d08b..07db2e4b 100644 --- a/sea-orm-cli/Cargo.toml +++ b/sea-orm-cli/Cargo.toml @@ -3,7 +3,7 @@ [package] name = "sea-orm-cli" -version = "0.2.5" +version = "0.2.6" authors = [ "Billy Chan " ] edition = "2018" description = "Command line utility for SeaORM" @@ -21,7 +21,7 @@ path = "src/main.rs" clap = { version = "^2.33.3" } dotenv = { version = "^0.15" } async-std = { version = "^1.9", features = [ "attributes" ] } -sea-orm-codegen = { version = "^0.2.5", path = "../sea-orm-codegen" } +sea-orm-codegen = { version = "^0.2.6", path = "../sea-orm-codegen" } sea-schema = { version = "^0.2.9", default-features = false, features = [ "debug-print", "sqlx-mysql", From f0aea3bf105dfe1cbeb91c2608842b8ac80720c6 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Sat, 9 Oct 2021 23:15:55 +0800 Subject: [PATCH 53/65] sea-orm-macros 0.2.6 --- sea-orm-macros/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sea-orm-macros/Cargo.toml b/sea-orm-macros/Cargo.toml index 22b58c01..cde1575c 100644 --- a/sea-orm-macros/Cargo.toml +++ b/sea-orm-macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sea-orm-macros" -version = "0.2.5" +version = "0.2.6" authors = [ "Billy Chan " ] edition = "2018" description = "Derive macros for SeaORM" From bfb83044f16b55fb98f2c0634b8462c3b546ec04 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Sat, 9 Oct 2021 23:17:06 +0800 Subject: [PATCH 54/65] 0.2.6 --- CHANGELOG.md | 7 +++++++ Cargo.toml | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 693c1d0e..ee544831 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/) and this project adheres to [Semantic Versioning](http://semver.org/). +## 0.2.6 - 2021-10-09 + +- [[#224]] [sea-orm-cli] Date & Time column type mapping +- Escape rust keywords with `r#` raw identifier + +[#224]: https://github.com/SeaQL/sea-orm/pull/224 + ## 0.2.5 - 2021-10-06 - [[#227]] Resolve "Inserting actual none value of Option results in panic" diff --git a/Cargo.toml b/Cargo.toml index a8d7a8ca..0c8de3ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = [".", "sea-orm-macros", "sea-orm-codegen"] [package] name = "sea-orm" -version = "0.2.5" +version = "0.2.6" authors = ["Chris Tsang "] edition = "2018" description = "🐚 An async & dynamic ORM for Rust" @@ -29,7 +29,7 @@ futures = { version = "^0.3" } futures-util = { version = "^0.3" } log = { version = "^0.4", optional = true } rust_decimal = { version = "^1", optional = true } -sea-orm-macros = { version = "^0.2.5", path = "sea-orm-macros", optional = true } +sea-orm-macros = { version = "^0.2.6", path = "sea-orm-macros", optional = true } sea-query = { version = "^0.17.0", features = ["thread-safe"] } sea-strum = { version = "^0.21", features = ["derive", "sea-orm"] } serde = { version = "^1.0", features = ["derive"] } From f94c33d1eaf7978b39d3ad41df32b185d77d4fde Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 11 Oct 2021 12:11:40 +0800 Subject: [PATCH 55/65] Explicit COMMIT and ROLLBACK in Mock --- src/database/mock.rs | 23 ++++++++++++++++++++--- src/database/stream/transaction.rs | 2 ++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/database/mock.rs b/src/database/mock.rs index 8bde4725..bdab864c 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -117,7 +117,8 @@ impl MockDatabaseTrait for MockDatabase { fn commit(&mut self) { if self.transaction.is_some() { - let transaction = self.transaction.take().unwrap(); + let mut transaction = self.transaction.take().unwrap(); + transaction.push(Statement::from_string(self.db_backend, "COMMIT".to_owned())); self.transaction_log.push(transaction.into_transaction()); } else { panic!("There is no open transaction to commit"); @@ -126,7 +127,12 @@ impl MockDatabaseTrait for MockDatabase { fn rollback(&mut self) { if self.transaction.is_some() { - self.transaction = None; + let mut transaction = self.transaction.take().unwrap(); + transaction.push(Statement::from_string( + self.db_backend, + "ROLLBACK".to_owned(), + )); + self.transaction_log.push(transaction.into_transaction()); } else { panic!("There is no open transaction to rollback"); } @@ -283,6 +289,7 @@ mod tests { r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#, vec![] ), + Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()), ]), Transaction::from_sql_and_values( DbBackend::Postgres, @@ -313,7 +320,17 @@ mod tests { _ => panic!(), } - assert_eq!(db.into_transaction_log(), vec![]); + assert_eq!( + db.into_transaction_log(), + vec![Transaction::many(vec![ + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#, + vec![1u64.into()] + ), + Statement::from_string(DbBackend::Postgres, "ROLLBACK".to_owned()), + ]),] + ); } #[smol_potat::test] diff --git a/src/database/stream/transaction.rs b/src/database/stream/transaction.rs index 651f3d11..2dddc59c 100644 --- a/src/database/stream/transaction.rs +++ b/src/database/stream/transaction.rs @@ -12,6 +12,8 @@ use futures::lock::MutexGuard; use crate::{DbErr, InnerConnection, QueryResult, Statement}; #[ouroboros::self_referencing] +/// `TransactionStream` cannot be used in a `transaction` closure as it does not impl `Send`. +/// It seems to be a Rust limitation right now, and solution to work around this deemed to be extremely hard. pub struct TransactionStream<'a> { stmt: Statement, conn: MutexGuard<'a, InnerConnection>, From 1a2bd13158edaa58d2bfe8739bdd15b27676cd32 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 11 Oct 2021 18:39:46 +0800 Subject: [PATCH 56/65] Nested transaction unit tests --- src/database/mock.rs | 176 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 166 insertions(+), 10 deletions(-) diff --git a/src/database/mock.rs b/src/database/mock.rs index bdab864c..6bb710b7 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -109,7 +109,10 @@ impl MockDatabaseTrait for MockDatabase { fn begin(&mut self) { if self.transaction.is_some() { - panic!("There is uncommitted transaction"); + self.transaction + .as_mut() + .unwrap() + .begin_nested(self.db_backend); } else { self.transaction = Some(OpenTransaction::init()); } @@ -117,9 +120,10 @@ impl MockDatabaseTrait for MockDatabase { fn commit(&mut self) { if self.transaction.is_some() { - let mut transaction = self.transaction.take().unwrap(); - transaction.push(Statement::from_string(self.db_backend, "COMMIT".to_owned())); - self.transaction_log.push(transaction.into_transaction()); + if self.transaction.as_mut().unwrap().commit(self.db_backend) { + let transaction = self.transaction.take().unwrap(); + self.transaction_log.push(transaction.into_transaction()); + } } else { panic!("There is no open transaction to commit"); } @@ -127,12 +131,10 @@ impl MockDatabaseTrait for MockDatabase { fn rollback(&mut self) { if self.transaction.is_some() { - let mut transaction = self.transaction.take().unwrap(); - transaction.push(Statement::from_string( - self.db_backend, - "ROLLBACK".to_owned(), - )); - self.transaction_log.push(transaction.into_transaction()); + if self.transaction.as_mut().unwrap().rollback(self.db_backend) { + let transaction = self.transaction.take().unwrap(); + self.transaction_log.push(transaction.into_transaction()); + } } else { panic!("There is no open transaction to rollback"); } @@ -230,11 +232,50 @@ impl OpenTransaction { } } + fn begin_nested(&mut self, db_backend: DbBackend) { + self.transaction_depth += 1; + self.push(Statement::from_string( + db_backend, + format!("SAVEPOINT savepoint_{}", self.transaction_depth), + )); + } + + fn commit(&mut self, db_backend: DbBackend) -> bool { + if self.transaction_depth == 0 { + self.push(Statement::from_string(db_backend, "COMMIT".to_owned())); + true + } else { + self.push(Statement::from_string( + db_backend, + format!("RELEASE SAVEPOINT savepoint_{}", self.transaction_depth), + )); + self.transaction_depth -= 1; + false + } + } + + fn rollback(&mut self, db_backend: DbBackend) -> bool { + if self.transaction_depth == 0 { + self.push(Statement::from_string(db_backend, "ROLLBACK".to_owned())); + true + } else { + self.push(Statement::from_string( + db_backend, + format!("ROLLBACK TO SAVEPOINT savepoint_{}", self.transaction_depth), + )); + self.transaction_depth -= 1; + false + } + } + fn push(&mut self, stmt: Statement) { self.stmts.push(stmt); } fn into_transaction(self) -> Transaction { + if self.transaction_depth != 0 { + panic!("There is uncommitted nested transaction."); + } Transaction { stmts: self.stmts } } } @@ -246,6 +287,7 @@ mod tests { entity::*, tests_cfg::*, ConnectionTrait, DbBackend, DbErr, MockDatabase, Statement, Transaction, TransactionError, }; + use pretty_assertions::assert_eq; #[derive(Debug, PartialEq)] pub struct MyErr(String); @@ -333,6 +375,120 @@ mod tests { ); } + #[smol_potat::test] + async fn test_nested_transaction_1() { + let db = MockDatabase::new(DbBackend::Postgres).into_connection(); + + db.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = cake::Entity::find().one(txn).await; + + txn.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = fruit::Entity::find().all(txn).await; + + Ok(()) + }) + }) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + assert_eq!( + db.into_transaction_log(), + vec![Transaction::many(vec![ + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#, + vec![1u64.into()] + ), + Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_1".to_owned()), + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#, + vec![] + ), + Statement::from_string( + DbBackend::Postgres, + "RELEASE SAVEPOINT savepoint_1".to_owned() + ), + Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()), + ]),] + ); + } + + #[smol_potat::test] + async fn test_nested_transaction_2() { + let db = MockDatabase::new(DbBackend::Postgres).into_connection(); + + db.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = cake::Entity::find().one(txn).await; + + txn.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = fruit::Entity::find().all(txn).await; + + txn.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = cake::Entity::find().all(txn).await; + + Ok(()) + }) + }) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + assert_eq!( + db.into_transaction_log(), + vec![Transaction::many(vec![ + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#, + vec![1u64.into()] + ), + Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_1".to_owned()), + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#, + vec![] + ), + Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_2".to_owned()), + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake""#, + vec![] + ), + Statement::from_string( + DbBackend::Postgres, + "RELEASE SAVEPOINT savepoint_2".to_owned() + ), + Statement::from_string( + DbBackend::Postgres, + "RELEASE SAVEPOINT savepoint_1".to_owned() + ), + Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()), + ]),] + ); + } + #[smol_potat::test] async fn test_stream_1() -> Result<(), DbErr> { use futures::TryStreamExt; From ecaa0dca188f4f50690d949ffceeef0c0d3fbdbb Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Mon, 11 Oct 2021 18:50:05 +0800 Subject: [PATCH 57/65] Edit --- src/database/mock.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/database/mock.rs b/src/database/mock.rs index 6bb710b7..f9703d9e 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -542,15 +542,15 @@ mod tests { let txn = db.begin().await?; - let mut stream = fruit::Entity::find().stream(&txn).await?; + if let Ok(mut stream) = fruit::Entity::find().stream(&txn).await { + assert_eq!(stream.try_next().await?, Some(apple)); - assert_eq!(stream.try_next().await?, Some(apple)); + assert_eq!(stream.try_next().await?, Some(orange)); - assert_eq!(stream.try_next().await?, Some(orange)); + assert_eq!(stream.try_next().await?, None); - assert_eq!(stream.try_next().await?, None); - - std::mem::drop(stream); + // stream will be dropped end of scope OR std::mem::drop(stream); + } txn.commit().await?; From 12800468b1b9c72be8a28cf5a993c74a9cd1a8ed Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Mon, 11 Oct 2021 23:00:24 +0800 Subject: [PATCH 58/65] Test nested transaction --- tests/transaction_tests.rs | 245 +++++++++++++++++++++++++++++++++++++ 1 file changed, 245 insertions(+) diff --git a/tests/transaction_tests.rs b/tests/transaction_tests.rs index 61d194c4..57739dde 100644 --- a/tests/transaction_tests.rs +++ b/tests/transaction_tests.rs @@ -101,3 +101,248 @@ fn _transaction_with_reference<'a>( Ok(()) }) } + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn transaction_nested() { + let ctx = TestContext::new("transaction_nested_test").await; + + ctx.db + .transaction::<_, _, DbErr>(|txn| { + Box::pin(async move { + let _ = bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(txn) + .await?; + + let _ = bakery::ActiveModel { + name: Set("Top Bakery".to_owned()), + profit_margin: Set(15.0), + ..Default::default() + } + .save(txn) + .await?; + + // Try nested transaction committed + txn.transaction::<_, _, DbErr>(|txn| { + Box::pin(async move { + let _ = bakery::ActiveModel { + name: Set("Nested Bakery".to_owned()), + profit_margin: Set(88.88), + ..Default::default() + } + .save(txn) + .await?; + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 3); + + // Try nested-nested transaction rollbacked + let is_err = txn + .transaction::<_, _, DbErr>(|txn| { + Box::pin(async move { + let _ = bakery::ActiveModel { + name: Set("Rock n Roll Bakery".to_owned()), + profit_margin: Set(28.8), + ..Default::default() + } + .save(txn) + .await?; + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 4); + + if true { + Err(DbErr::Query("Force Rollback!".to_owned())) + } else { + Ok(()) + } + }) + }) + .await + .is_err(); + + assert!(is_err); + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 3); + + // Try nested-nested transaction committed + txn.transaction::<_, _, DbErr>(|txn| { + Box::pin(async move { + let _ = bakery::ActiveModel { + name: Set("Rock n Roll Bakery".to_owned()), + profit_margin: Set(28.8), + ..Default::default() + } + .save(txn) + .await?; + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 4); + + Ok(()) + }) + }) + .await + .unwrap(); + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 4); + + Ok(()) + }) + }) + .await + .unwrap(); + + // Try nested transaction rollbacked + let is_err = txn + .transaction::<_, _, DbErr>(|txn| { + Box::pin(async move { + let _ = bakery::ActiveModel { + name: Set("Rock n Roll Bakery".to_owned()), + profit_margin: Set(28.8), + ..Default::default() + } + .save(txn) + .await?; + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 5); + + // Try nested-nested transaction committed + txn.transaction::<_, _, DbErr>(|txn| { + Box::pin(async move { + let _ = bakery::ActiveModel { + name: Set("Rock n Roll Bakery".to_owned()), + profit_margin: Set(28.8), + ..Default::default() + } + .save(txn) + .await?; + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 6); + + Ok(()) + }) + }) + .await + .unwrap(); + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 6); + + // Try nested-nested transaction rollbacked + let is_err = txn + .transaction::<_, _, DbErr>(|txn| { + Box::pin(async move { + let _ = bakery::ActiveModel { + name: Set("Rock n Roll Bakery".to_owned()), + profit_margin: Set(28.8), + ..Default::default() + } + .save(txn) + .await?; + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 7); + + if true { + Err(DbErr::Query("Force Rollback!".to_owned())) + } else { + Ok(()) + } + }) + }) + .await + .is_err(); + + assert!(is_err); + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 6); + + if true { + Err(DbErr::Query("Force Rollback!".to_owned())) + } else { + Ok(()) + } + }) + }) + .await + .is_err(); + + assert!(is_err); + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 4); + + Ok(()) + }) + }) + .await + .unwrap(); + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(&ctx.db) + .await + .unwrap(); + + assert_eq!(bakeries.len(), 4); + + ctx.delete().await; +} From 4a1b8fabc5b9dfb494c53d650c9055732076c0a9 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Tue, 12 Oct 2021 03:05:00 +0800 Subject: [PATCH 59/65] Rework Rocket Db Pool --- examples/rocket_example/Cargo.toml | 16 +- examples/rocket_example/Rocket.toml | 2 +- examples/rocket_example/src/main.rs | 50 ++--- examples/rocket_example/src/pool.rs | 16 +- sea-orm-rocket/Cargo.toml | 2 + sea-orm-rocket/README.md | 1 + sea-orm-rocket/codegen/Cargo.toml | 22 +++ sea-orm-rocket/codegen/src/database.rs | 110 +++++++++++ sea-orm-rocket/codegen/src/lib.rs | 52 ++++++ sea-orm-rocket/lib/Cargo.toml | 27 +++ sea-orm-rocket/lib/src/config.rs | 83 +++++++++ sea-orm-rocket/lib/src/database.rs | 249 +++++++++++++++++++++++++ sea-orm-rocket/lib/src/error.rs | 35 ++++ sea-orm-rocket/lib/src/lib.rs | 20 ++ sea-orm-rocket/lib/src/pool.rs | 63 +++++++ 15 files changed, 715 insertions(+), 33 deletions(-) create mode 100644 sea-orm-rocket/Cargo.toml create mode 100644 sea-orm-rocket/README.md create mode 100644 sea-orm-rocket/codegen/Cargo.toml create mode 100644 sea-orm-rocket/codegen/src/database.rs create mode 100644 sea-orm-rocket/codegen/src/lib.rs create mode 100644 sea-orm-rocket/lib/Cargo.toml create mode 100644 sea-orm-rocket/lib/src/config.rs create mode 100644 sea-orm-rocket/lib/src/database.rs create mode 100644 sea-orm-rocket/lib/src/error.rs create mode 100644 sea-orm-rocket/lib/src/lib.rs create mode 100644 sea-orm-rocket/lib/src/pool.rs diff --git a/examples/rocket_example/Cargo.toml b/examples/rocket_example/Cargo.toml index 22dc9a6b..c896fe2d 100644 --- a/examples/rocket_example/Cargo.toml +++ b/examples/rocket_example/Cargo.toml @@ -15,15 +15,21 @@ futures-util = { version = "^0.3" } rocket = { git = "https://github.com/SergioBenitez/Rocket.git", features = [ "json", ] } -rocket_db_pools = { git = "https://github.com/SergioBenitez/Rocket.git" } rocket_dyn_templates = { git = "https://github.com/SergioBenitez/Rocket.git", features = [ "tera", ] } -# remove `path = ""` in your own project -sea-orm = { path = "../../", version = "^0.2.3", features = ["macros"], default-features = false } serde_json = { version = "^1" } +[dependencies.sea-orm] +path = "../../" # remove this line in your own project +version = "^0.2.3" +features = ["macros", "runtime-tokio-native-tls"] +default-features = false + +[dependencies.sea-orm-rocket] +path = "../../sea-orm-rocket/lib" + [features] default = ["sqlx-postgres"] -sqlx-mysql = ["sea-orm/sqlx-mysql", "rocket_db_pools/sqlx_mysql"] -sqlx-postgres = ["sea-orm/sqlx-postgres", "rocket_db_pools/sqlx_postgres"] +sqlx-mysql = ["sea-orm/sqlx-mysql"] +sqlx-postgres = ["sea-orm/sqlx-postgres"] diff --git a/examples/rocket_example/Rocket.toml b/examples/rocket_example/Rocket.toml index b7fcc12a..fc294bd2 100644 --- a/examples/rocket_example/Rocket.toml +++ b/examples/rocket_example/Rocket.toml @@ -1,7 +1,7 @@ [default] template_dir = "templates/" -[default.databases.rocket_example] +[default.databases.sea_orm] # Mysql # make sure to enable "sqlx-mysql" feature in Cargo.toml, i.e default = ["sqlx-mysql"] # url = "mysql://root:@localhost/rocket_example" diff --git a/examples/rocket_example/src/main.rs b/examples/rocket_example/src/main.rs index 43f227c6..e2ef9254 100644 --- a/examples/rocket_example/src/main.rs +++ b/examples/rocket_example/src/main.rs @@ -7,21 +7,17 @@ use rocket::fs::{relative, FileServer}; use rocket::request::FlashMessage; use rocket::response::{Flash, Redirect}; use rocket::{Build, Request, Rocket}; -use rocket_db_pools::{sqlx, Connection, Database}; use rocket_dyn_templates::{context, Template}; use sea_orm::{entity::*, query::*}; +use sea_orm_rocket::{Connection, Database}; mod pool; -use pool::RocketDbPool; +use pool::Db; mod setup; -#[derive(Database, Debug)] -#[database("rocket_example")] -struct Db(RocketDbPool); - -type Result> = std::result::Result; +type Result> = std::result::Result; mod post; pub use post::Entity as Post; @@ -34,7 +30,9 @@ async fn new() -> Template { } #[post("/", data = "")] -async fn create(conn: Connection, post_form: Form) -> Flash { +async fn create(conn: Connection<'_, Db>, post_form: Form) -> Flash { + let db = conn.into_inner(); + let form = post_form.into_inner(); post::ActiveModel { @@ -42,7 +40,7 @@ async fn create(conn: Connection, post_form: Form) -> Flash, post_form: Form) -> Flash", data = "")] -async fn update(conn: Connection, id: i32, post_form: Form) -> Flash { +async fn update(conn: Connection<'_, Db>, id: i32, post_form: Form) -> Flash { + let db = conn.into_inner(); + let post: post::ActiveModel = Post::find_by_id(id) - .one(&*conn) + .one(db) .await .unwrap() .unwrap() @@ -65,7 +65,7 @@ async fn update(conn: Connection, id: i32, post_form: Form) -> title: Set(form.title.to_owned()), text: Set(form.text.to_owned()), } - .save(&*conn) + .save(db) .await .expect("could not edit post"); @@ -74,11 +74,13 @@ async fn update(conn: Connection, id: i32, post_form: Form) -> #[get("/?&")] async fn list( - conn: Connection, + conn: Connection<'_, Db>, posts_per_page: Option, page: Option, flash: Option>, ) -> Template { + let db = conn.into_inner(); + // Set page number and items per page let page = page.unwrap_or(1); let posts_per_page = posts_per_page.unwrap_or(DEFAULT_POSTS_PER_PAGE); @@ -89,7 +91,7 @@ async fn list( // Setup paginator let paginator = Post::find() .order_by_asc(post::Column::Id) - .paginate(&*conn, posts_per_page); + .paginate(db, posts_per_page); let num_pages = paginator.num_pages().await.ok().unwrap(); // Fetch paginated posts @@ -111,9 +113,11 @@ async fn list( } #[get("/")] -async fn edit(conn: Connection, id: i32) -> Template { +async fn edit(conn: Connection<'_, Db>, id: i32) -> Template { + let db = conn.into_inner(); + let post: Option = Post::find_by_id(id) - .one(&*conn) + .one(db) .await .expect("could not find post"); @@ -126,22 +130,26 @@ async fn edit(conn: Connection, id: i32) -> Template { } #[delete("/")] -async fn delete(conn: Connection, id: i32) -> Flash { +async fn delete(conn: Connection<'_, Db>, id: i32) -> Flash { + let db = conn.into_inner(); + let post: post::ActiveModel = Post::find_by_id(id) - .one(&*conn) + .one(db) .await .unwrap() .unwrap() .into(); - post.delete(&*conn).await.unwrap(); + post.delete(db).await.unwrap(); Flash::success(Redirect::to("/"), "Post successfully deleted.") } #[delete("/")] -async fn destroy(conn: Connection) -> Result<()> { - Post::delete_many().exec(&*conn).await.unwrap(); +async fn destroy(conn: Connection<'_, Db>) -> Result<()> { + let db = conn.into_inner(); + + Post::delete_many().exec(db).await.unwrap(); Ok(()) } diff --git a/examples/rocket_example/src/pool.rs b/examples/rocket_example/src/pool.rs index c4140c1f..931a4712 100644 --- a/examples/rocket_example/src/pool.rs +++ b/examples/rocket_example/src/pool.rs @@ -1,13 +1,17 @@ use async_trait::async_trait; -use rocket_db_pools::{rocket::figment::Figment, Config}; +use sea_orm_rocket::{rocket::figment::Figment, Config, Database}; + +#[derive(Database, Debug)] +#[database("sea_orm")] +pub struct Db(SeaOrmPool); #[derive(Debug)] -pub struct RocketDbPool { +pub struct SeaOrmPool { pub conn: sea_orm::DatabaseConnection, } #[async_trait] -impl rocket_db_pools::Pool for RocketDbPool { +impl sea_orm_rocket::Pool for SeaOrmPool { type Error = sea_orm::DbErr; type Connection = sea_orm::DatabaseConnection; @@ -16,10 +20,10 @@ impl rocket_db_pools::Pool for RocketDbPool { let config = figment.extract::().unwrap(); let conn = sea_orm::Database::connect(&config.url).await.unwrap(); - Ok(RocketDbPool { conn }) + Ok(SeaOrmPool { conn }) } - async fn get(&self) -> Result { - Ok(self.conn.clone()) + fn borrow(&self) -> &Self::Connection { + &self.conn } } diff --git a/sea-orm-rocket/Cargo.toml b/sea-orm-rocket/Cargo.toml new file mode 100644 index 00000000..4975f8e1 --- /dev/null +++ b/sea-orm-rocket/Cargo.toml @@ -0,0 +1,2 @@ +[workspace] +members = ["codegen", "lib"] \ No newline at end of file diff --git a/sea-orm-rocket/README.md b/sea-orm-rocket/README.md new file mode 100644 index 00000000..f8d94b21 --- /dev/null +++ b/sea-orm-rocket/README.md @@ -0,0 +1 @@ +# SeaORM Rocket support crate. \ No newline at end of file diff --git a/sea-orm-rocket/codegen/Cargo.toml b/sea-orm-rocket/codegen/Cargo.toml new file mode 100644 index 00000000..75656487 --- /dev/null +++ b/sea-orm-rocket/codegen/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "sea-orm-rocket-codegen" +version = "0.1.0-rc" +authors = ["Sergio Benitez ", "Jeb Rosen "] +description = "Procedural macros for sea_orm_rocket." +repository = "https://github.com/SergioBenitez/Rocket/contrib/db_pools" +readme = "../README.md" +keywords = ["rocket", "framework", "database", "pools"] +license = "MIT OR Apache-2.0" +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +devise = "0.3" +quote = "1" + +[dev-dependencies] +rocket = { git = "https://github.com/SergioBenitez/Rocket.git", default-features = false } +trybuild = "1.0" +version_check = "0.9" diff --git a/sea-orm-rocket/codegen/src/database.rs b/sea-orm-rocket/codegen/src/database.rs new file mode 100644 index 00000000..a6f5a981 --- /dev/null +++ b/sea-orm-rocket/codegen/src/database.rs @@ -0,0 +1,110 @@ +use proc_macro::TokenStream; + +use devise::{DeriveGenerator, FromMeta, MapperBuild, Support, ValidatorBuild}; +use devise::proc_macro2_diagnostics::SpanDiagnosticExt; +use devise::syn::{self, spanned::Spanned}; + +const ONE_DATABASE_ATTR: &str = "missing `#[database(\"name\")]` attribute"; +const ONE_UNNAMED_FIELD: &str = "struct must have exactly one unnamed field"; + +#[derive(Debug, FromMeta)] +struct DatabaseAttribute { + #[meta(naked)] + name: String, +} + +pub fn derive_database(input: TokenStream) -> TokenStream { + DeriveGenerator::build_for(input, quote!(impl sea_orm_rocket::Database)) + .support(Support::TupleStruct) + .validator(ValidatorBuild::new() + .struct_validate(|_, s| { + if s.fields.len() == 1 { + Ok(()) + } else { + Err(s.span().error(ONE_UNNAMED_FIELD)) + } + }) + ) + .outer_mapper(MapperBuild::new() + .struct_map(|_, s| { + let pool_type = match &s.fields { + syn::Fields::Unnamed(f) => &f.unnamed[0].ty, + _ => unreachable!("Support::TupleStruct"), + }; + + let decorated_type = &s.ident; + let db_ty = quote_spanned!(decorated_type.span() => + <#decorated_type as sea_orm_rocket::Database> + ); + + quote_spanned! { decorated_type.span() => + impl From<#pool_type> for #decorated_type { + fn from(pool: #pool_type) -> Self { + Self(pool) + } + } + + impl std::ops::Deref for #decorated_type { + type Target = #pool_type; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl std::ops::DerefMut for #decorated_type { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + #[rocket::async_trait] + impl<'r> rocket::request::FromRequest<'r> for &'r #decorated_type { + type Error = (); + + async fn from_request( + req: &'r rocket::request::Request<'_> + ) -> rocket::request::Outcome { + match #db_ty::fetch(req.rocket()) { + Some(db) => rocket::outcome::Outcome::Success(db), + None => rocket::outcome::Outcome::Failure(( + rocket::http::Status::InternalServerError, ())) + } + } + } + + impl rocket::Sentinel for &#decorated_type { + fn abort(rocket: &rocket::Rocket) -> bool { + #db_ty::fetch(rocket).is_none() + } + } + } + }) + ) + .outer_mapper(quote!(#[rocket::async_trait])) + .inner_mapper(MapperBuild::new() + .try_struct_map(|_, s| { + let db_name = DatabaseAttribute::one_from_attrs("database", &s.attrs)? + .map(|attr| attr.name) + .ok_or_else(|| s.span().error(ONE_DATABASE_ATTR))?; + + let fairing_name = format!("'{}' Database Pool", db_name); + + let pool_type = match &s.fields { + syn::Fields::Unnamed(f) => &f.unnamed[0].ty, + _ => unreachable!("Support::TupleStruct"), + }; + + Ok(quote_spanned! { pool_type.span() => + type Pool = #pool_type; + + const NAME: &'static str = #db_name; + + fn init() -> sea_orm_rocket::Initializer { + sea_orm_rocket::Initializer::with_name(#fairing_name) + } + }) + }) + ) + .to_tokens() +} diff --git a/sea-orm-rocket/codegen/src/lib.rs b/sea-orm-rocket/codegen/src/lib.rs new file mode 100644 index 00000000..7cb32f94 --- /dev/null +++ b/sea-orm-rocket/codegen/src/lib.rs @@ -0,0 +1,52 @@ +#![recursion_limit="256"] +#![warn(rust_2018_idioms)] + +//! # `sea_orm_rocket` - Code Generation +//! +//! Implements the code generation portion of the `sea_orm_rocket` crate. This +//! is an implementation detail. This create should never be depended on +//! directly. + +#[macro_use] extern crate quote; + +mod database; + +/// Automatic derive for the [`Database`] trait. +/// +/// The derive generates an implementation of [`Database`] as follows: +/// +/// * [`Database::NAME`] is set to the value in the `#[database("name")]` +/// attribute. +/// +/// This names the database, providing an anchor to configure the database via +/// `Rocket.toml` or any other configuration source. Specifically, the +/// configuration in `databases.name` is used to configure the driver. +/// +/// * [`Database::Pool`] is set to the wrapped type: `PoolType` above. The type +/// must implement [`Pool`]. +/// +/// To meet the required [`Database`] supertrait bounds, this derive also +/// generates implementations for: +/// +/// * `From` +/// +/// * `Deref` +/// +/// * `DerefMut` +/// +/// * `FromRequest<'_> for &Db` +/// +/// * `Sentinel for &Db` +/// +/// The `Deref` impls enable accessing the database pool directly from +/// references `&Db` or `&mut Db`. To force a dereference to the underlying +/// type, use `&db.0` or `&**db` or their `&mut` variants. +/// +/// [`Database`]: ../sea_orm_rocket/trait.Database.html +/// [`Database::NAME`]: ../sea_orm_rocket/trait.Database.html#associatedconstant.NAME +/// [`Database::Pool`]: ../sea_orm_rocket/trait.Database.html#associatedtype.Pool +/// [`Pool`]: ../sea_orm_rocket/trait.Pool.html +#[proc_macro_derive(Database, attributes(database))] +pub fn derive_database(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + crate::database::derive_database(input) +} diff --git a/sea-orm-rocket/lib/Cargo.toml b/sea-orm-rocket/lib/Cargo.toml new file mode 100644 index 00000000..3a586fe1 --- /dev/null +++ b/sea-orm-rocket/lib/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "sea-orm-rocket" +version = "0.1.0" +authors = ["Sergio Benitez ", "Jeb Rosen "] +description = "SeaORM Rocket support crate" +repository = "https://github.com/SeaQL/sea-orm" +readme = "../README.md" +keywords = ["rocket", "framework", "database", "pools"] +license = "MIT OR Apache-2.0" +edition = "2018" + +[package.metadata.docs.rs] +all-features = true + +[dependencies.rocket] +git = "https://github.com/SergioBenitez/Rocket.git" +version = "0.5.0-rc.1" +default-features = false + +[dependencies.sea-orm-rocket-codegen] +path = "../codegen" +version = "0.1.0-rc" + +[dev-dependencies.rocket] +git = "https://github.com/SergioBenitez/Rocket.git" +default-features = false +features = ["json"] diff --git a/sea-orm-rocket/lib/src/config.rs b/sea-orm-rocket/lib/src/config.rs new file mode 100644 index 00000000..b30c2ce5 --- /dev/null +++ b/sea-orm-rocket/lib/src/config.rs @@ -0,0 +1,83 @@ +use rocket::serde::{Deserialize, Serialize}; + +/// Base configuration for all database drivers. +/// +/// A dictionary matching this structure is extracted from the active +/// [`Figment`](crate::figment::Figment), scoped to `databases.name`, where +/// `name` is the name of the database, by the +/// [`Initializer`](crate::Initializer) fairing on ignition and used to +/// configure the relevant database and database pool. +/// +/// With the default provider, these parameters are typically configured in a +/// `Rocket.toml` file: +/// +/// ```toml +/// [default.databases.db_name] +/// url = "/path/to/db.sqlite" +/// +/// # only `url` is required. `Initializer` provides defaults for the rest. +/// min_connections = 64 +/// max_connections = 1024 +/// connect_timeout = 5 +/// idle_timeout = 120 +/// ``` +/// +/// Alternatively, a custom provider can be used. For example, a custom `Figment` +/// with a global `databases.name` configuration: +/// +/// ```rust +/// # use rocket::launch; +/// #[launch] +/// fn rocket() -> _ { +/// let figment = rocket::Config::figment() +/// .merge(("databases.name", sea_orm_rocket::Config { +/// url: "db:specific@config&url".into(), +/// min_connections: None, +/// max_connections: 1024, +/// connect_timeout: 3, +/// idle_timeout: None, +/// })); +/// +/// rocket::custom(figment) +/// } +/// ``` +/// +/// For general information on configuration in Rocket, see [`rocket::config`]. +/// For higher-level details on configuring a database, see the [crate-level +/// docs](crate#configuration). +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(crate = "rocket::serde")] +pub struct Config { + /// Database-specific connection and configuration URL. + /// + /// The format of the URL is database specific; consult your database's + /// documentation. + pub url: String, + /// Minimum number of connections to maintain in the pool. + /// + /// **Note:** `deadpool` drivers do not support and thus ignore this value. + /// + /// _Default:_ `None`. + pub min_connections: Option, + /// Maximum number of connections to maintain in the pool. + /// + /// _Default:_ `workers * 4`. + pub max_connections: usize, + /// Number of seconds to wait for a connection before timing out. + /// + /// If the timeout elapses before a connection can be made or retrieved from + /// a pool, an error is returned. + /// + /// _Default:_ `5`. + pub connect_timeout: u64, + /// Maximum number of seconds to keep a connection alive for. + /// + /// After a connection is established, it is maintained in a pool for + /// efficient connection retrieval. When an `idle_timeout` is set, that + /// connection will be closed after the timeout elapses. If an + /// `idle_timeout` is not specified, the behavior is driver specific but + /// typically defaults to keeping a connection active indefinitely. + /// + /// _Default:_ `None`. + pub idle_timeout: Option, +} diff --git a/sea-orm-rocket/lib/src/database.rs b/sea-orm-rocket/lib/src/database.rs new file mode 100644 index 00000000..6eb98385 --- /dev/null +++ b/sea-orm-rocket/lib/src/database.rs @@ -0,0 +1,249 @@ +use std::marker::PhantomData; +use std::ops::{DerefMut}; + +use rocket::{error, info_, Build, Ignite, Phase, Rocket, Sentinel}; +use rocket::fairing::{self, Fairing, Info, Kind}; +use rocket::request::{FromRequest, Outcome, Request}; +use rocket::http::Status; + +use rocket::yansi::Paint; +use rocket::figment::providers::Serialized; + +use crate::Pool; + +/// Derivable trait which ties a database [`Pool`] with a configuration name. +/// +/// This trait should rarely, if ever, be implemented manually. Instead, it +/// should be derived: +/// +/// ```ignore +/// use sea_orm_rocket::{Database}; +/// # use sea_orm_rocket::MockPool as SeaOrmPool; +/// +/// #[derive(Database, Debug)] +/// #[database("sea_orm")] +/// struct Db(SeaOrmPool); +/// +/// #[launch] +/// fn rocket() -> _ { +/// rocket::build().attach(Db::init()) +/// } +/// ``` +/// +/// See the [`Database` derive](derive@crate::Database) for details. +pub trait Database: From + DerefMut + Send + Sync + 'static { + /// The [`Pool`] type of connections to this database. + /// + /// When `Database` is derived, this takes the value of the `Inner` type in + /// `struct Db(Inner)`. + type Pool: Pool; + + /// The configuration name for this database. + /// + /// When `Database` is derived, this takes the value `"name"` in the + /// `#[database("name")]` attribute. + const NAME: &'static str; + + /// Returns a fairing that initializes the database and its connection pool. + /// + /// # Example + /// + /// ```rust + /// # mod _inner { + /// # use rocket::launch; + /// use sea_orm_rocket::{Database}; + /// # use sea_orm_rocket::MockPool as SeaOrmPool; + /// + /// #[derive(Database)] + /// #[database("sea_orm")] + /// struct Db(SeaOrmPool); + /// + /// #[launch] + /// fn rocket() -> _ { + /// rocket::build().attach(Db::init()) + /// } + /// # } + /// ``` + fn init() -> Initializer { + Initializer::new() + } + + /// Returns a reference to the initialized database in `rocket`. The + /// initializer fairing returned by `init()` must have already executed for + /// `Option` to be `Some`. This is guaranteed to be the case if the fairing + /// is attached and either: + /// + /// * Rocket is in the [`Orbit`](rocket::Orbit) phase. That is, the + /// application is running. This is always the case in request guards + /// and liftoff fairings, + /// * _or_ Rocket is in the [`Build`](rocket::Build) or + /// [`Ignite`](rocket::Ignite) phase and the `Initializer` fairing has + /// already been run. This is the case in all fairing callbacks + /// corresponding to fairings attached _after_ the `Initializer` + /// fairing. + /// + /// # Example + /// + /// Run database migrations in an ignite fairing. It is imperative that the + /// migration fairing be registered _after_ the `init()` fairing. + /// + /// ```rust + /// # mod _inner { + /// # use rocket::launch; + /// use rocket::{Rocket, Build}; + /// use rocket::fairing::{self, AdHoc}; + /// + /// use sea_orm_rocket::{Database}; + /// # use sea_orm_rocket::MockPool as SeaOrmPool; + /// + /// #[derive(Database)] + /// #[database("sea_orm")] + /// struct Db(SeaOrmPool); + /// + /// async fn run_migrations(rocket: Rocket) -> fairing::Result { + /// if let Some(db) = Db::fetch(&rocket) { + /// // run migrations using `db`. get the inner type with &db.0. + /// Ok(rocket) + /// } else { + /// Err(rocket) + /// } + /// } + /// + /// #[launch] + /// fn rocket() -> _ { + /// rocket::build() + /// .attach(Db::init()) + /// .attach(AdHoc::try_on_ignite("DB Migrations", run_migrations)) + /// } + /// # } + /// ``` + fn fetch(rocket: &Rocket

) -> Option<&Self> { + if let Some(db) = rocket.state() { + return Some(db); + } + + let dbtype = std::any::type_name::(); + let fairing = Paint::default(format!("{}::init()", dbtype)).bold(); + error!("Attempted to fetch unattached database `{}`.", Paint::default(dbtype).bold()); + info_!("`{}` fairing must be attached prior to using this database.", fairing); + None + } +} + +/// A [`Fairing`] which initializes a [`Database`] and its connection pool. +/// +/// A value of this type can be created for any type `D` that implements +/// [`Database`] via the [`Database::init()`] method on the type. Normally, a +/// value of this type _never_ needs to be constructed directly. This +/// documentation exists purely as a reference. +/// +/// This fairing initializes a database pool. Specifically, it: +/// +/// 1. Reads the configuration at `database.db_name`, where `db_name` is +/// [`Database::NAME`]. +/// +/// 2. Sets [`Config`](crate::Config) defaults on the configuration figment. +/// +/// 3. Calls [`Pool::init()`]. +/// +/// 4. Stores the database instance in managed storage, retrievable via +/// [`Database::fetch()`]. +/// +/// The name of the fairing itself is `Initializer`, with `D` replaced with +/// the type name `D` unless a name is explicitly provided via +/// [`Self::with_name()`]. +pub struct Initializer(Option<&'static str>, PhantomData D>); + +/// A request guard which retrieves a single connection to a [`Database`]. +/// +/// For a database type of `Db`, a request guard of `Connection` retrieves a +/// single connection to `Db`. +/// +/// The request guard succeeds if the database was initialized by the +/// [`Initializer`] fairing and a connection is available within +/// [`connect_timeout`](crate::Config::connect_timeout) seconds. +/// * If the `Initializer` fairing was _not_ attached, the guard _fails_ with +/// status `InternalServerError`. A [`Sentinel`] guards this condition, and so +/// this type of failure is unlikely to occur. A `None` error is returned. +/// * If a connection is not available within `connect_timeout` seconds or +/// another error occurs, the gaurd _fails_ with status `ServiceUnavailable` +/// and the error is returned in `Some`. +/// +pub struct Connection<'a, D: Database>(&'a ::Connection); + +impl Initializer { + /// Returns a database initializer fairing for `D`. + /// + /// This method should never need to be called manually. See the [crate + /// docs](crate) for usage information. + pub fn new() -> Self { + Self(None, std::marker::PhantomData) + } + + /// Returns a database initializer fairing for `D` with name `name`. + /// + /// This method should never need to be called manually. See the [crate + /// docs](crate) for usage information. + pub fn with_name(name: &'static str) -> Self { + Self(Some(name), std::marker::PhantomData) + } +} + +impl<'a, D: Database> Connection<'a, D> { + /// Returns the internal connection value. See the [`Connection` Deref + /// column](crate#supported-drivers) for the expected type of this value. + /// + /// Note that `Connection` derefs to the internal connection type, so + /// using this method is likely unnecessary. See [deref](Connection#deref) + /// for examples. + pub fn into_inner(self) -> &'a ::Connection { + self.0 + } +} + +#[rocket::async_trait] +impl Fairing for Initializer { + fn info(&self) -> Info { + Info { + name: self.0.unwrap_or_else(std::any::type_name::), + kind: Kind::Ignite, + } + } + + async fn on_ignite(&self, rocket: Rocket) -> fairing::Result { + let workers: usize = rocket.figment() + .extract_inner(rocket::Config::WORKERS) + .unwrap_or_else(|_| rocket::Config::default().workers); + + let figment = rocket.figment() + .focus(&format!("databases.{}", D::NAME)) + .merge(Serialized::default("max_connections", workers * 4)) + .merge(Serialized::default("connect_timeout", 5)); + + match ::init(&figment).await { + Ok(pool) => Ok(rocket.manage(D::from(pool))), + Err(e) => { + error!("failed to initialize database: {}", e); + Err(rocket) + } + } + } +} + +#[rocket::async_trait] +impl<'r, D: Database> FromRequest<'r> for Connection<'r, D> { + type Error = Option<::Error>; + + async fn from_request(req: &'r Request<'_>) -> Outcome { + match D::fetch(req.rocket()) { + Some(pool) => Outcome::Success(Connection(pool.borrow())), + None => Outcome::Failure((Status::InternalServerError, None)), + } + } +} + +impl Sentinel for Connection<'_, D> { + fn abort(rocket: &Rocket) -> bool { + D::fetch(rocket).is_none() + } +} diff --git a/sea-orm-rocket/lib/src/error.rs b/sea-orm-rocket/lib/src/error.rs new file mode 100644 index 00000000..69bae106 --- /dev/null +++ b/sea-orm-rocket/lib/src/error.rs @@ -0,0 +1,35 @@ +use std::fmt; + +/// A general error type for use by [`Pool`](crate::Pool#implementing) +/// implementors and returned by the [`Connection`](crate::Connection) request +/// guard. +#[derive(Debug)] +pub enum Error { + /// An error that occured during database/pool initialization. + Init(A), + + /// An error that ocurred while retrieving a connection from the pool. + Get(B), + + /// A [`Figment`](crate::figment::Figment) configuration error. + Config(crate::figment::Error), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::Init(e) => write!(f, "failed to initialize database: {}", e), + Error::Get(e) => write!(f, "failed to get db connection: {}", e), + Error::Config(e) => write!(f, "bad configuration: {}", e), + } + } +} + +impl std::error::Error for Error + where A: fmt::Debug + fmt::Display, B: fmt::Debug + fmt::Display {} + +impl From for Error { + fn from(e: crate::figment::Error) -> Self { + Self::Config(e) + } +} diff --git a/sea-orm-rocket/lib/src/lib.rs b/sea-orm-rocket/lib/src/lib.rs new file mode 100644 index 00000000..4b98cd0b --- /dev/null +++ b/sea-orm-rocket/lib/src/lib.rs @@ -0,0 +1,20 @@ +//! SeaORM Rocket support crate. +#![deny(missing_docs)] + +/// Re-export of the `figment` crate. +#[doc(inline)] +pub use rocket::figment; + +pub use rocket; + +mod database; +mod error; +mod pool; +mod config; + +pub use self::database::{Connection, Database, Initializer}; +pub use self::error::Error; +pub use self::pool::{Pool, MockPool}; +pub use self::config::Config; + +pub use sea_orm_rocket_codegen::*; diff --git a/sea-orm-rocket/lib/src/pool.rs b/sea-orm-rocket/lib/src/pool.rs new file mode 100644 index 00000000..9dd8b75a --- /dev/null +++ b/sea-orm-rocket/lib/src/pool.rs @@ -0,0 +1,63 @@ +use rocket::figment::Figment; + +/// Generic [`Database`](crate::Database) driver connection pool trait. +/// +/// This trait provides a generic interface to various database pooling +/// implementations in the Rust ecosystem. It can be implemented by anyone, but +/// this crate provides implementations for common drivers. +/// ``` +#[rocket::async_trait] +pub trait Pool: Sized + Send + Sync + 'static { + /// The connection type managed by this pool, returned by [`Self::get()`]. + type Connection; + + /// The error type returned by [`Self::init()`] and [`Self::get()`]. + type Error: std::error::Error; + + /// Constructs a pool from a [Value](rocket::figment::value::Value). + /// + /// It is up to each implementor of `Pool` to define its accepted + /// configuration value(s) via the `Config` associated type. Most + /// integrations provided in `sea_orm_rocket` use [`Config`], which + /// accepts a (required) `url` and an (optional) `pool_size`. + /// + /// ## Errors + /// + /// This method returns an error if the configuration is not compatible, or + /// if creating a pool failed due to an unavailable database server, + /// insufficient resources, or another database-specific error. + async fn init(figment: &Figment) -> Result; + + /// Borrow the inner connection + fn borrow(&self) -> &Self::Connection; +} + +#[derive(Debug)] +/// A mock object which impl `Pool`, for testing only +pub struct MockPool; + +#[derive(Debug)] +pub struct MockPoolErr; + +impl std::fmt::Display for MockPoolErr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl std::error::Error for MockPoolErr {} + +#[rocket::async_trait] +impl Pool for MockPool { + type Error = MockPoolErr; + + type Connection = bool; + + async fn init(_figment: &Figment) -> Result { + Ok(MockPool) + } + + fn borrow(&self) -> &Self::Connection { + &true + } +} From 8feca6be7b64f2bfbc0ae7af48e6d9b214073e07 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Tue, 12 Oct 2021 03:23:21 +0800 Subject: [PATCH 60/65] Edit --- examples/rocket_example/Cargo.toml | 3 ++- examples/rocket_example/src/pool.rs | 2 +- sea-orm-rocket/lib/src/database.rs | 4 ---- sea-orm-rocket/lib/src/pool.rs | 14 +++++++++++--- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/examples/rocket_example/Cargo.toml b/examples/rocket_example/Cargo.toml index c896fe2d..a4783cb0 100644 --- a/examples/rocket_example/Cargo.toml +++ b/examples/rocket_example/Cargo.toml @@ -27,7 +27,8 @@ features = ["macros", "runtime-tokio-native-tls"] default-features = false [dependencies.sea-orm-rocket] -path = "../../sea-orm-rocket/lib" +path = "../../sea-orm-rocket/lib" # remove this line in your own project +git = "https://github.com/SeaQL/sea-orm" [features] default = ["sqlx-postgres"] diff --git a/examples/rocket_example/src/pool.rs b/examples/rocket_example/src/pool.rs index 931a4712..afc8d48d 100644 --- a/examples/rocket_example/src/pool.rs +++ b/examples/rocket_example/src/pool.rs @@ -5,7 +5,7 @@ use sea_orm_rocket::{rocket::figment::Figment, Config, Database}; #[database("sea_orm")] pub struct Db(SeaOrmPool); -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct SeaOrmPool { pub conn: sea_orm::DatabaseConnection, } diff --git a/sea-orm-rocket/lib/src/database.rs b/sea-orm-rocket/lib/src/database.rs index 6eb98385..b3095ac9 100644 --- a/sea-orm-rocket/lib/src/database.rs +++ b/sea-orm-rocket/lib/src/database.rs @@ -192,10 +192,6 @@ impl Initializer { impl<'a, D: Database> Connection<'a, D> { /// Returns the internal connection value. See the [`Connection` Deref /// column](crate#supported-drivers) for the expected type of this value. - /// - /// Note that `Connection` derefs to the internal connection type, so - /// using this method is likely unnecessary. See [deref](Connection#deref) - /// for examples. pub fn into_inner(self) -> &'a ::Connection { self.0 } diff --git a/sea-orm-rocket/lib/src/pool.rs b/sea-orm-rocket/lib/src/pool.rs index 9dd8b75a..cbe2bb85 100644 --- a/sea-orm-rocket/lib/src/pool.rs +++ b/sea-orm-rocket/lib/src/pool.rs @@ -5,13 +5,21 @@ use rocket::figment::Figment; /// This trait provides a generic interface to various database pooling /// implementations in the Rust ecosystem. It can be implemented by anyone, but /// this crate provides implementations for common drivers. +/// +/// This is adapted from the original `rocket_db_pools`. But on top we require +/// `Connection` itself to be `Sync`. Hence, instead of cloning or allocating +/// a new connection per request, here we only borrow a reference to the pool. +/// +/// In SeaORM, only *when* you are about to execute a SQL statement will a +/// connection be acquired from the pool, and returned as soon as the query finishes. +/// This helps a bit with concurrency if the lifecycle of a request is long enough. /// ``` #[rocket::async_trait] pub trait Pool: Sized + Send + Sync + 'static { - /// The connection type managed by this pool, returned by [`Self::get()`]. + /// The connection type managed by this pool. type Connection; - /// The error type returned by [`Self::init()`] and [`Self::get()`]. + /// The error type returned by [`Self::init()`]. type Error: std::error::Error; /// Constructs a pool from a [Value](rocket::figment::value::Value). @@ -28,7 +36,7 @@ pub trait Pool: Sized + Send + Sync + 'static { /// insufficient resources, or another database-specific error. async fn init(figment: &Figment) -> Result; - /// Borrow the inner connection + /// Borrow a database connection fn borrow(&self) -> &Self::Connection; } From e57975930e09b382c891cc1df7f2f654defb4f9d Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Tue, 12 Oct 2021 12:26:35 +0800 Subject: [PATCH 61/65] Unit Test Rocket --- .github/workflows/rust.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 5a5e6e91..0b8a6ef0 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -144,6 +144,12 @@ jobs: args: > --all + - uses: actions-rs/cargo@v1 + with: + command: test + args: > + --path sea-orm-rocket + cli: name: CLI runs-on: ${{ matrix.os }} From f9dd9d242a627a4935b0fa7f15ac7d9ee0c0d82f Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Tue, 12 Oct 2021 12:32:25 +0800 Subject: [PATCH 62/65] Edit --- .github/workflows/rust.yml | 2 +- examples/rocket_example/Cargo.toml | 4 ++-- sea-orm-rocket/lib/src/pool.rs | 5 ++--- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 0b8a6ef0..82d1f6a9 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -148,7 +148,7 @@ jobs: with: command: test args: > - --path sea-orm-rocket + --manifest-path sea-orm-rocket/Cargo.toml cli: name: CLI diff --git a/examples/rocket_example/Cargo.toml b/examples/rocket_example/Cargo.toml index a4783cb0..c0834609 100644 --- a/examples/rocket_example/Cargo.toml +++ b/examples/rocket_example/Cargo.toml @@ -27,8 +27,8 @@ features = ["macros", "runtime-tokio-native-tls"] default-features = false [dependencies.sea-orm-rocket] -path = "../../sea-orm-rocket/lib" # remove this line in your own project -git = "https://github.com/SeaQL/sea-orm" +path = "../../sea-orm-rocket/lib" # remove this line in your own project and use the git line +# git = "https://github.com/SeaQL/sea-orm" [features] default = ["sqlx-postgres"] diff --git a/sea-orm-rocket/lib/src/pool.rs b/sea-orm-rocket/lib/src/pool.rs index cbe2bb85..bdd8e638 100644 --- a/sea-orm-rocket/lib/src/pool.rs +++ b/sea-orm-rocket/lib/src/pool.rs @@ -3,8 +3,7 @@ use rocket::figment::Figment; /// Generic [`Database`](crate::Database) driver connection pool trait. /// /// This trait provides a generic interface to various database pooling -/// implementations in the Rust ecosystem. It can be implemented by anyone, but -/// this crate provides implementations for common drivers. +/// implementations in the Rust ecosystem. It can be implemented by anyone. /// /// This is adapted from the original `rocket_db_pools`. But on top we require /// `Connection` itself to be `Sync`. Hence, instead of cloning or allocating @@ -36,7 +35,7 @@ pub trait Pool: Sized + Send + Sync + 'static { /// insufficient resources, or another database-specific error. async fn init(figment: &Figment) -> Result; - /// Borrow a database connection + /// Borrows a reference to the pool fn borrow(&self) -> &Self::Connection; } From 92795a022adf3c94e42275f1387ce3dc3076ef22 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Tue, 12 Oct 2021 13:58:26 +0800 Subject: [PATCH 63/65] Bump SeaQuery --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 8b813343..588658e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ futures-util = { version = "^0.3" } log = { version = "^0.4", optional = true } rust_decimal = { version = "^1", optional = true } sea-orm-macros = { version = "^0.2.5", path = "sea-orm-macros", optional = true } -sea-query = { version = "^0.17.0", git = "https://github.com/SeaQL/sea-query.git", branch = "from-value-tuple", features = ["thread-safe"] } +sea-query = { version = "^0.17.1", features = ["thread-safe"] } sea-strum = { version = "^0.21", features = ["derive", "sea-orm"] } serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } From 3a6e55ced1d13c06073067d96b449b5b17795d7e Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Tue, 12 Oct 2021 14:52:11 +0800 Subject: [PATCH 64/65] cargo fmt --- src/executor/insert.rs | 10 ++-------- src/executor/update.rs | 5 +---- src/query/util.rs | 12 ++++++------ 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/src/executor/insert.rs b/src/executor/insert.rs index d9230d11..e7437854 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -28,10 +28,7 @@ where A: ActiveModelTrait, { #[allow(unused_mut)] - pub fn exec<'a, C>( - self, - db: &'a C, - ) -> impl Future, DbErr>> + 'a + pub fn exec<'a, C>(self, db: &'a C) -> impl Future, DbErr>> + 'a where C: ConnectionTrait<'a>, A: 'a, @@ -64,10 +61,7 @@ where } } - pub fn exec<'a, C>( - self, - db: &'a C, - ) -> impl Future, DbErr>> + 'a + pub fn exec<'a, C>(self, db: &'a C) -> impl Future, DbErr>> + 'a where C: ConnectionTrait<'a>, A: 'a, diff --git a/src/executor/update.rs b/src/executor/update.rs index 9777ff79..3fb4a5c8 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -54,10 +54,7 @@ impl Updater { self } - pub fn exec<'a, C>( - self, - db: &'a C - ) -> impl Future> + '_ + pub fn exec<'a, C>(self, db: &'a C) -> impl Future> + '_ where C: ConnectionTrait<'a>, { diff --git a/src/query/util.rs b/src/query/util.rs index 545f8376..a6133725 100644 --- a/src/query/util.rs +++ b/src/query/util.rs @@ -23,14 +23,14 @@ macro_rules! debug_query_build { debug_query_build!(DbBackend, |x: &DebugQuery<_, DbBackend>| x.value); debug_query_build!(&DbBackend, |x: &DebugQuery<_, &DbBackend>| *x.value); -debug_query_build!( +debug_query_build!(DatabaseConnection, |x: &DebugQuery< + _, DatabaseConnection, - |x: &DebugQuery<_, DatabaseConnection>| x.value.get_database_backend() -); -debug_query_build!( +>| x.value.get_database_backend()); +debug_query_build!(&DatabaseConnection, |x: &DebugQuery< + _, &DatabaseConnection, - |x: &DebugQuery<_, &DatabaseConnection>| x.value.get_database_backend() -); +>| x.value.get_database_backend()); /// Helper to get a `Statement` from an object that impl `QueryTrait`. /// From 069040be8b9d89ac28f42430b10d043c3b7e3018 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Tue, 12 Oct 2021 14:56:22 +0800 Subject: [PATCH 65/65] Tweak lifetime --- src/executor/delete.rs | 7 +++---- src/executor/insert.rs | 5 ++--- src/executor/update.rs | 3 +-- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/executor/delete.rs b/src/executor/delete.rs index 2fa5f64a..c4a8c7de 100644 --- a/src/executor/delete.rs +++ b/src/executor/delete.rs @@ -18,7 +18,7 @@ impl<'a, A: 'a> DeleteOne where A: ActiveModelTrait, { - pub fn exec(self, db: &'a C) -> impl Future> + 'a + pub fn exec(self, db: &'a C) -> impl Future> + '_ where C: ConnectionTrait<'a>, { @@ -31,7 +31,7 @@ impl<'a, E> DeleteMany where E: EntityTrait, { - pub fn exec(self, db: &'a C) -> impl Future> + 'a + pub fn exec(self, db: &'a C) -> impl Future> + '_ where C: ConnectionTrait<'a>, { @@ -61,8 +61,7 @@ where Deleter::new(query).exec(db).await } -// Only Statement impl Send -async fn exec_delete<'a, C>(statement: Statement, db: &C) -> Result +async fn exec_delete<'a, C>(statement: Statement, db: &'a C) -> Result where C: ConnectionTrait<'a>, { diff --git a/src/executor/insert.rs b/src/executor/insert.rs index e7437854..b4da10c0 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -28,7 +28,7 @@ where A: ActiveModelTrait, { #[allow(unused_mut)] - pub fn exec<'a, C>(self, db: &'a C) -> impl Future, DbErr>> + 'a + pub fn exec<'a, C>(self, db: &'a C) -> impl Future, DbErr>> + '_ where C: ConnectionTrait<'a>, A: 'a, @@ -61,7 +61,7 @@ where } } - pub fn exec<'a, C>(self, db: &'a C) -> impl Future, DbErr>> + 'a + pub fn exec<'a, C>(self, db: &'a C) -> impl Future, DbErr>> + '_ where C: ConnectionTrait<'a>, A: 'a, @@ -71,7 +71,6 @@ where } } -// Only Statement impl Send async fn exec_insert<'a, A, C>( primary_key: Option, statement: Statement, diff --git a/src/executor/update.rs b/src/executor/update.rs index 3fb4a5c8..c228730b 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -32,7 +32,7 @@ impl<'a, E> UpdateMany where E: EntityTrait, { - pub fn exec(self, db: &'a C) -> impl Future> + 'a + pub fn exec(self, db: &'a C) -> impl Future> + '_ where C: ConnectionTrait<'a>, { @@ -83,7 +83,6 @@ where Ok(model) } -// Only Statement impl Send async fn exec_update<'a, C>( statement: Statement, db: &'a C,