diff --git a/Cargo.toml b/Cargo.toml index 2a34aa36..1690ef5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ futures = { version = "^0.3", default-features = false, features = ["std"] } log = { version = "^0.4", default-features = false } tracing = { version = "^0.1", default-features = false, features = ["attributes", "log"] } rust_decimal = { version = "^1", default-features = false, optional = true } +bigdecimal = { version = "^0.3", default-features = false, optional = true } sea-orm-macros = { version = "^0.10.3", path = "sea-orm-macros", default-features = false, optional = true } sea-query = { version = "^0.27.2", features = ["thread-safe"] } sea-query-binder = { version = "^0.2.2", default-features = false, optional = true } @@ -67,6 +68,7 @@ default = [ "with-json", "with-chrono", "with-rust_decimal", + "with-bigdecimal", "with-uuid", "with-time", ] @@ -75,6 +77,7 @@ mock = [] with-json = ["serde_json", "sea-query/with-json", "chrono?/serde", "time?/serde", "uuid?/serde", "sea-query-binder?/with-json", "sqlx?/json"] with-chrono = ["chrono", "sea-query/with-chrono", "sea-query-binder?/with-chrono", "sqlx?/chrono"] with-rust_decimal = ["rust_decimal", "sea-query/with-rust_decimal", "sea-query-binder?/with-rust_decimal", "sqlx?/decimal"] +with-bigdecimal = ["bigdecimal", "sea-query/with-bigdecimal", "sea-query-binder?/with-bigdecimal", "sqlx?/bigdecimal"] with-uuid = ["uuid", "sea-query/with-uuid", "sea-query-binder?/with-uuid", "sqlx?/uuid"] with-time = ["time", "sea-query/with-time", "sea-query-binder?/with-time", "sqlx?/time"] postgres-array = ["sea-query/postgres-array", "sea-query-binder?/postgres-array", "sea-orm-macros?/postgres-array"] diff --git a/src/entity/prelude.rs b/src/entity/prelude.rs index 1d68cba1..8fa4e3df 100644 --- a/src/entity/prelude.rs +++ b/src/entity/prelude.rs @@ -72,5 +72,8 @@ pub use time::OffsetDateTime as TimeDateTimeWithTimeZone; #[cfg(feature = "with-rust_decimal")] pub use rust_decimal::Decimal; +#[cfg(feature = "with-bigdecimal")] +pub use bigdecimal::BigDecimal; + #[cfg(feature = "with-uuid")] pub use uuid::Uuid; diff --git a/src/executor/query.rs b/src/executor/query.rs index 88087f53..63459686 100644 --- a/src/executor/query.rs +++ b/src/executor/query.rs @@ -359,6 +359,58 @@ impl TryGetable for Decimal { } } +#[cfg(feature = "with-bigdecimal")] +use bigdecimal::BigDecimal; + +#[cfg(feature = "with-bigdecimal")] +impl TryGetable for BigDecimal { + #[allow(unused_variables)] + fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result { + let column = format!("{}{}", pre, col); + match &res.row { + #[cfg(feature = "sqlx-mysql")] + QueryResultRow::SqlxMySql(row) => { + use sqlx::Row; + row.try_get::, _>(column.as_str()) + .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e))) + .and_then(|opt| opt.ok_or(TryGetError::Null(column))) + } + #[cfg(feature = "sqlx-postgres")] + QueryResultRow::SqlxPostgres(row) => { + use sqlx::Row; + row.try_get::, _>(column.as_str()) + .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e))) + .and_then(|opt| opt.ok_or(TryGetError::Null(column))) + } + #[cfg(feature = "sqlx-sqlite")] + QueryResultRow::SqlxSqlite(row) => { + use sqlx::Row; + let val: Option = row + .try_get(column.as_str()) + .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e)))?; + match val { + Some(v) => BigDecimal::try_from(v).map_err(|e| { + TryGetError::DbErr(DbErr::TryIntoErr { + from: "f64", + into: "BigDecimal", + source: Box::new(e), + }) + }), + None => Err(TryGetError::Null(column)), + } + } + #[cfg(feature = "mock")] + #[allow(unused_variables)] + QueryResultRow::Mock(row) => row.try_get(column.as_str()).map_err(|e| { + debug_print!("{:#?}", e.to_string()); + TryGetError::Null(column) + }), + #[allow(unreachable_patterns)] + _ => unreachable!(), + } + } +} + #[cfg(feature = "with-uuid")] try_getable_all!(uuid::Uuid); @@ -489,6 +541,9 @@ mod postgres_array { #[cfg(feature = "with-rust_decimal")] try_getable_postgres_array!(rust_decimal::Decimal); + #[cfg(feature = "with-bigdecimal")] + try_getable_postgres_array!(bigdecimal::BigDecimal); + #[cfg(feature = "with-uuid")] try_getable_postgres_array!(uuid::Uuid); diff --git a/tests/common/features/mod.rs b/tests/common/features/mod.rs index 1e90f5c7..ea6a39bc 100644 --- a/tests/common/features/mod.rs +++ b/tests/common/features/mod.rs @@ -8,6 +8,7 @@ pub mod insert_default; pub mod json_struct; pub mod json_vec; pub mod metadata; +pub mod pi; pub mod repository; pub mod satellite; pub mod schema; @@ -24,6 +25,7 @@ pub use insert_default::Entity as InsertDefault; pub use json_struct::Entity as JsonStruct; pub use json_vec::Entity as JsonVec; pub use metadata::Entity as Metadata; +pub use pi::Entity as Pi; pub use repository::Entity as Repository; pub use satellite::Entity as Satellite; pub use schema::*; diff --git a/tests/common/features/pi.rs b/tests/common/features/pi.rs new file mode 100644 index 00000000..0324382a --- /dev/null +++ b/tests/common/features/pi.rs @@ -0,0 +1,21 @@ +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "pi")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + #[sea_orm(column_type = "Decimal(Some((11, 10)))")] + pub decimal: Decimal, + #[sea_orm(column_type = "Decimal(Some((11, 10)))")] + pub big_decimal: BigDecimal, + #[sea_orm(column_type = "Decimal(Some((11, 10)))")] + pub decimal_opt: Option, + #[sea_orm(column_type = "Decimal(Some((11, 10)))")] + pub big_decimal_opt: Option, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/tests/common/features/schema.rs b/tests/common/features/schema.rs index 15af0172..13ca8c23 100644 --- a/tests/common/features/schema.rs +++ b/tests/common/features/schema.rs @@ -41,6 +41,7 @@ pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> { create_active_enum_table(db).await?; create_active_enum_child_table(db).await?; create_insert_default_table(db).await?; + create_pi_table(db).await?; if DbBackend::Postgres == db_backend { create_collection_table(db).await?; @@ -383,3 +384,30 @@ pub async fn create_collection_table(db: &DbConn) -> Result { create_table(db, &stmt, Collection).await } + +pub async fn create_pi_table(db: &DbConn) -> Result { + let stmt = sea_query::Table::create() + .table(pi::Entity) + .col( + ColumnDef::new(pi::Column::Id) + .integer() + .not_null() + .auto_increment() + .primary_key(), + ) + .col( + ColumnDef::new(pi::Column::Decimal) + .decimal_len(11, 10) + .not_null(), + ) + .col( + ColumnDef::new(pi::Column::BigDecimal) + .decimal_len(11, 10) + .not_null(), + ) + .col(ColumnDef::new(pi::Column::DecimalOpt).decimal_len(11, 10)) + .col(ColumnDef::new(pi::Column::BigDecimalOpt).decimal_len(11, 10)) + .to_owned(); + + create_table(db, &stmt, Pi).await +} diff --git a/tests/pi_tests.rs b/tests/pi_tests.rs new file mode 100644 index 00000000..0d344cee --- /dev/null +++ b/tests/pi_tests.rs @@ -0,0 +1,61 @@ +pub mod common; + +use common::{features::*, TestContext}; +use pretty_assertions::assert_eq; +use rust_decimal_macros::dec; +use sea_orm::{entity::prelude::*, entity::*, DatabaseConnection}; +use std::str::FromStr; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +async fn main() -> Result<(), DbErr> { + let ctx = TestContext::new("pi_tests").await; + create_tables(&ctx.db).await?; + create_and_update_pi(&ctx.db).await?; + ctx.delete().await; + + Ok(()) +} + +pub async fn create_and_update_pi(db: &DatabaseConnection) -> Result<(), DbErr> { + let pi = pi::Model { + id: 1, + decimal: dec!(3.1415926536), + big_decimal: BigDecimal::from_str("3.1415926536").unwrap(), + decimal_opt: None, + big_decimal_opt: None, + }; + + let res = pi.clone().into_active_model().insert(db).await?; + + let model = Pi::find().one(db).await?; + assert_eq!(model, Some(res)); + assert_eq!(model, Some(pi.clone())); + + let res = pi::ActiveModel { + decimal_opt: Set(Some(dec!(3.1415926536))), + big_decimal_opt: Set(Some(BigDecimal::from_str("3.1415926536").unwrap())), + ..pi.clone().into_active_model() + } + .update(db) + .await?; + + let model = Pi::find().one(db).await?; + assert_eq!(model, Some(res)); + assert_eq!( + model, + Some(pi::Model { + id: 1, + decimal: dec!(3.1415926536), + big_decimal: BigDecimal::from_str("3.1415926536").unwrap(), + decimal_opt: Some(dec!(3.1415926536)), + big_decimal_opt: Some(BigDecimal::from_str("3.1415926536").unwrap()), + }) + ); + + Ok(()) +}