From d664985ea9fd1a80dd49c1ee798b807ca959eff8 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Wed, 1 Sep 2021 23:24:43 +0800 Subject: [PATCH] WIP --- sea-orm-codegen/src/entity/base_entity.rs | 25 +++++++--- src/executor/insert.rs | 2 +- src/executor/query.rs | 56 ++++++++++------------- 3 files changed, 45 insertions(+), 38 deletions(-) diff --git a/sea-orm-codegen/src/entity/base_entity.rs b/sea-orm-codegen/src/entity/base_entity.rs index 2cc6dbb4..7b28f9e4 100644 --- a/sea-orm-codegen/src/entity/base_entity.rs +++ b/sea-orm-codegen/src/entity/base_entity.rs @@ -118,12 +118,25 @@ impl Entity { } pub fn get_primary_key_rs_type(&self) -> TokenStream { - if let Some(primary_key) = self.primary_keys.first() { - self.columns - .iter() - .find(|col| col.name.eq(&primary_key.name)) - .unwrap() - .get_rs_type() + let types = self + .primary_keys + .iter() + .map(|primary_key| { + self.columns + .iter() + .find(|col| col.name.eq(&primary_key.name)) + .unwrap() + .get_rs_type() + .to_string() + }) + .collect::>(); + if !types.is_empty() { + let value_type = if types.len() > 1 { + vec!["(".to_owned(), types.join(", "), ")".to_owned()] + } else { + types + }; + value_type.join("").parse().unwrap() } else { TokenStream::new() } diff --git a/src/executor/insert.rs b/src/executor/insert.rs index ce8f32d7..a0577e4b 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -88,7 +88,7 @@ where #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(conn) => { let res = conn.query_one(statement).await?.unwrap(); - res.try_get("", "last_insert_id").unwrap_or_default() + res.try_get_many("", "last_insert_id").unwrap_or_default() } _ => { let last_insert_id = db.execute(statement).await?.last_insert_id(); diff --git a/src/executor/query.rs b/src/executor/query.rs index 7be17b4b..407e963c 100644 --- a/src/executor/query.rs +++ b/src/executor/query.rs @@ -262,9 +262,12 @@ impl TryGetable for Decimal { .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e)))?; use rust_decimal::prelude::FromPrimitive; match val { - Some(v) => Decimal::from_f64(v) - .ok_or_else(|| TryGetError::DbErr(DbErr::Query("Failed to convert f64 into Decimal".to_owned()))), - None => Err(TryGetError::Null) + Some(v) => Decimal::from_f64(v).ok_or_else(|| { + TryGetError::DbErr(DbErr::Query( + "Failed to convert f64 into Decimal".to_owned(), + )) + }), + None => Err(TryGetError::Null), } } #[cfg(feature = "mock")] @@ -282,22 +285,15 @@ try_getable_all!(uuid::Uuid); // TryGetableMany // pub trait TryGetableMany: Sized { - fn try_get_many(res: &QueryResult, pre: &str, cols: &[String]) -> Result; + fn try_get_many(res: &QueryResult, pre: &str, cols: &[String]) -> Result; } impl TryGetableMany for T where T: TryGetable, { - fn try_get_many(res: &QueryResult, pre: &str, cols: &[String]) -> Result { - let expect_len = 1; - if cols.len() < expect_len { - return Err(DbErr::Query(format!( - "Expect {} column names supplied but got slice of length {}", - expect_len, - cols.len() - ))); - } + fn try_get_many(res: &QueryResult, pre: &str, cols: &[String]) -> Result { + try_get_many_with_slice_len_of(1, cols)?; T::try_get(res, pre, &cols[0]) } } @@ -306,15 +302,8 @@ impl TryGetableMany for (T, T) where T: TryGetable, { - fn try_get_many(res: &QueryResult, pre: &str, cols: &[String]) -> Result { - let expect_len = 2; - if cols.len() < expect_len { - return Err(DbErr::Query(format!( - "Expect {} column names supplied but got slice of length {}", - expect_len, - cols.len() - ))); - } + fn try_get_many(res: &QueryResult, pre: &str, cols: &[String]) -> Result { + try_get_many_with_slice_len_of(2, cols)?; Ok(( T::try_get(res, pre, &cols[0])?, T::try_get(res, pre, &cols[1])?, @@ -326,15 +315,8 @@ impl TryGetableMany for (T, T, T) where T: TryGetable, { - fn try_get_many(res: &QueryResult, pre: &str, cols: &[String]) -> Result { - let expect_len = 3; - if cols.len() < expect_len { - return Err(DbErr::Query(format!( - "Expect {} column names supplied but got slice of length {}", - expect_len, - cols.len() - ))); - } + fn try_get_many(res: &QueryResult, pre: &str, cols: &[String]) -> Result { + try_get_many_with_slice_len_of(3, cols)?; Ok(( T::try_get(res, pre, &cols[0])?, T::try_get(res, pre, &cols[1])?, @@ -343,6 +325,18 @@ where } } +fn try_get_many_with_slice_len_of(len: usize, cols: &[String]) -> Result<(), TryGetError> { + if cols.len() < len { + Err(TryGetError::DbErr(DbErr::Query(format!( + "Expect {} column names supplied but got slice of length {}", + len, + cols.len() + )))) + } else { + Ok(()) + } +} + // TryFromU64 // pub trait TryFromU64: Sized {