diff --git a/examples/sqlx-mysql/src/main.rs b/examples/sqlx-mysql/src/main.rs index eb2c79ce..e180e10a 100644 --- a/examples/sqlx-mysql/src/main.rs +++ b/examples/sqlx-mysql/src/main.rs @@ -5,12 +5,14 @@ mod example_cake_filling; mod example_filling; mod example_fruit; mod select; +mod operation; use example_cake as cake; use example_cake_filling as cake_filling; use example_filling as filling; use example_fruit as fruit; use select::*; +use operation::*; #[async_std::main] async fn main() { @@ -25,4 +27,8 @@ async fn main() { println!("===== =====\n"); all_about_select(&db).await.unwrap(); + + println!("===== =====\n"); + + all_about_operation(&db).await.unwrap(); } diff --git a/examples/sqlx-mysql/src/operation.rs b/examples/sqlx-mysql/src/operation.rs new file mode 100644 index 00000000..e6ef0365 --- /dev/null +++ b/examples/sqlx-mysql/src/operation.rs @@ -0,0 +1,22 @@ +use crate::*; +use sea_orm::{entity::*, query::*, Database}; + +pub async fn all_about_operation(db: &Database) -> Result<(), ExecErr> { + let banana = fruit::ActiveModel { + name: Val::set("banana".to_owned()), + ..Default::default() + }; + let mut banana = banana.save(db).await?; + + println!(); + println!("Inserted: {:?}\n", banana); + + banana.name = Val::set("banana banana".to_owned()); + + let banana = banana.save(db).await?; + + println!(); + println!("Updated: {:?}\n", banana); + + Ok(()) +} \ No newline at end of file diff --git a/sea-orm-macros/src/derives/active_model.rs b/sea-orm-macros/src/derives/active_model.rs index 4a394623..081c3921 100644 --- a/sea-orm-macros/src/derives/active_model.rs +++ b/sea-orm-macros/src/derives/active_model.rs @@ -36,6 +36,12 @@ pub fn expand_derive_active_model(ident: Ident, data: Data) -> syn::Result),* } + impl ActiveModel { + pub async fn save(self, db: &sea_orm::Database) -> Result { + sea_orm::save_active_model::(self, db).await + } + } + impl Default for ActiveModel { fn default() -> Self { ::new() @@ -77,6 +83,12 @@ pub fn expand_derive_active_model(ident: Ident, data: Data) -> syn::Result::Column) -> bool { + match c { + #(::Column::#name => self.#field.is_unset()),* + } + } + fn default() -> Self { Self { #(#field: sea_orm::ActiveValue::unset()),* diff --git a/src/connector/insert.rs b/src/connector/insert.rs index 1debda32..0488be9c 100644 --- a/src/connector/insert.rs +++ b/src/connector/insert.rs @@ -1,5 +1,6 @@ -use crate::{Connection, Database, ExecErr, Statement}; +use crate::{ActiveModelTrait, Connection, Database, ExecErr, Insert, QueryTrait, Statement}; use sea_query::{InsertStatement, QueryBuilder}; +use std::future::Future; #[derive(Clone, Debug)] pub struct Inserter { @@ -11,6 +12,16 @@ pub struct InsertResult { pub last_insert_id: u64, } +impl Insert +where + A: ActiveModelTrait, +{ + pub fn exec(self, db: &Database) -> impl Future> + '_ { + // so that self is dropped before entering await + Inserter::new(self.into_query()).exec(db) + } +} + impl Inserter { pub fn new(query: InsertStatement) -> Self { Self { query } @@ -23,12 +34,17 @@ impl Inserter { self.query.build(builder).into() } - pub async fn exec(self, db: &Database) -> Result { + pub fn exec(self, db: &Database) -> impl Future> + '_ { let builder = db.get_query_builder_backend(); - let result = db.get_connection().execute(self.build(builder)).await?; - // TODO: Postgres instead use query_one + returning clause - Ok(InsertResult { - last_insert_id: result.last_insert_id(), - }) + exec_insert(self.build(builder), db) } } + +// Only Statement impl Send +async fn exec_insert(statement: Statement, db: &Database) -> Result { + let result = db.get_connection().execute(statement).await?; + // TODO: Postgres instead use query_one + returning clause + Ok(InsertResult { + last_insert_id: result.last_insert_id(), + }) +} diff --git a/src/connector/mod.rs b/src/connector/mod.rs index fcebe06d..4027e650 100644 --- a/src/connector/mod.rs +++ b/src/connector/mod.rs @@ -1,10 +1,12 @@ mod executor; mod insert; mod select; +mod update; pub use executor::*; pub use insert::*; pub use select::*; +pub use update::*; use crate::{DatabaseConnection, QueryResult, Statement, TypeErr}; use async_trait::async_trait; diff --git a/src/connector/update.rs b/src/connector/update.rs new file mode 100644 index 00000000..fef52268 --- /dev/null +++ b/src/connector/update.rs @@ -0,0 +1,61 @@ +use crate::{ActiveModelTrait, Connection, Database, ExecErr, Statement, Update}; +use sea_query::{QueryBuilder, UpdateStatement}; +use std::future::Future; + +#[derive(Clone, Debug)] +pub struct Updater { + query: UpdateStatement, +} + +#[derive(Clone, Debug)] +pub struct UpdateResult { + pub rows_affected: u64, +} + +impl<'a, A: 'a> Update +where + A: ActiveModelTrait, +{ + pub fn exec(self, db: &'a Database) -> impl Future> + 'a { + // so that self is dropped before entering await + exec_update_and_return_original(self.query, self.model, db) + } +} + +impl Updater { + pub fn new(query: UpdateStatement) -> Self { + Self { query } + } + + pub fn build(&self, builder: B) -> Statement + where + B: QueryBuilder, + { + self.query.build(builder).into() + } + + pub fn exec(self, db: &Database) -> impl Future> + '_ { + let builder = db.get_query_builder_backend(); + exec_update(self.build(builder), db) + } +} + +async fn exec_update_and_return_original( + query: UpdateStatement, + model: A, + db: &Database, +) -> Result +where + A: ActiveModelTrait, +{ + Updater::new(query).exec(db).await?; + Ok(model) +} + +// Only Statement impl Send +async fn exec_update(statement: Statement, db: &Database) -> Result { + let result = db.get_connection().execute(statement).await?; + Ok(UpdateResult { + rows_affected: result.rows_affected(), + }) +} diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index 2f7ae3d9..d851781f 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -1,4 +1,5 @@ -use crate::{EntityTrait, Value}; +use crate::{Database, EntityTrait, ExecErr, Iterable, PrimaryKeyToColumn, Value}; +use async_trait::async_trait; use std::fmt::Debug; #[derive(Clone, Debug, Default)] @@ -44,6 +45,7 @@ where ActiveValue::unchanged(value) } +#[async_trait] pub trait ActiveModelTrait: Clone + Debug { type Entity: EntityTrait; @@ -55,15 +57,29 @@ pub trait ActiveModelTrait: Clone + Debug { fn unset(&mut self, c: ::Column); + fn is_unset(&self, c: ::Column) -> bool; + fn default() -> Self; } +/// Behaviors for users to override pub trait ActiveModelBehavior: ActiveModelTrait { type Entity: EntityTrait; + /// Create a new ActiveModel with default values. Also used by `Default::default()`. fn new() -> Self { ::default() } + + /// Will be called before saving to database + fn before_save(self) -> Self { + self + } + + /// Will be called after saving to database + fn after_save(self) -> Self { + self + } } impl ActiveValue @@ -119,7 +135,7 @@ where pub fn into_wrapped_value(self) -> ActiveValue { match self.state { ActiveValueState::Set => ActiveValue::set(self.into_value()), - ActiveValueState::Unchanged => ActiveValue::set(self.into_value()), + ActiveValueState::Unchanged => ActiveValue::unchanged(self.into_value()), ActiveValueState::Unset => ActiveValue::unset(), } } @@ -171,3 +187,53 @@ where self } } + +/// Insert the model if primary key is unset, update otherwise +pub async fn save_active_model(mut am: A, db: &Database) -> Result +where + A: ActiveModelBehavior + ActiveModelTrait + From, + E: EntityTrait, +{ + am = ActiveModelBehavior::before_save(am); + let mut is_update = true; + for key in E::PrimaryKey::iter() { + let col = key.into_column(); + if am.is_unset(col) { + is_update = false; + break; + } + } + if !is_update { + am = insert_and_select_active_model::(am, db).await?; + } else { + am = update_active_model::(am, db).await?; + } + am = ActiveModelBehavior::after_save(am); + Ok(am) +} + +async fn insert_and_select_active_model(am: A, db: &Database) -> Result +where + A: ActiveModelTrait + From, + E: EntityTrait, +{ + let exec = E::insert(am).exec(db); + let res = exec.await?; + if res.last_insert_id != 0 { + let find = E::find_by(res.last_insert_id).one(db); + let res = find.await; + let model: E::Model = res.map_err(|_| ExecErr)?; + Ok(model.into()) + } else { + Ok(A::default()) + } +} + +async fn update_active_model(am: A, db: &Database) -> Result +where + A: ActiveModelTrait, + E: EntityTrait, +{ + let exec = E::update(am).exec(db); + exec.await +} diff --git a/src/query/insert.rs b/src/query/insert.rs index 302775ac..24d52e0f 100644 --- a/src/query/insert.rs +++ b/src/query/insert.rs @@ -50,7 +50,7 @@ where } else if self.columns[idx] != av.is_set() { panic!("columns mismatch"); } - if av.is_set() { + if av.is_set() || av.is_unchanged() { columns.push(col); values.push(av.into_value()); } diff --git a/src/query/mod.rs b/src/query/mod.rs index f064dab1..c4df1b7c 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -19,3 +19,5 @@ pub use result::*; pub use select::*; pub use traits::*; pub use update::*; + +pub use crate::connector::{QueryErr, ExecErr}; \ No newline at end of file diff --git a/src/query/update.rs b/src/query/update.rs index 7e40f1c0..1ac8eb8d 100644 --- a/src/query/update.rs +++ b/src/query/update.rs @@ -121,7 +121,7 @@ mod tests { assert_eq!( Update::::new(fruit::ActiveModel { id: Val::set(2), - name: Val::unset(), + name: Val::unchanged("Apple".to_owned()), cake_id: Val::set(Some(3)), }) .build(PostgresQueryBuilder)