diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 6b11432a..c37cac7c 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -5,6 +5,7 @@ on: push: branches: - master + - 0.2.x env: CARGO_TERM_COLOR: always @@ -143,6 +144,12 @@ jobs: args: > --all + - uses: actions-rs/cargo@v1 + with: + command: test + args: > + --manifest-path sea-orm-rocket/Cargo.toml + cli: name: CLI runs-on: ${{ matrix.os }} @@ -170,7 +177,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - path: [async-std, tokio, actix_example, actix4_example, rocket_example] + path: [basic, actix_example, actix4_example, rocket_example] steps: - uses: actions/checkout@v2 @@ -186,6 +193,28 @@ jobs: args: > --manifest-path examples/${{ matrix.path }}/Cargo.toml + issues: + name: Issues + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + path: [86] + steps: + - uses: actions/checkout@v2 + + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + + - uses: actions-rs/cargo@v1 + with: + command: build + args: > + --manifest-path issues/${{ matrix.path }}/Cargo.toml + sqlite: name: SQLite runs-on: ubuntu-20.04 diff --git a/CHANGELOG.md b/CHANGELOG.md index 0398854a..ee544831 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,23 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/) and this project adheres to [Semantic Versioning](http://semver.org/). +## 0.2.6 - 2021-10-09 + +- [[#224]] [sea-orm-cli] Date & Time column type mapping +- Escape rust keywords with `r#` raw identifier + +[#224]: https://github.com/SeaQL/sea-orm/pull/224 + +## 0.2.5 - 2021-10-06 + +- [[#227]] Resolve "Inserting actual none value of Option results in panic" +- [[#219]] [sea-orm-cli] Add `--tables` option +- [[#189]] Add `debug_query` and `debug_query_stmt` macro + +[#227]: https://github.com/SeaQL/sea-orm/issues/227 +[#219]: https://github.com/SeaQL/sea-orm/pull/219 +[#189]: https://github.com/SeaQL/sea-orm/pull/189 + ## 0.2.4 - 2021-10-01 - [[#186]] [sea-orm-cli] Foreign key handling diff --git a/Cargo.toml b/Cargo.toml index a4466152..28bc14e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = [".", "sea-orm-macros", "sea-orm-codegen"] [package] name = "sea-orm" -version = "0.2.4" +version = "0.2.6" authors = ["Chris Tsang "] edition = "2018" description = "🐚 An async & dynamic ORM for Rust" @@ -29,13 +29,14 @@ futures = { version = "^0.3" } futures-util = { version = "^0.3" } log = { version = "^0.4", optional = true } rust_decimal = { version = "^1", optional = true } -sea-orm-macros = { version = "^0.2.4", path = "sea-orm-macros", optional = true } -sea-query = { version = "^0.16.5", features = ["thread-safe"] } +sea-orm-macros = { version = "^0.2.6", path = "sea-orm-macros", optional = true } +sea-query = { version = "^0.17.1", features = ["thread-safe"] } sea-strum = { version = "^0.21", features = ["derive", "sea-orm"] } serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } sqlx = { version = "^0.5", optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true } +ouroboros = "0.11" [dev-dependencies] smol = { version = "^1.2" } diff --git a/README.md b/README.md index b4525b31..d6bfcdff 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ SeaORM is a relational ORM to help you build light weight and concurrent web services in Rust. [![Getting Started](https://img.shields.io/badge/Getting%20Started-brightgreen)](https://www.sea-ql.org/SeaORM/docs/index) -[![Usage Example](https://img.shields.io/badge/Usage%20Example-yellow)](https://github.com/SeaQL/sea-orm/tree/master/examples/async-std) +[![Usage Example](https://img.shields.io/badge/Usage%20Example-yellow)](https://github.com/SeaQL/sea-orm/tree/master/examples/basic) [![Actix Example](https://img.shields.io/badge/Actix%20Example-blue)](https://github.com/SeaQL/sea-orm/tree/master/examples/actix_example) [![Rocket Example](https://img.shields.io/badge/Rocket%20Example-orange)](https://github.com/SeaQL/sea-orm/tree/master/examples/rocket_example) [![Discord](https://img.shields.io/discord/873880840487206962?label=Discord)](https://discord.com/invite/uCPdDXzbdv) diff --git a/examples/actix4_example/src/setup.rs b/examples/actix4_example/src/setup.rs index 034e8b53..04677af4 100644 --- a/examples/actix4_example/src/setup.rs +++ b/examples/actix4_example/src/setup.rs @@ -1,5 +1,5 @@ use sea_orm::sea_query::{ColumnDef, TableCreateStatement}; -use sea_orm::{error::*, sea_query, DbConn, ExecResult}; +use sea_orm::{error::*, sea_query, ConnectionTrait, DbConn, ExecResult}; async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result { let builder = db.get_database_backend(); diff --git a/examples/actix_example/src/setup.rs b/examples/actix_example/src/setup.rs index 034e8b53..04677af4 100644 --- a/examples/actix_example/src/setup.rs +++ b/examples/actix_example/src/setup.rs @@ -1,5 +1,5 @@ use sea_orm::sea_query::{ColumnDef, TableCreateStatement}; -use sea_orm::{error::*, sea_query, DbConn, ExecResult}; +use sea_orm::{error::*, sea_query, ConnectionTrait, DbConn, ExecResult}; async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result { let builder = db.get_database_backend(); diff --git a/examples/async-std/Cargo.toml b/examples/basic/Cargo.toml similarity index 100% rename from examples/async-std/Cargo.toml rename to examples/basic/Cargo.toml diff --git a/examples/async-std/Readme.md b/examples/basic/Readme.md similarity index 100% rename from examples/async-std/Readme.md rename to examples/basic/Readme.md diff --git a/examples/async-std/bakery.sql b/examples/basic/bakery.sql similarity index 100% rename from examples/async-std/bakery.sql rename to examples/basic/bakery.sql diff --git a/examples/async-std/import.sh b/examples/basic/import.sh similarity index 100% rename from examples/async-std/import.sh rename to examples/basic/import.sh diff --git a/examples/async-std/src/entities.rs b/examples/basic/src/entities.rs similarity index 100% rename from examples/async-std/src/entities.rs rename to examples/basic/src/entities.rs diff --git a/examples/async-std/src/example_cake.rs b/examples/basic/src/example_cake.rs similarity index 100% rename from examples/async-std/src/example_cake.rs rename to examples/basic/src/example_cake.rs diff --git a/examples/async-std/src/example_cake_filling.rs b/examples/basic/src/example_cake_filling.rs similarity index 100% rename from examples/async-std/src/example_cake_filling.rs rename to examples/basic/src/example_cake_filling.rs diff --git a/examples/async-std/src/example_filling.rs b/examples/basic/src/example_filling.rs similarity index 100% rename from examples/async-std/src/example_filling.rs rename to examples/basic/src/example_filling.rs diff --git a/examples/async-std/src/example_fruit.rs b/examples/basic/src/example_fruit.rs similarity index 100% rename from examples/async-std/src/example_fruit.rs rename to examples/basic/src/example_fruit.rs diff --git a/examples/async-std/src/main.rs b/examples/basic/src/main.rs similarity index 100% rename from examples/async-std/src/main.rs rename to examples/basic/src/main.rs diff --git a/examples/async-std/src/operation.rs b/examples/basic/src/operation.rs similarity index 100% rename from examples/async-std/src/operation.rs rename to examples/basic/src/operation.rs diff --git a/examples/async-std/src/select.rs b/examples/basic/src/select.rs similarity index 100% rename from examples/async-std/src/select.rs rename to examples/basic/src/select.rs diff --git a/examples/rocket_example/Cargo.toml b/examples/rocket_example/Cargo.toml index 22dc9a6b..c0834609 100644 --- a/examples/rocket_example/Cargo.toml +++ b/examples/rocket_example/Cargo.toml @@ -15,15 +15,22 @@ futures-util = { version = "^0.3" } rocket = { git = "https://github.com/SergioBenitez/Rocket.git", features = [ "json", ] } -rocket_db_pools = { git = "https://github.com/SergioBenitez/Rocket.git" } rocket_dyn_templates = { git = "https://github.com/SergioBenitez/Rocket.git", features = [ "tera", ] } -# remove `path = ""` in your own project -sea-orm = { path = "../../", version = "^0.2.3", features = ["macros"], default-features = false } serde_json = { version = "^1" } +[dependencies.sea-orm] +path = "../../" # remove this line in your own project +version = "^0.2.3" +features = ["macros", "runtime-tokio-native-tls"] +default-features = false + +[dependencies.sea-orm-rocket] +path = "../../sea-orm-rocket/lib" # remove this line in your own project and use the git line +# git = "https://github.com/SeaQL/sea-orm" + [features] default = ["sqlx-postgres"] -sqlx-mysql = ["sea-orm/sqlx-mysql", "rocket_db_pools/sqlx_mysql"] -sqlx-postgres = ["sea-orm/sqlx-postgres", "rocket_db_pools/sqlx_postgres"] +sqlx-mysql = ["sea-orm/sqlx-mysql"] +sqlx-postgres = ["sea-orm/sqlx-postgres"] diff --git a/examples/rocket_example/Rocket.toml b/examples/rocket_example/Rocket.toml index b7fcc12a..fc294bd2 100644 --- a/examples/rocket_example/Rocket.toml +++ b/examples/rocket_example/Rocket.toml @@ -1,7 +1,7 @@ [default] template_dir = "templates/" -[default.databases.rocket_example] +[default.databases.sea_orm] # Mysql # make sure to enable "sqlx-mysql" feature in Cargo.toml, i.e default = ["sqlx-mysql"] # url = "mysql://root:@localhost/rocket_example" diff --git a/examples/rocket_example/src/main.rs b/examples/rocket_example/src/main.rs index 853eaaaa..e2ef9254 100644 --- a/examples/rocket_example/src/main.rs +++ b/examples/rocket_example/src/main.rs @@ -7,21 +7,17 @@ use rocket::fs::{relative, FileServer}; use rocket::request::FlashMessage; use rocket::response::{Flash, Redirect}; use rocket::{Build, Request, Rocket}; -use rocket_db_pools::{sqlx, Connection, Database}; use rocket_dyn_templates::{context, Template}; use sea_orm::{entity::*, query::*}; +use sea_orm_rocket::{Connection, Database}; mod pool; -use pool::RocketDbPool; +use pool::Db; mod setup; -#[derive(Database, Debug)] -#[database("rocket_example")] -struct Db(RocketDbPool); - -type Result> = std::result::Result; +type Result> = std::result::Result; mod post; pub use post::Entity as Post; @@ -34,7 +30,9 @@ async fn new() -> Template { } #[post("/", data = "")] -async fn create(conn: Connection, post_form: Form) -> Flash { +async fn create(conn: Connection<'_, Db>, post_form: Form) -> Flash { + let db = conn.into_inner(); + let form = post_form.into_inner(); post::ActiveModel { @@ -42,7 +40,7 @@ async fn create(conn: Connection, post_form: Form) -> Flash, post_form: Form) -> Flash", data = "")] -async fn update(conn: Connection, id: i32, post_form: Form) -> Flash { +async fn update(conn: Connection<'_, Db>, id: i32, post_form: Form) -> Flash { + let db = conn.into_inner(); + let post: post::ActiveModel = Post::find_by_id(id) - .one(&conn) + .one(db) .await .unwrap() .unwrap() @@ -65,7 +65,7 @@ async fn update(conn: Connection, id: i32, post_form: Form) -> title: Set(form.title.to_owned()), text: Set(form.text.to_owned()), } - .save(&conn) + .save(db) .await .expect("could not edit post"); @@ -74,11 +74,13 @@ async fn update(conn: Connection, id: i32, post_form: Form) -> #[get("/?&")] async fn list( - conn: Connection, + conn: Connection<'_, Db>, posts_per_page: Option, page: Option, flash: Option>, ) -> Template { + let db = conn.into_inner(); + // Set page number and items per page let page = page.unwrap_or(1); let posts_per_page = posts_per_page.unwrap_or(DEFAULT_POSTS_PER_PAGE); @@ -89,7 +91,7 @@ async fn list( // Setup paginator let paginator = Post::find() .order_by_asc(post::Column::Id) - .paginate(&conn, posts_per_page); + .paginate(db, posts_per_page); let num_pages = paginator.num_pages().await.ok().unwrap(); // Fetch paginated posts @@ -111,9 +113,11 @@ async fn list( } #[get("/")] -async fn edit(conn: Connection, id: i32) -> Template { +async fn edit(conn: Connection<'_, Db>, id: i32) -> Template { + let db = conn.into_inner(); + let post: Option = Post::find_by_id(id) - .one(&conn) + .one(db) .await .expect("could not find post"); @@ -126,22 +130,26 @@ async fn edit(conn: Connection, id: i32) -> Template { } #[delete("/")] -async fn delete(conn: Connection, id: i32) -> Flash { +async fn delete(conn: Connection<'_, Db>, id: i32) -> Flash { + let db = conn.into_inner(); + let post: post::ActiveModel = Post::find_by_id(id) - .one(&conn) + .one(db) .await .unwrap() .unwrap() .into(); - post.delete(&conn).await.unwrap(); + post.delete(db).await.unwrap(); Flash::success(Redirect::to("/"), "Post successfully deleted.") } #[delete("/")] -async fn destroy(conn: Connection) -> Result<()> { - Post::delete_many().exec(&conn).await.unwrap(); +async fn destroy(conn: Connection<'_, Db>) -> Result<()> { + let db = conn.into_inner(); + + Post::delete_many().exec(db).await.unwrap(); Ok(()) } diff --git a/examples/rocket_example/src/pool.rs b/examples/rocket_example/src/pool.rs index 7c8e37cd..afc8d48d 100644 --- a/examples/rocket_example/src/pool.rs +++ b/examples/rocket_example/src/pool.rs @@ -1,13 +1,17 @@ use async_trait::async_trait; -use rocket_db_pools::{rocket::figment::Figment, Config}; +use sea_orm_rocket::{rocket::figment::Figment, Config, Database}; -#[derive(Debug)] -pub struct RocketDbPool { +#[derive(Database, Debug)] +#[database("sea_orm")] +pub struct Db(SeaOrmPool); + +#[derive(Debug, Clone)] +pub struct SeaOrmPool { pub conn: sea_orm::DatabaseConnection, } #[async_trait] -impl rocket_db_pools::Pool for RocketDbPool { +impl sea_orm_rocket::Pool for SeaOrmPool { type Error = sea_orm::DbErr; type Connection = sea_orm::DatabaseConnection; @@ -16,12 +20,10 @@ impl rocket_db_pools::Pool for RocketDbPool { let config = figment.extract::().unwrap(); let conn = sea_orm::Database::connect(&config.url).await.unwrap(); - Ok(RocketDbPool { - conn, - }) + Ok(SeaOrmPool { conn }) } - async fn get(&self) -> Result { - Ok(self.conn.clone()) + fn borrow(&self) -> &Self::Connection { + &self.conn } } diff --git a/examples/rocket_example/src/setup.rs b/examples/rocket_example/src/setup.rs index 034e8b53..f5b5a99e 100644 --- a/examples/rocket_example/src/setup.rs +++ b/examples/rocket_example/src/setup.rs @@ -1,5 +1,5 @@ use sea_orm::sea_query::{ColumnDef, TableCreateStatement}; -use sea_orm::{error::*, sea_query, DbConn, ExecResult}; +use sea_orm::{error::*, query::*, sea_query, DbConn, ExecResult}; async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result { let builder = db.get_database_backend(); diff --git a/examples/tokio/Cargo.toml b/issues/86/Cargo.toml similarity index 100% rename from examples/tokio/Cargo.toml rename to issues/86/Cargo.toml diff --git a/examples/tokio/src/cake.rs b/issues/86/src/cake.rs similarity index 100% rename from examples/tokio/src/cake.rs rename to issues/86/src/cake.rs diff --git a/examples/tokio/src/main.rs b/issues/86/src/main.rs similarity index 100% rename from examples/tokio/src/main.rs rename to issues/86/src/main.rs diff --git a/sea-orm-cli/Cargo.toml b/sea-orm-cli/Cargo.toml index 4b1fc2cd..07db2e4b 100644 --- a/sea-orm-cli/Cargo.toml +++ b/sea-orm-cli/Cargo.toml @@ -3,7 +3,7 @@ [package] name = "sea-orm-cli" -version = "0.2.4" +version = "0.2.6" authors = [ "Billy Chan " ] edition = "2018" description = "Command line utility for SeaORM" @@ -21,7 +21,7 @@ path = "src/main.rs" clap = { version = "^2.33.3" } dotenv = { version = "^0.15" } async-std = { version = "^1.9", features = [ "attributes" ] } -sea-orm-codegen = { version = "^0.2.4", path = "../sea-orm-codegen" } +sea-orm-codegen = { version = "^0.2.6", path = "../sea-orm-codegen" } sea-schema = { version = "^0.2.9", default-features = false, features = [ "debug-print", "sqlx-mysql", diff --git a/sea-orm-codegen/Cargo.toml b/sea-orm-codegen/Cargo.toml index 46e81401..9013cea4 100644 --- a/sea-orm-codegen/Cargo.toml +++ b/sea-orm-codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sea-orm-codegen" -version = "0.2.4" +version = "0.2.6" authors = ["Billy Chan "] edition = "2018" description = "Code Generator for SeaORM" diff --git a/sea-orm-codegen/src/entity/column.rs b/sea-orm-codegen/src/entity/column.rs index 532f2e91..39eb340c 100644 --- a/sea-orm-codegen/src/entity/column.rs +++ b/sea-orm-codegen/src/entity/column.rs @@ -1,3 +1,4 @@ +use crate::util::escape_rust_keyword; use heck::{CamelCase, SnakeCase}; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; @@ -14,11 +15,11 @@ pub struct Column { impl Column { pub fn get_name_snake_case(&self) -> Ident { - format_ident!("{}", self.name.to_snake_case()) + format_ident!("{}", escape_rust_keyword(self.name.to_snake_case())) } pub fn get_name_camel_case(&self) -> Ident { - format_ident!("{}", self.name.to_camel_case()) + format_ident!("{}", escape_rust_keyword(self.name.to_camel_case())) } pub fn get_rs_type(&self) -> TokenStream { @@ -27,8 +28,6 @@ impl Column { ColumnType::Char(_) | ColumnType::String(_) | ColumnType::Text - | ColumnType::Time(_) - | ColumnType::Date | ColumnType::Custom(_) => "String", ColumnType::TinyInteger(_) => "i8", ColumnType::SmallInteger(_) => "i16", @@ -37,6 +36,8 @@ impl Column { ColumnType::Float(_) => "f32", ColumnType::Double(_) => "f64", ColumnType::Json | ColumnType::JsonBinary => "Json", + ColumnType::Date => "Date", + ColumnType::Time(_) => "Time", ColumnType::DateTime(_) | ColumnType::Timestamp(_) => "DateTime", ColumnType::TimestampWithTimeZone(_) => "DateTimeWithTimeZone", ColumnType::Decimal(_) | ColumnType::Money(_) => "Decimal", @@ -194,6 +195,11 @@ mod tests { make_col!("CAKE_FILLING_ID", ColumnType::Double(None)), make_col!("CAKE-FILLING-ID", ColumnType::Binary(None)), make_col!("CAKE", ColumnType::Boolean), + make_col!("date", ColumnType::Date), + make_col!("time", ColumnType::Time(None)), + make_col!("date_time", ColumnType::DateTime(None)), + make_col!("timestamp", ColumnType::Timestamp(None)), + make_col!("timestamp_tz", ColumnType::TimestampWithTimeZone(None)), ] } @@ -211,6 +217,11 @@ mod tests { "cake_filling_id", "cake_filling_id", "cake", + "date", + "time", + "date_time", + "timestamp", + "timestamp_tz", ]; for (col, snack_case) in columns.into_iter().zip(snack_cases) { assert_eq!(col.get_name_snake_case().to_string(), snack_case); @@ -231,6 +242,11 @@ mod tests { "CakeFillingId", "CakeFillingId", "Cake", + "Date", + "Time", + "DateTime", + "Timestamp", + "TimestampTz", ]; for (col, camel_case) in columns.into_iter().zip(camel_cases) { assert_eq!(col.get_name_camel_case().to_string(), camel_case); @@ -241,7 +257,21 @@ mod tests { fn test_get_rs_type() { let columns = setup(); let rs_types = vec![ - "String", "String", "i8", "i16", "i32", "i64", "f32", "f64", "Vec", "bool", + "String", + "String", + "i8", + "i16", + "i32", + "i64", + "f32", + "f64", + "Vec", + "bool", + "Date", + "Time", + "DateTime", + "DateTime", + "DateTimeWithTimeZone", ]; for (mut col, rs_type) in columns.into_iter().zip(rs_types) { let rs_type: TokenStream = rs_type.parse().unwrap(); @@ -271,6 +301,11 @@ mod tests { "ColumnType::Double.def()", "ColumnType::Binary.def()", "ColumnType::Boolean.def()", + "ColumnType::Date.def()", + "ColumnType::Time.def()", + "ColumnType::DateTime.def()", + "ColumnType::Timestamp.def()", + "ColumnType::TimestampWithTimeZone.def()", ]; for (mut col, col_def) in columns.into_iter().zip(col_defs) { let mut col_def: TokenStream = col_def.parse().unwrap(); diff --git a/sea-orm-codegen/src/entity/writer.rs b/sea-orm-codegen/src/entity/writer.rs index 59f54537..17e74130 100644 --- a/sea-orm-codegen/src/entity/writer.rs +++ b/sea-orm-codegen/src/entity/writer.rs @@ -597,18 +597,85 @@ mod tests { name: "id".to_owned(), }], }, + Entity { + table_name: "rust_keyword".to_owned(), + columns: vec![ + Column { + name: "id".to_owned(), + col_type: ColumnType::Integer(Some(11)), + auto_increment: true, + not_null: true, + unique: false, + }, + Column { + name: "testing".to_owned(), + col_type: ColumnType::Integer(Some(11)), + auto_increment: false, + not_null: true, + unique: false, + }, + Column { + name: "rust".to_owned(), + col_type: ColumnType::Integer(Some(11)), + auto_increment: false, + not_null: true, + unique: false, + }, + Column { + name: "keywords".to_owned(), + col_type: ColumnType::Integer(Some(11)), + auto_increment: false, + not_null: true, + unique: false, + }, + Column { + name: "type".to_owned(), + col_type: ColumnType::Integer(Some(11)), + auto_increment: false, + not_null: true, + unique: false, + }, + Column { + name: "typeof".to_owned(), + col_type: ColumnType::Integer(Some(11)), + auto_increment: false, + not_null: true, + unique: false, + }, + Column { + name: "crate".to_owned(), + col_type: ColumnType::Integer(Some(11)), + auto_increment: false, + not_null: true, + unique: false, + }, + Column { + name: "self".to_owned(), + col_type: ColumnType::Integer(Some(11)), + auto_increment: false, + not_null: true, + unique: false, + }, + ], + relations: vec![], + conjunct_relations: vec![], + primary_keys: vec![PrimaryKey { + name: "id".to_owned(), + }], + }, ] } #[test] fn test_gen_expanded_code_blocks() -> io::Result<()> { let entities = setup(); - const ENTITY_FILES: [&str; 5] = [ + const ENTITY_FILES: [&str; 6] = [ include_str!("../../tests/expanded/cake.rs"), include_str!("../../tests/expanded/cake_filling.rs"), include_str!("../../tests/expanded/filling.rs"), include_str!("../../tests/expanded/fruit.rs"), include_str!("../../tests/expanded/vendor.rs"), + include_str!("../../tests/expanded/rust_keyword.rs"), ]; assert_eq!(entities.len(), ENTITY_FILES.len()); @@ -642,12 +709,13 @@ mod tests { #[test] fn test_gen_compact_code_blocks() -> io::Result<()> { let entities = setup(); - const ENTITY_FILES: [&str; 5] = [ + const ENTITY_FILES: [&str; 6] = [ include_str!("../../tests/compact/cake.rs"), include_str!("../../tests/compact/cake_filling.rs"), include_str!("../../tests/compact/filling.rs"), include_str!("../../tests/compact/fruit.rs"), include_str!("../../tests/compact/vendor.rs"), + include_str!("../../tests/compact/rust_keyword.rs"), ]; assert_eq!(entities.len(), ENTITY_FILES.len()); diff --git a/sea-orm-codegen/src/lib.rs b/sea-orm-codegen/src/lib.rs index 07e167bc..5e637de1 100644 --- a/sea-orm-codegen/src/lib.rs +++ b/sea-orm-codegen/src/lib.rs @@ -1,5 +1,6 @@ mod entity; mod error; +mod util; pub use entity::*; pub use error::*; diff --git a/sea-orm-codegen/src/util.rs b/sea-orm-codegen/src/util.rs new file mode 100644 index 00000000..34c46c54 --- /dev/null +++ b/sea-orm-codegen/src/util.rs @@ -0,0 +1,23 @@ +pub(crate) fn escape_rust_keyword(string: T) -> String +where + T: ToString, +{ + let string = string.to_string(); + if RUST_KEYWORDS.iter().any(|s| s.eq(&string)) { + format!("r#{}", string) + } else if RUST_SPECIAL_KEYWORDS.iter().any(|s| s.eq(&string)) { + format!("{}_", string) + } else { + string + } +} + +pub(crate) const RUST_KEYWORDS: [&str; 49] = [ + "as", "async", "await", "break", "const", "continue", "dyn", "else", "enum", "extern", "false", + "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", + "return", "static", "struct", "super", "trait", "true", "type", "union", "unsafe", "use", + "where", "while", "abstract", "become", "box", "do", "final", "macro", "override", "priv", + "try", "typeof", "unsized", "virtual", "yield", +]; + +pub(crate) const RUST_SPECIAL_KEYWORDS: [&str; 3] = ["crate", "Self", "self"]; diff --git a/sea-orm-codegen/tests/compact/rust_keyword.rs b/sea-orm-codegen/tests/compact/rust_keyword.rs new file mode 100644 index 00000000..229eae22 --- /dev/null +++ b/sea-orm-codegen/tests/compact/rust_keyword.rs @@ -0,0 +1,30 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "rust_keyword")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub testing: i32, + pub rust: i32, + pub keywords: i32, + pub r#type: i32, + pub r#typeof: i32, + pub crate_: i32, + pub self_: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter)] +pub enum Relation {} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + match self { + _ => panic!("No RelationDef"), + } + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/sea-orm-codegen/tests/expanded/rust_keyword.rs b/sea-orm-codegen/tests/expanded/rust_keyword.rs new file mode 100644 index 00000000..1ab8a627 --- /dev/null +++ b/sea-orm-codegen/tests/expanded/rust_keyword.rs @@ -0,0 +1,79 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +use sea_orm::entity::prelude::*; + +#[derive(Copy, Clone, Default, Debug, DeriveEntity)] +pub struct Entity; + +impl EntityName for Entity { + fn table_name(&self) -> &str { + "rust_keyword" + } +} + +#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel)] +pub struct Model { + pub id: i32, + pub testing: i32, + pub rust: i32, + pub keywords: i32, + pub r#type: i32, + pub r#typeof: i32, + pub crate_: i32, + pub self_: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +pub enum Column { + Id, + Testing, + Rust, + Keywords, + Type, + Typeof, + Crate, + Self_, +} + +#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)] +pub enum PrimaryKey { + Id, +} + +impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + + fn auto_increment() -> bool { + true + } +} + +#[derive(Copy, Clone, Debug, EnumIter)] +pub enum Relation {} + +impl ColumnTrait for Column { + type EntityName = Entity; + + fn def(&self) -> ColumnDef { + match self { + Self::Id => ColumnType::Integer.def(), + Self::Testing => ColumnType::Integer.def(), + Self::Rust => ColumnType::Integer.def(), + Self::Keywords => ColumnType::Integer.def(), + Self::Type => ColumnType::Integer.def(), + Self::Typeof => ColumnType::Integer.def(), + Self::Crate => ColumnType::Integer.def(), + Self::Self_ => ColumnType::Integer.def(), + } + } +} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + match self { + _ => panic!("No RelationDef"), + } + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/sea-orm-macros/Cargo.toml b/sea-orm-macros/Cargo.toml index e37c8053..cde1575c 100644 --- a/sea-orm-macros/Cargo.toml +++ b/sea-orm-macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sea-orm-macros" -version = "0.2.4" +version = "0.2.6" authors = [ "Billy Chan " ] edition = "2018" description = "Derive macros for SeaORM" diff --git a/sea-orm-macros/src/derives/active_model.rs b/sea-orm-macros/src/derives/active_model.rs index 2227f09b..85bdcb69 100644 --- a/sea-orm-macros/src/derives/active_model.rs +++ b/sea-orm-macros/src/derives/active_model.rs @@ -1,4 +1,4 @@ -use crate::util::field_not_ignored; +use crate::util::{escape_rust_keyword, field_not_ignored, trim_starting_raw_identifier}; use heck::CamelCase; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote, quote_spanned}; @@ -29,10 +29,10 @@ pub fn expand_derive_active_model(ident: Ident, data: Data) -> syn::Result) -> syn::Result { // if #[sea_orm(table_name = "foo", schema_name = "bar")] specified, create Entity struct let mut table_name = None; @@ -60,8 +60,10 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res if let Fields::Named(fields) = item_struct.fields { for field in fields.named { if let Some(ident) = &field.ident { - let mut field_name = - Ident::new(&ident.to_string().to_case(Case::Pascal), Span::call_site()); + let mut field_name = Ident::new( + &trim_starting_raw_identifier(&ident).to_case(Case::Pascal), + Span::call_site(), + ); let mut nullable = false; let mut default_value = None; @@ -168,6 +170,8 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res field_name = enum_name; } + field_name = Ident::new(&escape_rust_keyword(field_name), Span::call_site()); + if ignore { continue; } else { diff --git a/sea-orm-macros/src/derives/model.rs b/sea-orm-macros/src/derives/model.rs index a43b487f..29a597b9 100644 --- a/sea-orm-macros/src/derives/model.rs +++ b/sea-orm-macros/src/derives/model.rs @@ -1,4 +1,7 @@ -use crate::{attributes::derive_attr, util::field_not_ignored}; +use crate::{ + attributes::derive_attr, + util::{escape_rust_keyword, field_not_ignored, trim_starting_raw_identifier}, +}; use heck::CamelCase; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; @@ -43,10 +46,10 @@ impl DeriveModel { let column_idents = fields .iter() .map(|field| { - let mut ident = format_ident!( - "{}", - field.ident.as_ref().unwrap().to_string().to_camel_case() - ); + let ident = field.ident.as_ref().unwrap().to_string(); + let ident = trim_starting_raw_identifier(ident).to_camel_case(); + let ident = escape_rust_keyword(ident); + let mut ident = format_ident!("{}", &ident); for attr in field.attrs.iter() { if let Some(ident) = attr.path.get_ident() { if ident != "sea_orm" { diff --git a/sea-orm-macros/src/util.rs b/sea-orm-macros/src/util.rs index 7dda1087..379b486c 100644 --- a/sea-orm-macros/src/util.rs +++ b/sea-orm-macros/src/util.rs @@ -24,3 +24,39 @@ pub(crate) fn field_not_ignored(field: &Field) -> bool { } true } + +pub(crate) fn trim_starting_raw_identifier(string: T) -> String +where + T: ToString, +{ + string + .to_string() + .trim_start_matches(RAW_IDENTIFIER) + .to_string() +} + +pub(crate) fn escape_rust_keyword(string: T) -> String +where + T: ToString, +{ + let string = string.to_string(); + if RUST_KEYWORDS.iter().any(|s| s.eq(&string)) { + format!("r#{}", string) + } else if RUST_SPECIAL_KEYWORDS.iter().any(|s| s.eq(&string)) { + format!("{}_", string) + } else { + string + } +} + +pub(crate) const RAW_IDENTIFIER: &str = "r#"; + +pub(crate) const RUST_KEYWORDS: [&str; 49] = [ + "as", "async", "await", "break", "const", "continue", "dyn", "else", "enum", "extern", "false", + "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", + "return", "static", "struct", "super", "trait", "true", "type", "union", "unsafe", "use", + "where", "while", "abstract", "become", "box", "do", "final", "macro", "override", "priv", + "try", "typeof", "unsized", "virtual", "yield", +]; + +pub(crate) const RUST_SPECIAL_KEYWORDS: [&str; 3] = ["crate", "Self", "self"]; diff --git a/sea-orm-rocket/Cargo.toml b/sea-orm-rocket/Cargo.toml new file mode 100644 index 00000000..4975f8e1 --- /dev/null +++ b/sea-orm-rocket/Cargo.toml @@ -0,0 +1,2 @@ +[workspace] +members = ["codegen", "lib"] \ No newline at end of file diff --git a/sea-orm-rocket/README.md b/sea-orm-rocket/README.md new file mode 100644 index 00000000..f8d94b21 --- /dev/null +++ b/sea-orm-rocket/README.md @@ -0,0 +1 @@ +# SeaORM Rocket support crate. \ No newline at end of file diff --git a/sea-orm-rocket/codegen/Cargo.toml b/sea-orm-rocket/codegen/Cargo.toml new file mode 100644 index 00000000..75656487 --- /dev/null +++ b/sea-orm-rocket/codegen/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "sea-orm-rocket-codegen" +version = "0.1.0-rc" +authors = ["Sergio Benitez ", "Jeb Rosen "] +description = "Procedural macros for sea_orm_rocket." +repository = "https://github.com/SergioBenitez/Rocket/contrib/db_pools" +readme = "../README.md" +keywords = ["rocket", "framework", "database", "pools"] +license = "MIT OR Apache-2.0" +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +devise = "0.3" +quote = "1" + +[dev-dependencies] +rocket = { git = "https://github.com/SergioBenitez/Rocket.git", default-features = false } +trybuild = "1.0" +version_check = "0.9" diff --git a/sea-orm-rocket/codegen/src/database.rs b/sea-orm-rocket/codegen/src/database.rs new file mode 100644 index 00000000..a6f5a981 --- /dev/null +++ b/sea-orm-rocket/codegen/src/database.rs @@ -0,0 +1,110 @@ +use proc_macro::TokenStream; + +use devise::{DeriveGenerator, FromMeta, MapperBuild, Support, ValidatorBuild}; +use devise::proc_macro2_diagnostics::SpanDiagnosticExt; +use devise::syn::{self, spanned::Spanned}; + +const ONE_DATABASE_ATTR: &str = "missing `#[database(\"name\")]` attribute"; +const ONE_UNNAMED_FIELD: &str = "struct must have exactly one unnamed field"; + +#[derive(Debug, FromMeta)] +struct DatabaseAttribute { + #[meta(naked)] + name: String, +} + +pub fn derive_database(input: TokenStream) -> TokenStream { + DeriveGenerator::build_for(input, quote!(impl sea_orm_rocket::Database)) + .support(Support::TupleStruct) + .validator(ValidatorBuild::new() + .struct_validate(|_, s| { + if s.fields.len() == 1 { + Ok(()) + } else { + Err(s.span().error(ONE_UNNAMED_FIELD)) + } + }) + ) + .outer_mapper(MapperBuild::new() + .struct_map(|_, s| { + let pool_type = match &s.fields { + syn::Fields::Unnamed(f) => &f.unnamed[0].ty, + _ => unreachable!("Support::TupleStruct"), + }; + + let decorated_type = &s.ident; + let db_ty = quote_spanned!(decorated_type.span() => + <#decorated_type as sea_orm_rocket::Database> + ); + + quote_spanned! { decorated_type.span() => + impl From<#pool_type> for #decorated_type { + fn from(pool: #pool_type) -> Self { + Self(pool) + } + } + + impl std::ops::Deref for #decorated_type { + type Target = #pool_type; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl std::ops::DerefMut for #decorated_type { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + #[rocket::async_trait] + impl<'r> rocket::request::FromRequest<'r> for &'r #decorated_type { + type Error = (); + + async fn from_request( + req: &'r rocket::request::Request<'_> + ) -> rocket::request::Outcome { + match #db_ty::fetch(req.rocket()) { + Some(db) => rocket::outcome::Outcome::Success(db), + None => rocket::outcome::Outcome::Failure(( + rocket::http::Status::InternalServerError, ())) + } + } + } + + impl rocket::Sentinel for &#decorated_type { + fn abort(rocket: &rocket::Rocket) -> bool { + #db_ty::fetch(rocket).is_none() + } + } + } + }) + ) + .outer_mapper(quote!(#[rocket::async_trait])) + .inner_mapper(MapperBuild::new() + .try_struct_map(|_, s| { + let db_name = DatabaseAttribute::one_from_attrs("database", &s.attrs)? + .map(|attr| attr.name) + .ok_or_else(|| s.span().error(ONE_DATABASE_ATTR))?; + + let fairing_name = format!("'{}' Database Pool", db_name); + + let pool_type = match &s.fields { + syn::Fields::Unnamed(f) => &f.unnamed[0].ty, + _ => unreachable!("Support::TupleStruct"), + }; + + Ok(quote_spanned! { pool_type.span() => + type Pool = #pool_type; + + const NAME: &'static str = #db_name; + + fn init() -> sea_orm_rocket::Initializer { + sea_orm_rocket::Initializer::with_name(#fairing_name) + } + }) + }) + ) + .to_tokens() +} diff --git a/sea-orm-rocket/codegen/src/lib.rs b/sea-orm-rocket/codegen/src/lib.rs new file mode 100644 index 00000000..7cb32f94 --- /dev/null +++ b/sea-orm-rocket/codegen/src/lib.rs @@ -0,0 +1,52 @@ +#![recursion_limit="256"] +#![warn(rust_2018_idioms)] + +//! # `sea_orm_rocket` - Code Generation +//! +//! Implements the code generation portion of the `sea_orm_rocket` crate. This +//! is an implementation detail. This create should never be depended on +//! directly. + +#[macro_use] extern crate quote; + +mod database; + +/// Automatic derive for the [`Database`] trait. +/// +/// The derive generates an implementation of [`Database`] as follows: +/// +/// * [`Database::NAME`] is set to the value in the `#[database("name")]` +/// attribute. +/// +/// This names the database, providing an anchor to configure the database via +/// `Rocket.toml` or any other configuration source. Specifically, the +/// configuration in `databases.name` is used to configure the driver. +/// +/// * [`Database::Pool`] is set to the wrapped type: `PoolType` above. The type +/// must implement [`Pool`]. +/// +/// To meet the required [`Database`] supertrait bounds, this derive also +/// generates implementations for: +/// +/// * `From` +/// +/// * `Deref` +/// +/// * `DerefMut` +/// +/// * `FromRequest<'_> for &Db` +/// +/// * `Sentinel for &Db` +/// +/// The `Deref` impls enable accessing the database pool directly from +/// references `&Db` or `&mut Db`. To force a dereference to the underlying +/// type, use `&db.0` or `&**db` or their `&mut` variants. +/// +/// [`Database`]: ../sea_orm_rocket/trait.Database.html +/// [`Database::NAME`]: ../sea_orm_rocket/trait.Database.html#associatedconstant.NAME +/// [`Database::Pool`]: ../sea_orm_rocket/trait.Database.html#associatedtype.Pool +/// [`Pool`]: ../sea_orm_rocket/trait.Pool.html +#[proc_macro_derive(Database, attributes(database))] +pub fn derive_database(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + crate::database::derive_database(input) +} diff --git a/sea-orm-rocket/lib/Cargo.toml b/sea-orm-rocket/lib/Cargo.toml new file mode 100644 index 00000000..3a586fe1 --- /dev/null +++ b/sea-orm-rocket/lib/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "sea-orm-rocket" +version = "0.1.0" +authors = ["Sergio Benitez ", "Jeb Rosen "] +description = "SeaORM Rocket support crate" +repository = "https://github.com/SeaQL/sea-orm" +readme = "../README.md" +keywords = ["rocket", "framework", "database", "pools"] +license = "MIT OR Apache-2.0" +edition = "2018" + +[package.metadata.docs.rs] +all-features = true + +[dependencies.rocket] +git = "https://github.com/SergioBenitez/Rocket.git" +version = "0.5.0-rc.1" +default-features = false + +[dependencies.sea-orm-rocket-codegen] +path = "../codegen" +version = "0.1.0-rc" + +[dev-dependencies.rocket] +git = "https://github.com/SergioBenitez/Rocket.git" +default-features = false +features = ["json"] diff --git a/sea-orm-rocket/lib/src/config.rs b/sea-orm-rocket/lib/src/config.rs new file mode 100644 index 00000000..b30c2ce5 --- /dev/null +++ b/sea-orm-rocket/lib/src/config.rs @@ -0,0 +1,83 @@ +use rocket::serde::{Deserialize, Serialize}; + +/// Base configuration for all database drivers. +/// +/// A dictionary matching this structure is extracted from the active +/// [`Figment`](crate::figment::Figment), scoped to `databases.name`, where +/// `name` is the name of the database, by the +/// [`Initializer`](crate::Initializer) fairing on ignition and used to +/// configure the relevant database and database pool. +/// +/// With the default provider, these parameters are typically configured in a +/// `Rocket.toml` file: +/// +/// ```toml +/// [default.databases.db_name] +/// url = "/path/to/db.sqlite" +/// +/// # only `url` is required. `Initializer` provides defaults for the rest. +/// min_connections = 64 +/// max_connections = 1024 +/// connect_timeout = 5 +/// idle_timeout = 120 +/// ``` +/// +/// Alternatively, a custom provider can be used. For example, a custom `Figment` +/// with a global `databases.name` configuration: +/// +/// ```rust +/// # use rocket::launch; +/// #[launch] +/// fn rocket() -> _ { +/// let figment = rocket::Config::figment() +/// .merge(("databases.name", sea_orm_rocket::Config { +/// url: "db:specific@config&url".into(), +/// min_connections: None, +/// max_connections: 1024, +/// connect_timeout: 3, +/// idle_timeout: None, +/// })); +/// +/// rocket::custom(figment) +/// } +/// ``` +/// +/// For general information on configuration in Rocket, see [`rocket::config`]. +/// For higher-level details on configuring a database, see the [crate-level +/// docs](crate#configuration). +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(crate = "rocket::serde")] +pub struct Config { + /// Database-specific connection and configuration URL. + /// + /// The format of the URL is database specific; consult your database's + /// documentation. + pub url: String, + /// Minimum number of connections to maintain in the pool. + /// + /// **Note:** `deadpool` drivers do not support and thus ignore this value. + /// + /// _Default:_ `None`. + pub min_connections: Option, + /// Maximum number of connections to maintain in the pool. + /// + /// _Default:_ `workers * 4`. + pub max_connections: usize, + /// Number of seconds to wait for a connection before timing out. + /// + /// If the timeout elapses before a connection can be made or retrieved from + /// a pool, an error is returned. + /// + /// _Default:_ `5`. + pub connect_timeout: u64, + /// Maximum number of seconds to keep a connection alive for. + /// + /// After a connection is established, it is maintained in a pool for + /// efficient connection retrieval. When an `idle_timeout` is set, that + /// connection will be closed after the timeout elapses. If an + /// `idle_timeout` is not specified, the behavior is driver specific but + /// typically defaults to keeping a connection active indefinitely. + /// + /// _Default:_ `None`. + pub idle_timeout: Option, +} diff --git a/sea-orm-rocket/lib/src/database.rs b/sea-orm-rocket/lib/src/database.rs new file mode 100644 index 00000000..b3095ac9 --- /dev/null +++ b/sea-orm-rocket/lib/src/database.rs @@ -0,0 +1,245 @@ +use std::marker::PhantomData; +use std::ops::{DerefMut}; + +use rocket::{error, info_, Build, Ignite, Phase, Rocket, Sentinel}; +use rocket::fairing::{self, Fairing, Info, Kind}; +use rocket::request::{FromRequest, Outcome, Request}; +use rocket::http::Status; + +use rocket::yansi::Paint; +use rocket::figment::providers::Serialized; + +use crate::Pool; + +/// Derivable trait which ties a database [`Pool`] with a configuration name. +/// +/// This trait should rarely, if ever, be implemented manually. Instead, it +/// should be derived: +/// +/// ```ignore +/// use sea_orm_rocket::{Database}; +/// # use sea_orm_rocket::MockPool as SeaOrmPool; +/// +/// #[derive(Database, Debug)] +/// #[database("sea_orm")] +/// struct Db(SeaOrmPool); +/// +/// #[launch] +/// fn rocket() -> _ { +/// rocket::build().attach(Db::init()) +/// } +/// ``` +/// +/// See the [`Database` derive](derive@crate::Database) for details. +pub trait Database: From + DerefMut + Send + Sync + 'static { + /// The [`Pool`] type of connections to this database. + /// + /// When `Database` is derived, this takes the value of the `Inner` type in + /// `struct Db(Inner)`. + type Pool: Pool; + + /// The configuration name for this database. + /// + /// When `Database` is derived, this takes the value `"name"` in the + /// `#[database("name")]` attribute. + const NAME: &'static str; + + /// Returns a fairing that initializes the database and its connection pool. + /// + /// # Example + /// + /// ```rust + /// # mod _inner { + /// # use rocket::launch; + /// use sea_orm_rocket::{Database}; + /// # use sea_orm_rocket::MockPool as SeaOrmPool; + /// + /// #[derive(Database)] + /// #[database("sea_orm")] + /// struct Db(SeaOrmPool); + /// + /// #[launch] + /// fn rocket() -> _ { + /// rocket::build().attach(Db::init()) + /// } + /// # } + /// ``` + fn init() -> Initializer { + Initializer::new() + } + + /// Returns a reference to the initialized database in `rocket`. The + /// initializer fairing returned by `init()` must have already executed for + /// `Option` to be `Some`. This is guaranteed to be the case if the fairing + /// is attached and either: + /// + /// * Rocket is in the [`Orbit`](rocket::Orbit) phase. That is, the + /// application is running. This is always the case in request guards + /// and liftoff fairings, + /// * _or_ Rocket is in the [`Build`](rocket::Build) or + /// [`Ignite`](rocket::Ignite) phase and the `Initializer` fairing has + /// already been run. This is the case in all fairing callbacks + /// corresponding to fairings attached _after_ the `Initializer` + /// fairing. + /// + /// # Example + /// + /// Run database migrations in an ignite fairing. It is imperative that the + /// migration fairing be registered _after_ the `init()` fairing. + /// + /// ```rust + /// # mod _inner { + /// # use rocket::launch; + /// use rocket::{Rocket, Build}; + /// use rocket::fairing::{self, AdHoc}; + /// + /// use sea_orm_rocket::{Database}; + /// # use sea_orm_rocket::MockPool as SeaOrmPool; + /// + /// #[derive(Database)] + /// #[database("sea_orm")] + /// struct Db(SeaOrmPool); + /// + /// async fn run_migrations(rocket: Rocket) -> fairing::Result { + /// if let Some(db) = Db::fetch(&rocket) { + /// // run migrations using `db`. get the inner type with &db.0. + /// Ok(rocket) + /// } else { + /// Err(rocket) + /// } + /// } + /// + /// #[launch] + /// fn rocket() -> _ { + /// rocket::build() + /// .attach(Db::init()) + /// .attach(AdHoc::try_on_ignite("DB Migrations", run_migrations)) + /// } + /// # } + /// ``` + fn fetch(rocket: &Rocket

) -> Option<&Self> { + if let Some(db) = rocket.state() { + return Some(db); + } + + let dbtype = std::any::type_name::(); + let fairing = Paint::default(format!("{}::init()", dbtype)).bold(); + error!("Attempted to fetch unattached database `{}`.", Paint::default(dbtype).bold()); + info_!("`{}` fairing must be attached prior to using this database.", fairing); + None + } +} + +/// A [`Fairing`] which initializes a [`Database`] and its connection pool. +/// +/// A value of this type can be created for any type `D` that implements +/// [`Database`] via the [`Database::init()`] method on the type. Normally, a +/// value of this type _never_ needs to be constructed directly. This +/// documentation exists purely as a reference. +/// +/// This fairing initializes a database pool. Specifically, it: +/// +/// 1. Reads the configuration at `database.db_name`, where `db_name` is +/// [`Database::NAME`]. +/// +/// 2. Sets [`Config`](crate::Config) defaults on the configuration figment. +/// +/// 3. Calls [`Pool::init()`]. +/// +/// 4. Stores the database instance in managed storage, retrievable via +/// [`Database::fetch()`]. +/// +/// The name of the fairing itself is `Initializer`, with `D` replaced with +/// the type name `D` unless a name is explicitly provided via +/// [`Self::with_name()`]. +pub struct Initializer(Option<&'static str>, PhantomData D>); + +/// A request guard which retrieves a single connection to a [`Database`]. +/// +/// For a database type of `Db`, a request guard of `Connection` retrieves a +/// single connection to `Db`. +/// +/// The request guard succeeds if the database was initialized by the +/// [`Initializer`] fairing and a connection is available within +/// [`connect_timeout`](crate::Config::connect_timeout) seconds. +/// * If the `Initializer` fairing was _not_ attached, the guard _fails_ with +/// status `InternalServerError`. A [`Sentinel`] guards this condition, and so +/// this type of failure is unlikely to occur. A `None` error is returned. +/// * If a connection is not available within `connect_timeout` seconds or +/// another error occurs, the gaurd _fails_ with status `ServiceUnavailable` +/// and the error is returned in `Some`. +/// +pub struct Connection<'a, D: Database>(&'a ::Connection); + +impl Initializer { + /// Returns a database initializer fairing for `D`. + /// + /// This method should never need to be called manually. See the [crate + /// docs](crate) for usage information. + pub fn new() -> Self { + Self(None, std::marker::PhantomData) + } + + /// Returns a database initializer fairing for `D` with name `name`. + /// + /// This method should never need to be called manually. See the [crate + /// docs](crate) for usage information. + pub fn with_name(name: &'static str) -> Self { + Self(Some(name), std::marker::PhantomData) + } +} + +impl<'a, D: Database> Connection<'a, D> { + /// Returns the internal connection value. See the [`Connection` Deref + /// column](crate#supported-drivers) for the expected type of this value. + pub fn into_inner(self) -> &'a ::Connection { + self.0 + } +} + +#[rocket::async_trait] +impl Fairing for Initializer { + fn info(&self) -> Info { + Info { + name: self.0.unwrap_or_else(std::any::type_name::), + kind: Kind::Ignite, + } + } + + async fn on_ignite(&self, rocket: Rocket) -> fairing::Result { + let workers: usize = rocket.figment() + .extract_inner(rocket::Config::WORKERS) + .unwrap_or_else(|_| rocket::Config::default().workers); + + let figment = rocket.figment() + .focus(&format!("databases.{}", D::NAME)) + .merge(Serialized::default("max_connections", workers * 4)) + .merge(Serialized::default("connect_timeout", 5)); + + match ::init(&figment).await { + Ok(pool) => Ok(rocket.manage(D::from(pool))), + Err(e) => { + error!("failed to initialize database: {}", e); + Err(rocket) + } + } + } +} + +#[rocket::async_trait] +impl<'r, D: Database> FromRequest<'r> for Connection<'r, D> { + type Error = Option<::Error>; + + async fn from_request(req: &'r Request<'_>) -> Outcome { + match D::fetch(req.rocket()) { + Some(pool) => Outcome::Success(Connection(pool.borrow())), + None => Outcome::Failure((Status::InternalServerError, None)), + } + } +} + +impl Sentinel for Connection<'_, D> { + fn abort(rocket: &Rocket) -> bool { + D::fetch(rocket).is_none() + } +} diff --git a/sea-orm-rocket/lib/src/error.rs b/sea-orm-rocket/lib/src/error.rs new file mode 100644 index 00000000..69bae106 --- /dev/null +++ b/sea-orm-rocket/lib/src/error.rs @@ -0,0 +1,35 @@ +use std::fmt; + +/// A general error type for use by [`Pool`](crate::Pool#implementing) +/// implementors and returned by the [`Connection`](crate::Connection) request +/// guard. +#[derive(Debug)] +pub enum Error { + /// An error that occured during database/pool initialization. + Init(A), + + /// An error that ocurred while retrieving a connection from the pool. + Get(B), + + /// A [`Figment`](crate::figment::Figment) configuration error. + Config(crate::figment::Error), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::Init(e) => write!(f, "failed to initialize database: {}", e), + Error::Get(e) => write!(f, "failed to get db connection: {}", e), + Error::Config(e) => write!(f, "bad configuration: {}", e), + } + } +} + +impl std::error::Error for Error + where A: fmt::Debug + fmt::Display, B: fmt::Debug + fmt::Display {} + +impl From for Error { + fn from(e: crate::figment::Error) -> Self { + Self::Config(e) + } +} diff --git a/sea-orm-rocket/lib/src/lib.rs b/sea-orm-rocket/lib/src/lib.rs new file mode 100644 index 00000000..4b98cd0b --- /dev/null +++ b/sea-orm-rocket/lib/src/lib.rs @@ -0,0 +1,20 @@ +//! SeaORM Rocket support crate. +#![deny(missing_docs)] + +/// Re-export of the `figment` crate. +#[doc(inline)] +pub use rocket::figment; + +pub use rocket; + +mod database; +mod error; +mod pool; +mod config; + +pub use self::database::{Connection, Database, Initializer}; +pub use self::error::Error; +pub use self::pool::{Pool, MockPool}; +pub use self::config::Config; + +pub use sea_orm_rocket_codegen::*; diff --git a/sea-orm-rocket/lib/src/pool.rs b/sea-orm-rocket/lib/src/pool.rs new file mode 100644 index 00000000..bdd8e638 --- /dev/null +++ b/sea-orm-rocket/lib/src/pool.rs @@ -0,0 +1,70 @@ +use rocket::figment::Figment; + +/// Generic [`Database`](crate::Database) driver connection pool trait. +/// +/// This trait provides a generic interface to various database pooling +/// implementations in the Rust ecosystem. It can be implemented by anyone. +/// +/// This is adapted from the original `rocket_db_pools`. But on top we require +/// `Connection` itself to be `Sync`. Hence, instead of cloning or allocating +/// a new connection per request, here we only borrow a reference to the pool. +/// +/// In SeaORM, only *when* you are about to execute a SQL statement will a +/// connection be acquired from the pool, and returned as soon as the query finishes. +/// This helps a bit with concurrency if the lifecycle of a request is long enough. +/// ``` +#[rocket::async_trait] +pub trait Pool: Sized + Send + Sync + 'static { + /// The connection type managed by this pool. + type Connection; + + /// The error type returned by [`Self::init()`]. + type Error: std::error::Error; + + /// Constructs a pool from a [Value](rocket::figment::value::Value). + /// + /// It is up to each implementor of `Pool` to define its accepted + /// configuration value(s) via the `Config` associated type. Most + /// integrations provided in `sea_orm_rocket` use [`Config`], which + /// accepts a (required) `url` and an (optional) `pool_size`. + /// + /// ## Errors + /// + /// This method returns an error if the configuration is not compatible, or + /// if creating a pool failed due to an unavailable database server, + /// insufficient resources, or another database-specific error. + async fn init(figment: &Figment) -> Result; + + /// Borrows a reference to the pool + fn borrow(&self) -> &Self::Connection; +} + +#[derive(Debug)] +/// A mock object which impl `Pool`, for testing only +pub struct MockPool; + +#[derive(Debug)] +pub struct MockPoolErr; + +impl std::fmt::Display for MockPoolErr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl std::error::Error for MockPoolErr {} + +#[rocket::async_trait] +impl Pool for MockPool { + type Error = MockPoolErr; + + type Connection = bool; + + async fn init(_figment: &Figment) -> Result { + Ok(MockPool) + } + + fn borrow(&self) -> &Self::Connection { + &true + } +} diff --git a/src/database/connection.rs b/src/database/connection.rs index e5ec4e2c..d90c72a9 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -1,168 +1,40 @@ -use crate::{error::*, ExecResult, QueryResult, Statement, StatementBuilder}; -use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; +use crate::{ + DatabaseTransaction, DbBackend, DbErr, ExecResult, QueryResult, Statement, TransactionError, +}; +use futures::Stream; +use std::{future::Future, pin::Pin}; -#[cfg_attr(not(feature = "mock"), derive(Clone))] -pub enum DatabaseConnection { - #[cfg(feature = "sqlx-mysql")] - SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection), - #[cfg(feature = "sqlx-postgres")] - SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection), - #[cfg(feature = "sqlx-sqlite")] - SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection), - #[cfg(feature = "mock")] - MockDatabaseConnection(crate::MockDatabaseConnection), - Disconnected, -} +#[async_trait::async_trait] +pub trait ConnectionTrait<'a>: Sync { + type Stream: Stream>; -pub type DbConn = DatabaseConnection; + fn get_database_backend(&self) -> DbBackend; -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum DatabaseBackend { - MySql, - Postgres, - Sqlite, -} + async fn execute(&self, stmt: Statement) -> Result; -pub type DbBackend = DatabaseBackend; + async fn query_one(&self, stmt: Statement) -> Result, DbErr>; -impl Default for DatabaseConnection { - fn default() -> Self { - Self::Disconnected - } -} + async fn query_all(&self, stmt: Statement) -> Result, DbErr>; -impl std::fmt::Debug for DatabaseConnection { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "{}", - match self { - #[cfg(feature = "sqlx-mysql")] - Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection", - #[cfg(feature = "sqlx-postgres")] - Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection", - #[cfg(feature = "sqlx-sqlite")] - Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection", - #[cfg(feature = "mock")] - Self::MockDatabaseConnection(_) => "MockDatabaseConnection", - Self::Disconnected => "Disconnected", - } - ) - } -} + fn stream( + &'a self, + stmt: Statement, + ) -> Pin> + 'a>>; -impl DatabaseConnection { - pub fn get_database_backend(&self) -> DbBackend { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.get_database_backend(), - DatabaseConnection::Disconnected => panic!("Disconnected"), - } - } + async fn begin(&self) -> Result; - pub async fn execute(&self, stmt: Statement) -> Result { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt).await, - DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), - } - } - - pub async fn query_one(&self, stmt: Statement) -> Result, DbErr> { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt).await, - DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), - } - } - - pub async fn query_all(&self, stmt: Statement) -> Result, DbErr> { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt).await, - DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), - } - } - - #[cfg(feature = "mock")] - pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection { - match self { - DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn, - _ => panic!("not mock connection"), - } - } - - #[cfg(not(feature = "mock"))] - pub fn as_mock_connection(&self) -> Option { - None - } - - #[cfg(feature = "mock")] - pub fn into_transaction_log(self) -> Vec { - let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap(); - mocker.drain_transaction_log() - } -} - -impl DbBackend { - pub fn is_prefix_of(self, base_url: &str) -> bool { - match self { - Self::Postgres => { - base_url.starts_with("postgres://") || base_url.starts_with("postgresql://") - } - Self::MySql => base_url.starts_with("mysql://"), - Self::Sqlite => base_url.starts_with("sqlite:"), - } - } - - pub fn build(&self, statement: &S) -> Statement + /// Execute the function inside a transaction. + /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. + async fn transaction(&self, callback: F) -> Result> where - S: StatementBuilder, - { - statement.build(self) - } + F: for<'c> FnOnce( + &'c DatabaseTransaction, + ) -> Pin> + Send + 'c>> + + Send, + T: Send, + E: std::error::Error + Send; - pub fn get_query_builder(&self) -> Box { - match self { - Self::MySql => Box::new(MysqlQueryBuilder), - Self::Postgres => Box::new(PostgresQueryBuilder), - Self::Sqlite => Box::new(SqliteQueryBuilder), - } - } -} - -#[cfg(test)] -mod tests { - use crate::DatabaseConnection; - - #[test] - fn assert_database_connection_traits() { - fn assert_send_sync() {} - - assert_send_sync::(); + fn is_mock_connection(&self) -> bool { + false } } diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs new file mode 100644 index 00000000..e8f37546 --- /dev/null +++ b/src/database/db_connection.rs @@ -0,0 +1,261 @@ +use crate::{ + error::*, ConnectionTrait, DatabaseTransaction, ExecResult, QueryResult, Statement, + StatementBuilder, TransactionError, +}; +use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; +use std::{future::Future, pin::Pin}; + +#[cfg(feature = "sqlx-dep")] +use sqlx::pool::PoolConnection; + +#[cfg(feature = "mock")] +use std::sync::Arc; + +#[cfg_attr(not(feature = "mock"), derive(Clone))] +pub enum DatabaseConnection { + #[cfg(feature = "sqlx-mysql")] + SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection), + #[cfg(feature = "sqlx-postgres")] + SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection), + #[cfg(feature = "sqlx-sqlite")] + SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection), + #[cfg(feature = "mock")] + MockDatabaseConnection(Arc), + Disconnected, +} + +pub type DbConn = DatabaseConnection; + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum DatabaseBackend { + MySql, + Postgres, + Sqlite, +} + +pub type DbBackend = DatabaseBackend; + +pub(crate) enum InnerConnection { + #[cfg(feature = "sqlx-mysql")] + MySql(PoolConnection), + #[cfg(feature = "sqlx-postgres")] + Postgres(PoolConnection), + #[cfg(feature = "sqlx-sqlite")] + Sqlite(PoolConnection), + #[cfg(feature = "mock")] + Mock(std::sync::Arc), +} + +impl Default for DatabaseConnection { + fn default() -> Self { + Self::Disconnected + } +} + +impl std::fmt::Debug for DatabaseConnection { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "{}", + match self { + #[cfg(feature = "sqlx-mysql")] + Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection", + #[cfg(feature = "sqlx-postgres")] + Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection", + #[cfg(feature = "sqlx-sqlite")] + Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection", + #[cfg(feature = "mock")] + Self::MockDatabaseConnection(_) => "MockDatabaseConnection", + Self::Disconnected => "Disconnected", + } + ) + } +} + +#[async_trait::async_trait] +impl<'a> ConnectionTrait<'a> for DatabaseConnection { + type Stream = crate::QueryStream; + + fn get_database_backend(&self) -> DbBackend { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => conn.get_database_backend(), + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } + + async fn execute(&self, stmt: Statement) -> Result { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt), + DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), + } + } + + async fn query_one(&self, stmt: Statement) -> Result, DbErr> { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt), + DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), + } + } + + async fn query_all(&self, stmt: Statement) -> Result, DbErr> { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt), + DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), + } + } + + fn stream( + &'a self, + stmt: Statement, + ) -> Pin> + 'a>> { + Box::pin(async move { + Ok(match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => { + crate::QueryStream::from((Arc::clone(conn), stmt)) + } + DatabaseConnection::Disconnected => panic!("Disconnected"), + }) + }) + } + + async fn begin(&self) -> Result { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => { + DatabaseTransaction::new_mock(Arc::clone(conn)).await + } + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } + + /// Execute the function inside a transaction. + /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. + async fn transaction(&self, _callback: F) -> Result> + where + F: for<'c> FnOnce( + &'c DatabaseTransaction, + ) -> Pin> + Send + 'c>> + + Send, + T: Send, + E: std::error::Error + Send, + { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => { + conn.transaction(_callback).await + } + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => { + let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)) + .await + .map_err(TransactionError::Connection)?; + transaction.run(_callback).await + } + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } + + #[cfg(feature = "mock")] + fn is_mock_connection(&self) -> bool { + matches!(self, DatabaseConnection::MockDatabaseConnection(_)) + } +} + +#[cfg(feature = "mock")] +impl DatabaseConnection { + pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection { + match self { + DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn, + _ => panic!("not mock connection"), + } + } + + pub fn into_transaction_log(self) -> Vec { + let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap(); + mocker.drain_transaction_log() + } +} + +impl DbBackend { + pub fn is_prefix_of(self, base_url: &str) -> bool { + match self { + Self::Postgres => { + base_url.starts_with("postgres://") || base_url.starts_with("postgresql://") + } + Self::MySql => base_url.starts_with("mysql://"), + Self::Sqlite => base_url.starts_with("sqlite:"), + } + } + + pub fn build(&self, statement: &S) -> Statement + where + S: StatementBuilder, + { + statement.build(self) + } + + pub fn get_query_builder(&self) -> Box { + match self { + Self::MySql => Box::new(MysqlQueryBuilder), + Self::Postgres => Box::new(PostgresQueryBuilder), + Self::Sqlite => Box::new(SqliteQueryBuilder), + } + } +} + +#[cfg(test)] +mod tests { + use crate::DatabaseConnection; + + #[test] + fn assert_database_connection_traits() { + fn assert_send_sync() {} + + assert_send_sync::(); + } +} diff --git a/src/database/mock.rs b/src/database/mock.rs index ccb34a49..f9703d9e 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -1,14 +1,15 @@ use crate::{ error::*, DatabaseConnection, DbBackend, EntityTrait, ExecResult, ExecResultHolder, Iden, Iterable, MockDatabaseConnection, MockDatabaseTrait, ModelTrait, QueryResult, QueryResultRow, - Statement, Transaction, + Statement, }; -use sea_query::{Value, ValueType}; -use std::collections::BTreeMap; +use sea_query::{Value, ValueType, Values}; +use std::{collections::BTreeMap, sync::Arc}; #[derive(Debug)] pub struct MockDatabase { db_backend: DbBackend, + transaction: Option, transaction_log: Vec, exec_results: Vec, query_results: Vec>, @@ -29,10 +30,22 @@ pub trait IntoMockRow { fn into_mock_row(self) -> MockRow; } +#[derive(Debug)] +pub struct OpenTransaction { + stmts: Vec, + transaction_depth: usize, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Transaction { + stmts: Vec, +} + impl MockDatabase { pub fn new(db_backend: DbBackend) -> Self { Self { db_backend, + transaction: None, transaction_log: Vec::new(), exec_results: Vec::new(), query_results: Vec::new(), @@ -40,7 +53,7 @@ impl MockDatabase { } pub fn into_connection(self) -> DatabaseConnection { - DatabaseConnection::MockDatabaseConnection(MockDatabaseConnection::new(self)) + DatabaseConnection::MockDatabaseConnection(Arc::new(MockDatabaseConnection::new(self))) } pub fn append_exec_results(mut self, mut vec: Vec) -> Self { @@ -62,7 +75,11 @@ impl MockDatabase { impl MockDatabaseTrait for MockDatabase { fn execute(&mut self, counter: usize, statement: Statement) -> Result { - self.transaction_log.push(Transaction::one(statement)); + if let Some(transaction) = &mut self.transaction { + transaction.push(statement); + } else { + self.transaction_log.push(Transaction::one(statement)); + } if counter < self.exec_results.len() { Ok(ExecResult { result: ExecResultHolder::Mock(std::mem::take(&mut self.exec_results[counter])), @@ -73,7 +90,11 @@ impl MockDatabaseTrait for MockDatabase { } fn query(&mut self, counter: usize, statement: Statement) -> Result, DbErr> { - self.transaction_log.push(Transaction::one(statement)); + if let Some(transaction) = &mut self.transaction { + transaction.push(statement); + } else { + self.transaction_log.push(Transaction::one(statement)); + } if counter < self.query_results.len() { Ok(std::mem::take(&mut self.query_results[counter]) .into_iter() @@ -86,6 +107,39 @@ impl MockDatabaseTrait for MockDatabase { } } + fn begin(&mut self) { + if self.transaction.is_some() { + self.transaction + .as_mut() + .unwrap() + .begin_nested(self.db_backend); + } else { + self.transaction = Some(OpenTransaction::init()); + } + } + + fn commit(&mut self) { + if self.transaction.is_some() { + if self.transaction.as_mut().unwrap().commit(self.db_backend) { + let transaction = self.transaction.take().unwrap(); + self.transaction_log.push(transaction.into_transaction()); + } + } else { + panic!("There is no open transaction to commit"); + } + } + + fn rollback(&mut self) { + if self.transaction.is_some() { + if self.transaction.as_mut().unwrap().rollback(self.db_backend) { + let transaction = self.transaction.take().unwrap(); + self.transaction_log.push(transaction.into_transaction()); + } + } else { + panic!("There is no open transaction to rollback"); + } + } + fn drain_transaction_log(&mut self) -> Vec { std::mem::take(&mut self.transaction_log) } @@ -100,7 +154,7 @@ impl MockRow { where T: ValueType, { - Ok(self.values.get(col).unwrap().clone().unwrap()) + T::try_from(self.values.get(col).unwrap().clone()).map_err(|e| DbErr::Query(e.to_string())) } pub fn into_column_value_tuples(self) -> impl Iterator { @@ -134,3 +188,372 @@ impl IntoMockRow for BTreeMap<&str, Value> { } } } + +impl Transaction { + pub fn from_sql_and_values(db_backend: DbBackend, sql: &str, values: I) -> Self + where + I: IntoIterator, + { + Self::one(Statement::from_string_values_tuple( + db_backend, + (sql.to_string(), Values(values.into_iter().collect())), + )) + } + + /// Create a Transaction with one statement + pub fn one(stmt: Statement) -> Self { + Self { stmts: vec![stmt] } + } + + /// Create a Transaction with many statements + pub fn many(stmts: I) -> Self + where + I: IntoIterator, + { + Self { + stmts: stmts.into_iter().collect(), + } + } + + /// Wrap each Statement as a single-statement Transaction + pub fn wrap(stmts: I) -> Vec + where + I: IntoIterator, + { + stmts.into_iter().map(Self::one).collect() + } +} + +impl OpenTransaction { + fn init() -> Self { + Self { + stmts: Vec::new(), + transaction_depth: 0, + } + } + + fn begin_nested(&mut self, db_backend: DbBackend) { + self.transaction_depth += 1; + self.push(Statement::from_string( + db_backend, + format!("SAVEPOINT savepoint_{}", self.transaction_depth), + )); + } + + fn commit(&mut self, db_backend: DbBackend) -> bool { + if self.transaction_depth == 0 { + self.push(Statement::from_string(db_backend, "COMMIT".to_owned())); + true + } else { + self.push(Statement::from_string( + db_backend, + format!("RELEASE SAVEPOINT savepoint_{}", self.transaction_depth), + )); + self.transaction_depth -= 1; + false + } + } + + fn rollback(&mut self, db_backend: DbBackend) -> bool { + if self.transaction_depth == 0 { + self.push(Statement::from_string(db_backend, "ROLLBACK".to_owned())); + true + } else { + self.push(Statement::from_string( + db_backend, + format!("ROLLBACK TO SAVEPOINT savepoint_{}", self.transaction_depth), + )); + self.transaction_depth -= 1; + false + } + } + + fn push(&mut self, stmt: Statement) { + self.stmts.push(stmt); + } + + fn into_transaction(self) -> Transaction { + if self.transaction_depth != 0 { + panic!("There is uncommitted nested transaction."); + } + Transaction { stmts: self.stmts } + } +} + +#[cfg(test)] +#[cfg(feature = "mock")] +mod tests { + use crate::{ + entity::*, tests_cfg::*, ConnectionTrait, DbBackend, DbErr, MockDatabase, Statement, + Transaction, TransactionError, + }; + use pretty_assertions::assert_eq; + + #[derive(Debug, PartialEq)] + pub struct MyErr(String); + + impl std::error::Error for MyErr {} + + impl std::fmt::Display for MyErr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0.as_str()) + } + } + + #[smol_potat::test] + async fn test_transaction_1() { + let db = MockDatabase::new(DbBackend::Postgres).into_connection(); + + db.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _1 = cake::Entity::find().one(txn).await; + let _2 = fruit::Entity::find().all(txn).await; + + Ok(()) + }) + }) + .await + .unwrap(); + + let _ = cake::Entity::find().all(&db).await; + + assert_eq!( + db.into_transaction_log(), + vec![ + Transaction::many(vec![ + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#, + vec![1u64.into()] + ), + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#, + vec![] + ), + Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()), + ]), + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake""#, + vec![] + ), + ] + ); + } + + #[smol_potat::test] + async fn test_transaction_2() { + let db = MockDatabase::new(DbBackend::Postgres).into_connection(); + + let result = db + .transaction::<_, (), MyErr>(|txn| { + Box::pin(async move { + let _ = cake::Entity::find().one(txn).await; + Err(MyErr("test".to_owned())) + }) + }) + .await; + + match result { + Err(TransactionError::Transaction(err)) => { + assert_eq!(err, MyErr("test".to_owned())) + } + _ => panic!(), + } + + assert_eq!( + db.into_transaction_log(), + vec![Transaction::many(vec![ + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#, + vec![1u64.into()] + ), + Statement::from_string(DbBackend::Postgres, "ROLLBACK".to_owned()), + ]),] + ); + } + + #[smol_potat::test] + async fn test_nested_transaction_1() { + let db = MockDatabase::new(DbBackend::Postgres).into_connection(); + + db.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = cake::Entity::find().one(txn).await; + + txn.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = fruit::Entity::find().all(txn).await; + + Ok(()) + }) + }) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + assert_eq!( + db.into_transaction_log(), + vec![Transaction::many(vec![ + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#, + vec![1u64.into()] + ), + Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_1".to_owned()), + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#, + vec![] + ), + Statement::from_string( + DbBackend::Postgres, + "RELEASE SAVEPOINT savepoint_1".to_owned() + ), + Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()), + ]),] + ); + } + + #[smol_potat::test] + async fn test_nested_transaction_2() { + let db = MockDatabase::new(DbBackend::Postgres).into_connection(); + + db.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = cake::Entity::find().one(txn).await; + + txn.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = fruit::Entity::find().all(txn).await; + + txn.transaction::<_, (), DbErr>(|txn| { + Box::pin(async move { + let _ = cake::Entity::find().all(txn).await; + + Ok(()) + }) + }) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + assert_eq!( + db.into_transaction_log(), + vec![Transaction::many(vec![ + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#, + vec![1u64.into()] + ), + Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_1".to_owned()), + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#, + vec![] + ), + Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_2".to_owned()), + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake""#, + vec![] + ), + Statement::from_string( + DbBackend::Postgres, + "RELEASE SAVEPOINT savepoint_2".to_owned() + ), + Statement::from_string( + DbBackend::Postgres, + "RELEASE SAVEPOINT savepoint_1".to_owned() + ), + Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()), + ]),] + ); + } + + #[smol_potat::test] + async fn test_stream_1() -> Result<(), DbErr> { + use futures::TryStreamExt; + + let apple = fruit::Model { + id: 1, + name: "Apple".to_owned(), + cake_id: Some(1), + }; + + let orange = fruit::Model { + id: 2, + name: "orange".to_owned(), + cake_id: None, + }; + + let db = MockDatabase::new(DbBackend::Postgres) + .append_query_results(vec![vec![apple.clone(), orange.clone()]]) + .into_connection(); + + let mut stream = fruit::Entity::find().stream(&db).await?; + + assert_eq!(stream.try_next().await?, Some(apple)); + + assert_eq!(stream.try_next().await?, Some(orange)); + + assert_eq!(stream.try_next().await?, None); + + Ok(()) + } + + #[smol_potat::test] + async fn test_stream_in_transaction() -> Result<(), DbErr> { + use futures::TryStreamExt; + + let apple = fruit::Model { + id: 1, + name: "Apple".to_owned(), + cake_id: Some(1), + }; + + let orange = fruit::Model { + id: 2, + name: "orange".to_owned(), + cake_id: None, + }; + + let db = MockDatabase::new(DbBackend::Postgres) + .append_query_results(vec![vec![apple.clone(), orange.clone()]]) + .into_connection(); + + let txn = db.begin().await?; + + if let Ok(mut stream) = fruit::Entity::find().stream(&txn).await { + assert_eq!(stream.try_next().await?, Some(apple)); + + assert_eq!(stream.try_next().await?, Some(orange)); + + assert_eq!(stream.try_next().await?, None); + + // stream will be dropped end of scope OR std::mem::drop(stream); + } + + txn.commit().await?; + + Ok(()) + } +} diff --git a/src/database/mod.rs b/src/database/mod.rs index f61343c1..a1dfea93 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,13 +1,17 @@ mod connection; +mod db_connection; #[cfg(feature = "mock")] mod mock; mod statement; +mod stream; mod transaction; pub use connection::*; +pub use db_connection::*; #[cfg(feature = "mock")] pub use mock::*; pub use statement::*; +pub use stream::*; pub use transaction::*; use crate::DbErr; diff --git a/src/database/stream/mod.rs b/src/database/stream/mod.rs new file mode 100644 index 00000000..774cf45f --- /dev/null +++ b/src/database/stream/mod.rs @@ -0,0 +1,5 @@ +mod query; +mod transaction; + +pub use query::*; +pub use transaction::*; diff --git a/src/database/stream/query.rs b/src/database/stream/query.rs new file mode 100644 index 00000000..8383659a --- /dev/null +++ b/src/database/stream/query.rs @@ -0,0 +1,109 @@ +use std::{pin::Pin, task::Poll}; + +#[cfg(feature = "mock")] +use std::sync::Arc; + +use futures::Stream; +#[cfg(feature = "sqlx-dep")] +use futures::TryStreamExt; + +#[cfg(feature = "sqlx-dep")] +use sqlx::{pool::PoolConnection, Executor}; + +use crate::{DbErr, InnerConnection, QueryResult, Statement}; + +#[ouroboros::self_referencing] +pub struct QueryStream { + stmt: Statement, + conn: InnerConnection, + #[borrows(mut conn, stmt)] + #[not_covariant] + stream: Pin> + 'this>>, +} + +#[cfg(feature = "sqlx-mysql")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::MySql(conn)) + } +} + +#[cfg(feature = "sqlx-postgres")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::Postgres(conn)) + } +} + +#[cfg(feature = "sqlx-sqlite")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::Sqlite(conn)) + } +} + +#[cfg(feature = "mock")] +impl From<(Arc, Statement)> for QueryStream { + fn from((conn, stmt): (Arc, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::Mock(conn)) + } +} + +impl std::fmt::Debug for QueryStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "QueryStream") + } +} + +impl QueryStream { + fn build(stmt: Statement, conn: InnerConnection) -> QueryStream { + QueryStreamBuilder { + stmt, + conn, + stream_builder: |conn, stmt| match conn { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + let query = crate::driver::sqlx_mysql::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + let query = crate::driver::sqlx_postgres::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) + } + #[cfg(feature = "mock")] + InnerConnection::Mock(c) => c.fetch(stmt), + }, + } + .build() + } +} + +impl Stream for QueryStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + this.with_stream_mut(|stream| stream.as_mut().poll_next(cx)) + } +} diff --git a/src/database/stream/transaction.rs b/src/database/stream/transaction.rs new file mode 100644 index 00000000..2dddc59c --- /dev/null +++ b/src/database/stream/transaction.rs @@ -0,0 +1,93 @@ +use std::{ops::DerefMut, pin::Pin, task::Poll}; + +use futures::Stream; +#[cfg(feature = "sqlx-dep")] +use futures::TryStreamExt; + +#[cfg(feature = "sqlx-dep")] +use sqlx::Executor; + +use futures::lock::MutexGuard; + +use crate::{DbErr, InnerConnection, QueryResult, Statement}; + +#[ouroboros::self_referencing] +/// `TransactionStream` cannot be used in a `transaction` closure as it does not impl `Send`. +/// It seems to be a Rust limitation right now, and solution to work around this deemed to be extremely hard. +pub struct TransactionStream<'a> { + stmt: Statement, + conn: MutexGuard<'a, InnerConnection>, + #[borrows(mut conn, stmt)] + #[not_covariant] + stream: Pin> + 'this>>, +} + +impl<'a> std::fmt::Debug for TransactionStream<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TransactionStream") + } +} + +impl<'a> TransactionStream<'a> { + pub(crate) async fn build( + conn: MutexGuard<'a, InnerConnection>, + stmt: Statement, + ) -> TransactionStream<'a> { + TransactionStreamAsyncBuilder { + stmt, + conn, + stream_builder: |conn, stmt| { + Box::pin(async move { + match conn.deref_mut() { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + let query = crate::driver::sqlx_mysql::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) + as Pin>>> + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + let query = crate::driver::sqlx_postgres::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) + as Pin>>> + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) + as Pin>>> + } + #[cfg(feature = "mock")] + InnerConnection::Mock(c) => c.fetch(stmt), + } + }) + }, + } + .build() + .await + } +} + +impl<'a> Stream for TransactionStream<'a> { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + this.with_stream_mut(|stream| stream.as_mut().poll_next(cx)) + } +} diff --git a/src/database/transaction.rs b/src/database/transaction.rs index 6bf06491..d7bbc058 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -1,42 +1,367 @@ -use crate::{DbBackend, Statement}; -use sea_query::{Value, Values}; +use crate::{ + debug_print, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, + Statement, TransactionStream, +}; +#[cfg(feature = "sqlx-dep")] +use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err}; +use futures::lock::Mutex; +#[cfg(feature = "sqlx-dep")] +use sqlx::{pool::PoolConnection, TransactionManager}; +use std::{future::Future, pin::Pin, sync::Arc}; -#[derive(Debug, Clone, PartialEq)] -pub struct Transaction { - stmts: Vec, +// a Transaction is just a sugar for a connection where START TRANSACTION has been executed +pub struct DatabaseTransaction { + conn: Arc>, + backend: DbBackend, + open: bool, } -impl Transaction { - pub fn from_sql_and_values(db_backend: DbBackend, sql: &str, values: I) -> Self - where - I: IntoIterator, - { - Self::one(Statement::from_string_values_tuple( - db_backend, - (sql.to_string(), Values(values.into_iter().collect())), - )) +impl std::fmt::Debug for DatabaseTransaction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DatabaseTransaction") + } +} + +impl DatabaseTransaction { + #[cfg(feature = "sqlx-mysql")] + pub(crate) async fn new_mysql( + inner: PoolConnection, + ) -> Result { + Self::begin( + Arc::new(Mutex::new(InnerConnection::MySql(inner))), + DbBackend::MySql, + ) + .await } - /// Create a Transaction with one statement - pub fn one(stmt: Statement) -> Self { - Self { stmts: vec![stmt] } + #[cfg(feature = "sqlx-postgres")] + pub(crate) async fn new_postgres( + inner: PoolConnection, + ) -> Result { + Self::begin( + Arc::new(Mutex::new(InnerConnection::Postgres(inner))), + DbBackend::Postgres, + ) + .await } - /// Create a Transaction with many statements - pub fn many(stmts: I) -> Self + #[cfg(feature = "sqlx-sqlite")] + pub(crate) async fn new_sqlite( + inner: PoolConnection, + ) -> Result { + Self::begin( + Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), + DbBackend::Sqlite, + ) + .await + } + + #[cfg(feature = "mock")] + pub(crate) async fn new_mock( + inner: Arc, + ) -> Result { + let backend = inner.get_database_backend(); + Self::begin(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await + } + + async fn begin( + conn: Arc>, + backend: DbBackend, + ) -> Result { + let res = DatabaseTransaction { + conn, + backend, + open: true, + }; + match *res.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::begin(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::begin(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::begin(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "mock")] + InnerConnection::Mock(ref mut c) => { + c.begin(); + } + } + Ok(res) + } + + pub(crate) async fn run(self, callback: F) -> Result> where - I: IntoIterator, + F: for<'b> FnOnce( + &'b DatabaseTransaction, + ) -> Pin> + Send + 'b>> + + Send, + T: Send, + E: std::error::Error + Send, { - Self { - stmts: stmts.into_iter().collect(), + let res = callback(&self).await.map_err(TransactionError::Transaction); + if res.is_ok() { + self.commit().await.map_err(TransactionError::Connection)?; + } else { + self.rollback() + .await + .map_err(TransactionError::Connection)?; + } + res + } + + pub async fn commit(mut self) -> Result<(), DbErr> { + self.open = false; + match *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::commit(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::commit(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::commit(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "mock")] + InnerConnection::Mock(ref mut c) => { + c.commit(); + } + } + Ok(()) + } + + pub async fn rollback(mut self) -> Result<(), DbErr> { + self.open = false; + match *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::rollback(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::rollback(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::rollback(c) + .await + .map_err(sqlx_error_to_query_err)? + } + #[cfg(feature = "mock")] + InnerConnection::Mock(ref mut c) => { + c.rollback(); + } + } + Ok(()) + } + + // the rollback is queued and will be performed on next async operation, like returning the connection to the pool + fn start_rollback(&mut self) { + if self.open { + if let Some(mut conn) = self.conn.try_lock() { + match &mut *conn { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + ::TransactionManager::start_rollback(c); + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + ::TransactionManager::start_rollback(c); + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + ::TransactionManager::start_rollback(c); + } + #[cfg(feature = "mock")] + InnerConnection::Mock(c) => { + c.rollback(); + } + } + } else { + //this should never happen + panic!("Dropping a locked Transaction"); + } + } + } +} + +impl Drop for DatabaseTransaction { + fn drop(&mut self) { + self.start_rollback(); + } +} + +#[async_trait::async_trait] +impl<'a> ConnectionTrait<'a> for DatabaseTransaction { + type Stream = TransactionStream<'a>; + + fn get_database_backend(&self) -> DbBackend { + // this way we don't need to lock + self.backend + } + + async fn execute(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query.execute(conn).await.map(Into::into) + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query.execute(conn).await.map(Into::into) + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query.execute(conn).await.map(Into::into) + } + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.execute(stmt), + }; + #[cfg(feature = "sqlx-dep")] + _res.map_err(sqlx_error_to_exec_err) + } + + async fn query_one(&self, stmt: Statement) -> Result, DbErr> { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query.fetch_one(conn).await.map(|row| Some(row.into())) + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query.fetch_one(conn).await.map(|row| Some(row.into())) + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query.fetch_one(conn).await.map(|row| Some(row.into())) + } + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.query_one(stmt), + }; + #[cfg(feature = "sqlx-dep")] + if let Err(sqlx::Error::RowNotFound) = _res { + Ok(None) + } else { + _res.map_err(sqlx_error_to_query_err) } } - /// Wrap each Statement as a single-statement Transaction - pub fn wrap(stmts: I) -> Vec + async fn query_all(&self, stmt: Statement) -> Result, DbErr> { + debug_print!("{}", stmt); + + let _res = match &mut *self.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(conn) => { + let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); + query + .fetch_all(conn) + .await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(conn) => { + let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); + query + .fetch_all(conn) + .await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(conn) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query + .fetch_all(conn) + .await + .map(|rows| rows.into_iter().map(|r| r.into()).collect()) + } + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.query_all(stmt), + }; + #[cfg(feature = "sqlx-dep")] + _res.map_err(sqlx_error_to_query_err) + } + + fn stream( + &'a self, + stmt: Statement, + ) -> Pin> + 'a>> { + Box::pin( + async move { Ok(crate::TransactionStream::build(self.conn.lock().await, stmt).await) }, + ) + } + + async fn begin(&self) -> Result { + DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend).await + } + + /// Execute the function inside a transaction. + /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. + async fn transaction(&self, _callback: F) -> Result> where - I: IntoIterator, + F: for<'c> FnOnce( + &'c DatabaseTransaction, + ) -> Pin> + Send + 'c>> + + Send, + T: Send, + E: std::error::Error + Send, { - stmts.into_iter().map(Self::one).collect() + let transaction = self.begin().await.map_err(TransactionError::Connection)?; + transaction.run(_callback).await } } + +#[derive(Debug)] +pub enum TransactionError +where + E: std::error::Error, +{ + Connection(DbErr), + Transaction(E), +} + +impl std::fmt::Display for TransactionError +where + E: std::error::Error, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TransactionError::Connection(e) => std::fmt::Display::fmt(e, f), + TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f), + } + } +} + +impl std::error::Error for TransactionError where E: std::error::Error {} diff --git a/src/driver/mock.rs b/src/driver/mock.rs index 0e398586..ad0c35cf 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -2,10 +2,14 @@ use crate::{ debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult, Statement, Transaction, }; -use std::fmt::Debug; -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Mutex, +use futures::Stream; +use std::{ + fmt::Debug, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, }; #[derive(Debug)] @@ -22,6 +26,12 @@ pub trait MockDatabaseTrait: Send + Debug { fn query(&mut self, counter: usize, stmt: Statement) -> Result, DbErr>; + fn begin(&mut self); + + fn commit(&mut self); + + fn rollback(&mut self); + fn drain_transaction_log(&mut self) -> Vec; fn get_database_backend(&self) -> DbBackend; @@ -49,9 +59,9 @@ impl MockDatabaseConnector { pub async fn connect(string: &str) -> Result { macro_rules! connect_mock_db { ( $syntax: expr ) => { - Ok(DatabaseConnection::MockDatabaseConnection( + Ok(DatabaseConnection::MockDatabaseConnection(Arc::new( MockDatabaseConnection::new(MockDatabase::new($syntax)), - )) + ))) }; } @@ -82,30 +92,52 @@ impl MockDatabaseConnection { } } - pub fn get_mocker_mutex(&self) -> &Mutex> { + pub(crate) fn get_mocker_mutex(&self) -> &Mutex> { &self.mocker } - pub async fn execute(&self, statement: Statement) -> Result { + pub fn get_database_backend(&self) -> DbBackend { + self.mocker.lock().unwrap().get_database_backend() + } + + pub fn execute(&self, statement: Statement) -> Result { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); self.mocker.lock().unwrap().execute(counter, statement) } - pub async fn query_one(&self, statement: Statement) -> Result, DbErr> { + pub fn query_one(&self, statement: Statement) -> Result, DbErr> { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); let result = self.mocker.lock().unwrap().query(counter, statement)?; Ok(result.into_iter().next()) } - pub async fn query_all(&self, statement: Statement) -> Result, DbErr> { + pub fn query_all(&self, statement: Statement) -> Result, DbErr> { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); self.mocker.lock().unwrap().query(counter, statement) } - pub fn get_database_backend(&self) -> DbBackend { - self.mocker.lock().unwrap().get_database_backend() + pub fn fetch( + &self, + statement: &Statement, + ) -> Pin>>> { + match self.query_all(statement.clone()) { + Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(Ok))), + Err(e) => Box::pin(futures::stream::iter(Some(Err(e)).into_iter())), + } + } + + pub fn begin(&self) { + self.mocker.lock().unwrap().begin() + } + + pub fn commit(&self) { + self.mocker.lock().unwrap().commit() + } + + pub fn rollback(&self) { + self.mocker.lock().unwrap().rollback() } } diff --git a/src/driver/mod.rs b/src/driver/mod.rs index 6f6cfb64..33b6c847 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -3,11 +3,11 @@ mod mock; #[cfg(feature = "sqlx-dep")] mod sqlx_common; #[cfg(feature = "sqlx-mysql")] -mod sqlx_mysql; +pub(crate) mod sqlx_mysql; #[cfg(feature = "sqlx-postgres")] -mod sqlx_postgres; +pub(crate) mod sqlx_postgres; #[cfg(feature = "sqlx-sqlite")] -mod sqlx_sqlite; +pub(crate) mod sqlx_sqlite; #[cfg(feature = "mock")] pub use mock::*; diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index c542a9b4..6b6f9507 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -1,3 +1,5 @@ +use std::{future::Future, pin::Pin}; + use sqlx::{ mysql::{MySqlArguments, MySqlQueryResult, MySqlRow}, MySql, MySqlPool, @@ -6,7 +8,10 @@ use sqlx::{ sea_query::sea_query_driver_mysql!(); use sea_query_driver_mysql::bind_query; -use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; +use crate::{ + debug_print, error::*, executor::*, DatabaseConnection, DatabaseTransaction, QueryStream, + Statement, TransactionError, +}; use super::sqlx_common::*; @@ -20,7 +25,7 @@ pub struct SqlxMySqlPoolConnection { impl SqlxMySqlConnector { pub fn accepts(string: &str) -> bool { - DbBackend::MySql.is_prefix_of(string) + string.starts_with("mysql://") } pub async fn connect(string: &str) -> Result { @@ -91,6 +96,49 @@ impl SqlxMySqlPoolConnection { )) } } + + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn begin(&self) -> Result { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_mysql(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn transaction(&self, callback: F) -> Result> + where + F: for<'b> FnOnce( + &'b DatabaseTransaction, + ) -> Pin> + Send + 'b>> + + Send, + T: Send, + E: std::error::Error + Send, + { + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_mysql(conn) + .await + .map_err(|e| TransactionError::Connection(e))?; + transaction.run(callback).await + } else { + Err(TransactionError::Connection(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + ))) + } + } } impl From for QueryResult { @@ -109,7 +157,7 @@ impl From for ExecResult { } } -fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySqlArguments> { +pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySqlArguments> { let mut query = sqlx::query(&stmt.sql); if let Some(values) = &stmt.values { query = bind_query(query, values); diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index fb5402eb..13cb51cd 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -1,3 +1,5 @@ +use std::{future::Future, pin::Pin}; + use sqlx::{ postgres::{PgArguments, PgQueryResult, PgRow}, PgPool, Postgres, @@ -6,7 +8,10 @@ use sqlx::{ sea_query::sea_query_driver_postgres!(); use sea_query_driver_postgres::bind_query; -use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; +use crate::{ + debug_print, error::*, executor::*, DatabaseConnection, DatabaseTransaction, QueryStream, + Statement, TransactionError, +}; use super::sqlx_common::*; @@ -20,7 +25,7 @@ pub struct SqlxPostgresPoolConnection { impl SqlxPostgresConnector { pub fn accepts(string: &str) -> bool { - DbBackend::Postgres.is_prefix_of(string) + string.starts_with("postgres://") } pub async fn connect(string: &str) -> Result { @@ -91,6 +96,49 @@ impl SqlxPostgresPoolConnection { )) } } + + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn begin(&self) -> Result { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_postgres(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn transaction(&self, callback: F) -> Result> + where + F: for<'b> FnOnce( + &'b DatabaseTransaction, + ) -> Pin> + Send + 'b>> + + Send, + T: Send, + E: std::error::Error + Send, + { + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_postgres(conn) + .await + .map_err(|e| TransactionError::Connection(e))?; + transaction.run(callback).await + } else { + Err(TransactionError::Connection(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + ))) + } + } } impl From for QueryResult { @@ -109,7 +157,7 @@ impl From for ExecResult { } } -fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Postgres, PgArguments> { +pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Postgres, PgArguments> { let mut query = sqlx::query(&stmt.sql); if let Some(values) = &stmt.values { query = bind_query(query, values); diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index b02f4408..ea2e05e3 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -1,12 +1,17 @@ +use std::{future::Future, pin::Pin}; + use sqlx::{ - sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}, + sqlite::{SqliteArguments, SqlitePoolOptions, SqliteQueryResult, SqliteRow}, Sqlite, SqlitePool, }; sea_query::sea_query_driver_sqlite!(); use sea_query_driver_sqlite::bind_query; -use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; +use crate::{ + debug_print, error::*, executor::*, DatabaseConnection, DatabaseTransaction, QueryStream, + Statement, TransactionError, +}; use super::sqlx_common::*; @@ -20,11 +25,15 @@ pub struct SqlxSqlitePoolConnection { impl SqlxSqliteConnector { pub fn accepts(string: &str) -> bool { - DbBackend::Sqlite.is_prefix_of(string) + string.starts_with("sqlite:") } pub async fn connect(string: &str) -> Result { - if let Ok(pool) = SqlitePool::connect(string).await { + if let Ok(pool) = SqlitePoolOptions::new() + .max_connections(1) + .connect(string) + .await + { Ok(DatabaseConnection::SqlxSqlitePoolConnection( SqlxSqlitePoolConnection { pool }, )) @@ -91,6 +100,49 @@ impl SqlxSqlitePoolConnection { )) } } + + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn begin(&self) -> Result { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_sqlite(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn transaction(&self, callback: F) -> Result> + where + F: for<'b> FnOnce( + &'b DatabaseTransaction, + ) -> Pin> + Send + 'b>> + + Send, + T: Send, + E: std::error::Error + Send, + { + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_sqlite(conn) + .await + .map_err(|e| TransactionError::Connection(e))?; + transaction.run(callback).await + } else { + Err(TransactionError::Connection(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + ))) + } + } } impl From for QueryResult { @@ -109,7 +161,7 @@ impl From for ExecResult { } } -fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqliteArguments> { +pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqliteArguments> { let mut query = sqlx::query(&stmt.sql); if let Some(values) = &stmt.values { query = bind_query(query, values); diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index 7cdd9071..d6625af7 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -1,8 +1,8 @@ use crate::{ - error::*, DatabaseConnection, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, - PrimaryKeyTrait, Value, + error::*, ConnectionTrait, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, Value, }; use async_trait::async_trait; +use sea_query::ValueTuple; use std::fmt::Debug; #[derive(Clone, Debug, Default)] @@ -10,7 +10,8 @@ pub struct ActiveValue where V: Into, { - value: Option, + // Don't want to call ActiveValue::unwrap() and cause panic + pub(self) value: Option, state: ActiveValueState, } @@ -67,33 +68,64 @@ pub trait ActiveModelTrait: Clone + Debug { fn default() -> Self; - async fn insert(self, db: &DatabaseConnection) -> Result - where - Self: ActiveModelBehavior, - ::Model: IntoActiveModel, - { - let am = ActiveModelBehavior::before_save(self, true, db)?; - let res = ::insert(am).exec(db).await?; - // Assume valid last_insert_id is not equals to Default::default() - if res.last_insert_id - != <::PrimaryKey as PrimaryKeyTrait>::ValueType::default() - { - let found = ::find_by_id(res.last_insert_id) - .one(db) - .await?; - let am = match found { - Some(model) => Ok(model.into_active_model()), - None => Err(DbErr::Exec("Failed to find inserted item".to_owned())), - }?; - ActiveModelBehavior::after_save(am, true, db) - } else { - Ok(Self::default()) + #[allow(clippy::question_mark)] + fn get_primary_key_value(&self) -> Option { + let mut cols = ::PrimaryKey::iter(); + macro_rules! next { + () => { + if let Some(col) = cols.next() { + if let Some(val) = self.get(col.into_column()).value { + val + } else { + return None; + } + } else { + return None; + } + }; + } + match ::PrimaryKey::iter().count() { + 1 => { + let s1 = next!(); + Some(ValueTuple::One(s1)) + } + 2 => { + let s1 = next!(); + let s2 = next!(); + Some(ValueTuple::Two(s1, s2)) + } + 3 => { + let s1 = next!(); + let s2 = next!(); + let s3 = next!(); + Some(ValueTuple::Three(s1, s2, s3)) + } + _ => panic!("The arity cannot be larger than 3"), } } - async fn update(self, db: &DatabaseConnection) -> Result + async fn insert<'a, C>(self, db: &'a C) -> Result where Self: ActiveModelBehavior, + ::Model: IntoActiveModel, + C: ConnectionTrait<'a>, + Self: 'a, + { + let am = ActiveModelBehavior::before_save(self, true, db)?; + let res = ::insert(am).exec(db).await?; + let found = ::find_by_id(res.last_insert_id) + .one(db) + .await?; + match found { + Some(model) => Ok(model.into_active_model()), + None => Err(DbErr::Exec("Failed to find inserted item".to_owned())), + } + } + + async fn update<'a, C>(self, db: &'a C) -> Result + where + C: ConnectionTrait<'a>, + Self: 'a, { let am = ActiveModelBehavior::before_save(self, false, db)?; let am = Self::Entity::update(am).exec(db).await?; @@ -102,10 +134,11 @@ pub trait ActiveModelTrait: Clone + Debug { /// Insert the model if primary key is unset, update otherwise. /// Only works if the entity has auto increment primary key. - async fn save(self, db: &DatabaseConnection) -> Result + async fn save<'a, C>(self, db: &'a C) -> Result where - Self: ActiveModelBehavior, + Self: ActiveModelBehavior + 'a, ::Model: IntoActiveModel, + C: ConnectionTrait<'a>, { let mut am = self; let mut is_update = true; @@ -125,9 +158,10 @@ pub trait ActiveModelTrait: Clone + Debug { } /// Delete an active model by its primary key - async fn delete(self, db: &DatabaseConnection) -> Result + async fn delete<'a, C>(self, db: &'a C) -> Result where - Self: ActiveModelBehavior, + Self: ActiveModelBehavior + 'a, + C: ConnectionTrait<'a>, { let am = ActiveModelBehavior::before_delete(self, db)?; let am_clone = am.clone(); @@ -219,23 +253,23 @@ where matches!(self.state, ActiveValueState::Unset) } - pub fn take(&mut self) -> V { + pub fn take(&mut self) -> Option { self.state = ActiveValueState::Unset; - self.value.take().unwrap() + self.value.take() } pub fn unwrap(self) -> V { self.value.unwrap() } - pub fn into_value(self) -> Value { - self.value.unwrap().into() + pub fn into_value(self) -> Option { + self.value.map(Into::into) } pub fn into_wrapped_value(self) -> ActiveValue { match self.state { - ActiveValueState::Set => ActiveValue::set(self.into_value()), - ActiveValueState::Unchanged => ActiveValue::unchanged(self.into_value()), + ActiveValueState::Set => ActiveValue::set(self.into_value().unwrap()), + ActiveValueState::Unchanged => ActiveValue::unchanged(self.into_value().unwrap()), ActiveValueState::Unset => ActiveValue::unset(), } } diff --git a/src/entity/base_entity.rs b/src/entity/base_entity.rs index 764f2524..aef46207 100644 --- a/src/entity/base_entity.rs +++ b/src/entity/base_entity.rs @@ -510,7 +510,7 @@ pub trait EntityTrait: EntityName { /// /// ``` /// # #[cfg(feature = "mock")] - /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; + /// # use sea_orm::{entity::*, error::*, query::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; /// # /// # let db = MockDatabase::new(DbBackend::Postgres) /// # .append_exec_results(vec![ diff --git a/src/entity/primary_key.rs b/src/entity/primary_key.rs index 463f1482..a5e4cde0 100644 --- a/src/entity/primary_key.rs +++ b/src/entity/primary_key.rs @@ -1,16 +1,16 @@ use super::{ColumnTrait, IdenStatic, Iterable}; use crate::{TryFromU64, TryGetableMany}; -use sea_query::IntoValueTuple; +use sea_query::{FromValueTuple, IntoValueTuple}; use std::fmt::Debug; //LINT: composite primary key cannot auto increment pub trait PrimaryKeyTrait: IdenStatic + Iterable { type ValueType: Sized + Send - + Default + Debug + PartialEq + IntoValueTuple + + FromValueTuple + TryGetableMany + TryFromU64; diff --git a/src/executor/delete.rs b/src/executor/delete.rs index 807bc544..c4a8c7de 100644 --- a/src/executor/delete.rs +++ b/src/executor/delete.rs @@ -1,5 +1,5 @@ use crate::{ - error::*, ActiveModelTrait, DatabaseConnection, DeleteMany, DeleteOne, EntityTrait, Statement, + error::*, ActiveModelTrait, ConnectionTrait, DeleteMany, DeleteOne, EntityTrait, Statement, }; use sea_query::DeleteStatement; use std::future::Future; @@ -18,10 +18,10 @@ impl<'a, A: 'a> DeleteOne where A: ActiveModelTrait, { - pub fn exec( - self, - db: &'a DatabaseConnection, - ) -> impl Future> + 'a { + pub fn exec(self, db: &'a C) -> impl Future> + '_ + where + C: ConnectionTrait<'a>, + { // so that self is dropped before entering await exec_delete_only(self.query, db) } @@ -31,10 +31,10 @@ impl<'a, E> DeleteMany where E: EntityTrait, { - pub fn exec( - self, - db: &'a DatabaseConnection, - ) -> impl Future> + 'a { + pub fn exec(self, db: &'a C) -> impl Future> + '_ + where + C: ConnectionTrait<'a>, + { // so that self is dropped before entering await exec_delete_only(self.query, db) } @@ -45,24 +45,26 @@ impl Deleter { Self { query } } - pub fn exec( - self, - db: &DatabaseConnection, - ) -> impl Future> + '_ { + pub fn exec<'a, C>(self, db: &'a C) -> impl Future> + '_ + where + C: ConnectionTrait<'a>, + { let builder = db.get_database_backend(); exec_delete(builder.build(&self.query), db) } } -async fn exec_delete_only( - query: DeleteStatement, - db: &DatabaseConnection, -) -> Result { +async fn exec_delete_only<'a, C>(query: DeleteStatement, db: &'a C) -> Result +where + C: ConnectionTrait<'a>, +{ Deleter::new(query).exec(db).await } -// Only Statement impl Send -async fn exec_delete(statement: Statement, db: &DatabaseConnection) -> Result { +async fn exec_delete<'a, C>(statement: Statement, db: &'a C) -> Result +where + C: ConnectionTrait<'a>, +{ let result = db.execute(statement).await?; Ok(DeleteResult { rows_affected: result.rows_affected(), diff --git a/src/executor/insert.rs b/src/executor/insert.rs index a44867f7..b4da10c0 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,15 +1,16 @@ use crate::{ - error::*, ActiveModelTrait, DatabaseConnection, DbBackend, EntityTrait, Insert, - PrimaryKeyTrait, Statement, TryFromU64, + error::*, ActiveModelTrait, ConnectionTrait, DbBackend, EntityTrait, Insert, PrimaryKeyTrait, + Statement, TryFromU64, }; -use sea_query::InsertStatement; +use sea_query::{FromValueTuple, InsertStatement, ValueTuple}; use std::{future::Future, marker::PhantomData}; -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Inserter where A: ActiveModelTrait, { + primary_key: Option, query: InsertStatement, model: PhantomData, } @@ -27,14 +28,11 @@ where A: ActiveModelTrait, { #[allow(unused_mut)] - pub fn exec<'a>( - self, - db: &'a DatabaseConnection, - ) -> impl Future, DbErr>> + 'a + pub fn exec<'a, C>(self, db: &'a C) -> impl Future, DbErr>> + '_ where + C: ConnectionTrait<'a>, A: 'a, { - // TODO: extract primary key's value from query // so that self is dropped before entering await let mut query = self.query; if db.get_database_backend() == DbBackend::Postgres { @@ -47,8 +45,7 @@ where ); } } - Inserter::::new(query).exec(db) - // TODO: return primary key if extracted before, otherwise use InsertResult + Inserter::::new(self.primary_key, query).exec(db) } } @@ -56,50 +53,55 @@ impl Inserter where A: ActiveModelTrait, { - pub fn new(query: InsertStatement) -> Self { + pub fn new(primary_key: Option, query: InsertStatement) -> Self { Self { + primary_key, query, model: PhantomData, } } - pub fn exec<'a>( - self, - db: &'a DatabaseConnection, - ) -> impl Future, DbErr>> + 'a + pub fn exec<'a, C>(self, db: &'a C) -> impl Future, DbErr>> + '_ where + C: ConnectionTrait<'a>, A: 'a, { let builder = db.get_database_backend(); - exec_insert(builder.build(&self.query), db) + exec_insert(self.primary_key, builder.build(&self.query), db) } } -// Only Statement impl Send -async fn exec_insert( +async fn exec_insert<'a, A, C>( + primary_key: Option, statement: Statement, - db: &DatabaseConnection, + db: &'a C, ) -> Result, DbErr> where + C: ConnectionTrait<'a>, A: ActiveModelTrait, { type PrimaryKey = <::Entity as EntityTrait>::PrimaryKey; type ValueTypeOf = as PrimaryKeyTrait>::ValueType; - let last_insert_id = match db.get_database_backend() { + let last_insert_id_opt = match db.get_database_backend() { DbBackend::Postgres => { use crate::{sea_query::Iden, Iterable}; let cols = PrimaryKey::::iter() .map(|col| col.to_string()) .collect::>(); let res = db.query_one(statement).await?.unwrap(); - res.try_get_many("", cols.as_ref()).unwrap_or_default() + res.try_get_many("", cols.as_ref()).ok() } _ => { let last_insert_id = db.execute(statement).await?.last_insert_id(); - ValueTypeOf::::try_from_u64(last_insert_id) - .ok() - .unwrap_or_default() + ValueTypeOf::::try_from_u64(last_insert_id).ok() } }; + let last_insert_id = match last_insert_id_opt { + Some(last_insert_id) => last_insert_id, + None => match primary_key { + Some(value_tuple) => FromValueTuple::from_value_tuple(value_tuple), + None => return Err(DbErr::Exec("Fail to unpack last_insert_id".to_owned())), + }, + }; Ok(InsertResult { last_insert_id }) } diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index 608d9dc1..28f8574b 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -1,4 +1,4 @@ -use crate::{error::*, DatabaseConnection, DbBackend, SelectorTrait}; +use crate::{error::*, ConnectionTrait, DbBackend, SelectorTrait}; use async_stream::stream; use futures::Stream; use sea_query::{Alias, Expr, SelectStatement}; @@ -7,21 +7,23 @@ use std::{marker::PhantomData, pin::Pin}; pub type PinBoxStream<'db, Item> = Pin + 'db>>; #[derive(Clone, Debug)] -pub struct Paginator<'db, S> +pub struct Paginator<'db, C, S> where + C: ConnectionTrait<'db>, S: SelectorTrait + 'db, { pub(crate) query: SelectStatement, pub(crate) page: usize, pub(crate) page_size: usize, - pub(crate) db: &'db DatabaseConnection, + pub(crate) db: &'db C, pub(crate) selector: PhantomData, } // LINT: warn if paginator is used without an order by clause -impl<'db, S> Paginator<'db, S> +impl<'db, C, S> Paginator<'db, C, S> where + C: ConnectionTrait<'db>, S: SelectorTrait + 'db, { /// Fetch a specific page; page index starts from zero @@ -155,7 +157,7 @@ where #[cfg(feature = "mock")] mod tests { use crate::entity::prelude::*; - use crate::tests_cfg::*; + use crate::{tests_cfg::*, ConnectionTrait}; use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction}; use futures::TryStreamExt; use sea_query::{Alias, Expr, SelectStatement, Value}; diff --git a/src/executor/query.rs b/src/executor/query.rs index be00cb90..a164e911 100644 --- a/src/executor/query.rs +++ b/src/executor/query.rs @@ -126,12 +126,12 @@ macro_rules! try_getable_unsigned { ( $type: ty ) => { impl TryGetable for $type { fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result { - let column = format!("{}{}", pre, col); + let _column = format!("{}{}", pre, col); match &res.row { #[cfg(feature = "sqlx-mysql")] QueryResultRow::SqlxMySql(row) => { use sqlx::Row; - row.try_get::, _>(column.as_str()) + row.try_get::, _>(_column.as_str()) .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e))) .and_then(|opt| opt.ok_or(TryGetError::Null)) } @@ -142,13 +142,13 @@ macro_rules! try_getable_unsigned { #[cfg(feature = "sqlx-sqlite")] QueryResultRow::SqlxSqlite(row) => { use sqlx::Row; - row.try_get::, _>(column.as_str()) + row.try_get::, _>(_column.as_str()) .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e))) .and_then(|opt| opt.ok_or(TryGetError::Null)) } #[cfg(feature = "mock")] #[allow(unused_variables)] - QueryResultRow::Mock(row) => row.try_get(column.as_str()).map_err(|e| { + QueryResultRow::Mock(row) => row.try_get(_column.as_str()).map_err(|e| { debug_print!("{:#?}", e.to_string()); TryGetError::Null }), @@ -162,12 +162,12 @@ macro_rules! try_getable_mysql { ( $type: ty ) => { impl TryGetable for $type { fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result { - let column = format!("{}{}", pre, col); + let _column = format!("{}{}", pre, col); match &res.row { #[cfg(feature = "sqlx-mysql")] QueryResultRow::SqlxMySql(row) => { use sqlx::Row; - row.try_get::, _>(column.as_str()) + row.try_get::, _>(_column.as_str()) .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e))) .and_then(|opt| opt.ok_or(TryGetError::Null)) } @@ -181,7 +181,7 @@ macro_rules! try_getable_mysql { } #[cfg(feature = "mock")] #[allow(unused_variables)] - QueryResultRow::Mock(row) => row.try_get(column.as_str()).map_err(|e| { + QueryResultRow::Mock(row) => row.try_get(_column.as_str()).map_err(|e| { debug_print!("{:#?}", e.to_string()); TryGetError::Null }), @@ -195,7 +195,7 @@ macro_rules! try_getable_postgres { ( $type: ty ) => { impl TryGetable for $type { fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result { - let column = format!("{}{}", pre, col); + let _column = format!("{}{}", pre, col); match &res.row { #[cfg(feature = "sqlx-mysql")] QueryResultRow::SqlxMySql(_) => { @@ -204,7 +204,7 @@ macro_rules! try_getable_postgres { #[cfg(feature = "sqlx-postgres")] QueryResultRow::SqlxPostgres(row) => { use sqlx::Row; - row.try_get::, _>(column.as_str()) + row.try_get::, _>(_column.as_str()) .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e))) .and_then(|opt| opt.ok_or(TryGetError::Null)) } @@ -214,7 +214,7 @@ macro_rules! try_getable_postgres { } #[cfg(feature = "mock")] #[allow(unused_variables)] - QueryResultRow::Mock(row) => row.try_get(column.as_str()).map_err(|e| { + QueryResultRow::Mock(row) => row.try_get(_column.as_str()).map_err(|e| { debug_print!("{:#?}", e.to_string()); TryGetError::Null }), diff --git a/src/executor/select.rs b/src/executor/select.rs index 0db698f0..a95fe463 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -1,10 +1,12 @@ use crate::{ - error::*, DatabaseConnection, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, + error::*, ConnectionTrait, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, SelectTwoMany, Statement, TryGetableMany, }; +use futures::{Stream, TryStreamExt}; use sea_query::SelectStatement; use std::marker::PhantomData; +use std::pin::Pin; #[derive(Clone, Debug)] pub struct Selector @@ -234,23 +236,45 @@ where Selector::>::with_columns(self.query) } - pub async fn one(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn one<'a, C>(self, db: &C) -> Result, DbErr> + where + C: ConnectionTrait<'a>, + { self.into_model().one(db).await } - pub async fn all(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where + C: ConnectionTrait<'a>, + { self.into_model().all(db).await } - pub fn paginate( + pub async fn stream<'a: 'b, 'b, C>( self, - db: &DatabaseConnection, + db: &'a C, + ) -> Result> + 'b, DbErr> + where + C: ConnectionTrait<'a>, + { + self.into_model().stream(db).await + } + + pub fn paginate<'a, C>( + self, + db: &'a C, page_size: usize, - ) -> Paginator<'_, SelectModel> { + ) -> Paginator<'a, C, SelectModel> + where + C: ConnectionTrait<'a>, + { self.into_model().paginate(db, page_size) } - pub async fn count(self, db: &DatabaseConnection) -> Result { + pub async fn count<'a, C>(self, db: &'a C) -> Result + where + C: ConnectionTrait<'a>, + { self.paginate(db, 1).num_items().await } } @@ -279,29 +303,45 @@ where } } - pub async fn one( - self, - db: &DatabaseConnection, - ) -> Result)>, DbErr> { + pub async fn one<'a, C>(self, db: &C) -> Result)>, DbErr> + where + C: ConnectionTrait<'a>, + { self.into_model().one(db).await } - pub async fn all( - self, - db: &DatabaseConnection, - ) -> Result)>, DbErr> { + pub async fn all<'a, C>(self, db: &C) -> Result)>, DbErr> + where + C: ConnectionTrait<'a>, + { self.into_model().all(db).await } - pub fn paginate( + pub async fn stream<'a: 'b, 'b, C>( self, - db: &DatabaseConnection, + db: &'a C, + ) -> Result), DbErr>> + 'b, DbErr> + where + C: ConnectionTrait<'a>, + { + self.into_model().stream(db).await + } + + pub fn paginate<'a, C>( + self, + db: &'a C, page_size: usize, - ) -> Paginator<'_, SelectTwoModel> { + ) -> Paginator<'a, C, SelectTwoModel> + where + C: ConnectionTrait<'a>, + { self.into_model().paginate(db, page_size) } - pub async fn count(self, db: &DatabaseConnection) -> Result { + pub async fn count<'a, C>(self, db: &'a C) -> Result + where + C: ConnectionTrait<'a>, + { self.paginate(db, 1).num_items().await } } @@ -330,17 +370,27 @@ where } } - pub async fn one( - self, - db: &DatabaseConnection, - ) -> Result)>, DbErr> { + pub async fn one<'a, C>(self, db: &C) -> Result)>, DbErr> + where + C: ConnectionTrait<'a>, + { self.into_model().one(db).await } - pub async fn all( + pub async fn stream<'a: 'b, 'b, C>( self, - db: &DatabaseConnection, - ) -> Result)>, DbErr> { + db: &'a C, + ) -> Result), DbErr>> + 'b, DbErr> + where + C: ConnectionTrait<'a>, + { + self.into_model().stream(db).await + } + + pub async fn all<'a, C>(self, db: &C) -> Result)>, DbErr> + where + C: ConnectionTrait<'a>, + { let rows = self.into_model().all(db).await?; Ok(consolidate_query_result::(rows)) } @@ -375,7 +425,10 @@ where } } - pub async fn one(mut self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn one<'a, C>(mut self, db: &C) -> Result, DbErr> + where + C: ConnectionTrait<'a>, + { let builder = db.get_database_backend(); self.query.limit(1); let row = db.query_one(builder.build(&self.query)).await?; @@ -385,7 +438,10 @@ where } } - pub async fn all(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where + C: ConnectionTrait<'a>, + { let builder = db.get_database_backend(); let rows = db.query_all(builder.build(&self.query)).await?; let mut models = Vec::new(); @@ -395,7 +451,25 @@ where Ok(models) } - pub fn paginate(self, db: &DatabaseConnection, page_size: usize) -> Paginator<'_, S> { + pub async fn stream<'a: 'b, 'b, C>( + self, + db: &'a C, + ) -> Result> + 'b>>, DbErr> + where + C: ConnectionTrait<'a>, + S: 'b, + { + let builder = db.get_database_backend(); + let stream = db.stream(builder.build(&self.query)).await?; + Ok(Box::pin(stream.and_then(|row| { + futures::future::ready(S::from_raw_query_result(row)) + }))) + } + + pub fn paginate<'a, C>(self, db: &'a C, page_size: usize) -> Paginator<'a, C, S> + where + C: ConnectionTrait<'a>, + { Paginator { query: self.query, page: 0, @@ -605,7 +679,10 @@ where /// ),] /// ); /// ``` - pub async fn one(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn one<'a, C>(self, db: &C) -> Result, DbErr> + where + C: ConnectionTrait<'a>, + { let row = db.query_one(self.stmt).await?; match row { Some(row) => Ok(Some(S::from_raw_query_result(row)?)), @@ -644,7 +721,10 @@ where /// ),] /// ); /// ``` - pub async fn all(self, db: &DatabaseConnection) -> Result, DbErr> { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where + C: ConnectionTrait<'a>, + { let rows = db.query_all(self.stmt).await?; let mut models = Vec::new(); for row in rows.into_iter() { diff --git a/src/executor/update.rs b/src/executor/update.rs index 6c7a9873..c228730b 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -1,5 +1,5 @@ use crate::{ - error::*, ActiveModelTrait, DatabaseConnection, EntityTrait, Statement, UpdateMany, UpdateOne, + error::*, ActiveModelTrait, ConnectionTrait, EntityTrait, Statement, UpdateMany, UpdateOne, }; use sea_query::UpdateStatement; use std::future::Future; @@ -7,9 +7,10 @@ use std::future::Future; #[derive(Clone, Debug)] pub struct Updater { query: UpdateStatement, + check_record_exists: bool, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct UpdateResult { pub rows_affected: u64, } @@ -18,9 +19,12 @@ impl<'a, A: 'a> UpdateOne where A: ActiveModelTrait, { - pub fn exec(self, db: &'a DatabaseConnection) -> impl Future> + 'a { + pub async fn exec<'b, C>(self, db: &'b C) -> Result + where + C: ConnectionTrait<'b>, + { // so that self is dropped before entering await - exec_update_and_return_original(self.query, self.model, db) + exec_update_and_return_original(self.query, self.model, db).await } } @@ -28,10 +32,10 @@ impl<'a, E> UpdateMany where E: EntityTrait, { - pub fn exec( - self, - db: &'a DatabaseConnection, - ) -> impl Future> + 'a { + pub fn exec(self, db: &'a C) -> impl Future> + '_ + where + C: ConnectionTrait<'a>, + { // so that self is dropped before entering await exec_update_only(self.query, db) } @@ -39,41 +43,198 @@ where impl Updater { pub fn new(query: UpdateStatement) -> Self { - Self { query } + Self { + query, + check_record_exists: false, + } } - pub fn exec( - self, - db: &DatabaseConnection, - ) -> impl Future> + '_ { + pub fn check_record_exists(mut self) -> Self { + self.check_record_exists = true; + self + } + + pub fn exec<'a, C>(self, db: &'a C) -> impl Future> + '_ + where + C: ConnectionTrait<'a>, + { let builder = db.get_database_backend(); - exec_update(builder.build(&self.query), db) + exec_update(builder.build(&self.query), db, self.check_record_exists) } } -async fn exec_update_only( - query: UpdateStatement, - db: &DatabaseConnection, -) -> Result { +async fn exec_update_only<'a, C>(query: UpdateStatement, db: &'a C) -> Result +where + C: ConnectionTrait<'a>, +{ Updater::new(query).exec(db).await } -async fn exec_update_and_return_original( +async fn exec_update_and_return_original<'a, A, C>( query: UpdateStatement, model: A, - db: &DatabaseConnection, + db: &'a C, ) -> Result where A: ActiveModelTrait, + C: ConnectionTrait<'a>, { - Updater::new(query).exec(db).await?; + Updater::new(query).check_record_exists().exec(db).await?; Ok(model) } -// Only Statement impl Send -async fn exec_update(statement: Statement, db: &DatabaseConnection) -> Result { +async fn exec_update<'a, C>( + statement: Statement, + db: &'a C, + check_record_exists: bool, +) -> Result +where + C: ConnectionTrait<'a>, +{ let result = db.execute(statement).await?; + if check_record_exists && result.rows_affected() == 0 { + return Err(DbErr::RecordNotFound( + "None of the database rows are affected".to_owned(), + )); + } Ok(UpdateResult { rows_affected: result.rows_affected(), }) } + +#[cfg(test)] +mod tests { + use crate::{entity::prelude::*, tests_cfg::*, *}; + use pretty_assertions::assert_eq; + use sea_query::Expr; + + #[smol_potat::test] + async fn update_record_not_found_1() -> Result<(), DbErr> { + let db = MockDatabase::new(DbBackend::Postgres) + .append_exec_results(vec![ + MockExecResult { + last_insert_id: 0, + rows_affected: 1, + }, + MockExecResult { + last_insert_id: 0, + rows_affected: 0, + }, + MockExecResult { + last_insert_id: 0, + rows_affected: 0, + }, + MockExecResult { + last_insert_id: 0, + rows_affected: 0, + }, + MockExecResult { + last_insert_id: 0, + rows_affected: 0, + }, + ]) + .into_connection(); + + let model = cake::Model { + id: 1, + name: "New York Cheese".to_owned(), + }; + + assert_eq!( + cake::ActiveModel { + name: Set("Cheese Cake".to_owned()), + ..model.into_active_model() + } + .update(&db) + .await?, + cake::Model { + id: 1, + name: "Cheese Cake".to_owned(), + } + .into_active_model() + ); + + let model = cake::Model { + id: 2, + name: "New York Cheese".to_owned(), + }; + + assert_eq!( + cake::ActiveModel { + name: Set("Cheese Cake".to_owned()), + ..model.clone().into_active_model() + } + .update(&db) + .await, + Err(DbErr::RecordNotFound( + "None of the database rows are affected".to_owned() + )) + ); + + assert_eq!( + cake::Entity::update(cake::ActiveModel { + name: Set("Cheese Cake".to_owned()), + ..model.clone().into_active_model() + }) + .exec(&db) + .await, + Err(DbErr::RecordNotFound( + "None of the database rows are affected".to_owned() + )) + ); + + assert_eq!( + Update::one(cake::ActiveModel { + name: Set("Cheese Cake".to_owned()), + ..model.into_active_model() + }) + .exec(&db) + .await, + Err(DbErr::RecordNotFound( + "None of the database rows are affected".to_owned() + )) + ); + + assert_eq!( + Update::many(cake::Entity) + .col_expr(cake::Column::Name, Expr::value("Cheese Cake".to_owned())) + .filter(cake::Column::Id.eq(2)) + .exec(&db) + .await, + Ok(UpdateResult { rows_affected: 0 }) + ); + + assert_eq!( + db.into_transaction_log(), + vec![ + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, + vec!["Cheese Cake".into(), 1i32.into()] + ), + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, + vec!["Cheese Cake".into(), 2i32.into()] + ), + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, + vec!["Cheese Cake".into(), 2i32.into()] + ), + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, + vec!["Cheese Cake".into(), 2i32.into()] + ), + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, + vec!["Cheese Cake".into(), 2i32.into()] + ), + ] + ); + + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 6ddc442c..1b78cf58 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,7 @@ //! SeaORM is a relational ORM to help you build light weight and concurrent web services in Rust. //! //! [![Getting Started](https://img.shields.io/badge/Getting%20Started-brightgreen)](https://www.sea-ql.org/SeaORM/docs/index) -//! [![Usage Example](https://img.shields.io/badge/Usage%20Example-yellow)](https://github.com/SeaQL/sea-orm/tree/master/examples/async-std) +//! [![Usage Example](https://img.shields.io/badge/Usage%20Example-yellow)](https://github.com/SeaQL/sea-orm/tree/master/examples/basic) //! [![Actix Example](https://img.shields.io/badge/Actix%20Example-blue)](https://github.com/SeaQL/sea-orm/tree/master/examples/actix_example) //! [![Rocket Example](https://img.shields.io/badge/Rocket%20Example-orange)](https://github.com/SeaQL/sea-orm/tree/master/examples/rocket_example) //! [![Discord](https://img.shields.io/discord/873880840487206962?label=Discord)](https://discord.com/invite/uCPdDXzbdv) diff --git a/src/query/insert.rs b/src/query/insert.rs index a65071e1..5e504a0c 100644 --- a/src/query/insert.rs +++ b/src/query/insert.rs @@ -1,14 +1,18 @@ -use crate::{ActiveModelTrait, EntityName, EntityTrait, IntoActiveModel, Iterable, QueryTrait}; +use crate::{ + ActiveModelTrait, EntityName, EntityTrait, IntoActiveModel, Iterable, PrimaryKeyTrait, + QueryTrait, +}; use core::marker::PhantomData; -use sea_query::InsertStatement; +use sea_query::{InsertStatement, ValueTuple}; -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Insert where A: ActiveModelTrait, { pub(crate) query: InsertStatement, pub(crate) columns: Vec, + pub(crate) primary_key: Option, pub(crate) model: PhantomData, } @@ -31,6 +35,7 @@ where .into_table(A::Entity::default().table_ref()) .to_owned(), columns: Vec::new(), + primary_key: None, model: PhantomData, } } @@ -107,6 +112,12 @@ where M: IntoActiveModel, { let mut am: A = m.into_active_model(); + self.primary_key = + if !<::PrimaryKey as PrimaryKeyTrait>::auto_increment() { + am.get_primary_key_value() + } else { + None + }; let mut columns = Vec::new(); let mut values = Vec::new(); let columns_empty = self.columns.is_empty(); @@ -120,7 +131,7 @@ where } if av_has_val { columns.push(col); - values.push(av.into_value()); + values.push(av.into_value().unwrap()); } } self.query.columns(columns); diff --git a/src/query/mod.rs b/src/query/mod.rs index 54cc12dd..fcf8b168 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -8,6 +8,7 @@ mod json; mod select; mod traits; mod update; +mod util; pub use combine::{SelectA, SelectB}; pub use delete::*; @@ -19,5 +20,6 @@ pub use json::*; pub use select::*; pub use traits::*; pub use update::*; +pub use util::*; -pub use crate::{InsertResult, Statement, UpdateResult, Value, Values}; +pub use crate::{ConnectionTrait, InsertResult, Statement, UpdateResult, Value, Values}; diff --git a/src/query/util.rs b/src/query/util.rs new file mode 100644 index 00000000..a6133725 --- /dev/null +++ b/src/query/util.rs @@ -0,0 +1,112 @@ +use crate::{database::*, QueryTrait, Statement}; + +#[derive(Debug)] +pub struct DebugQuery<'a, Q, T> { + pub query: &'a Q, + pub value: T, +} + +macro_rules! debug_query_build { + ($impl_obj:ty, $db_expr:expr) => { + impl<'a, Q> DebugQuery<'a, Q, $impl_obj> + where + Q: QueryTrait, + { + pub fn build(&self) -> Statement { + let func = $db_expr; + let db_backend = func(self); + self.query.build(db_backend) + } + } + }; +} + +debug_query_build!(DbBackend, |x: &DebugQuery<_, DbBackend>| x.value); +debug_query_build!(&DbBackend, |x: &DebugQuery<_, &DbBackend>| *x.value); +debug_query_build!(DatabaseConnection, |x: &DebugQuery< + _, + DatabaseConnection, +>| x.value.get_database_backend()); +debug_query_build!(&DatabaseConnection, |x: &DebugQuery< + _, + &DatabaseConnection, +>| x.value.get_database_backend()); + +/// Helper to get a `Statement` from an object that impl `QueryTrait`. +/// +/// # Example +/// +/// ``` +/// # #[cfg(feature = "mock")] +/// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, DbBackend}; +/// # +/// # let conn = MockDatabase::new(DbBackend::Postgres) +/// # .into_connection(); +/// # +/// use sea_orm::{entity::*, query::*, tests_cfg::cake, debug_query_stmt}; +/// +/// let c = cake::Entity::insert( +/// cake::ActiveModel { +/// id: ActiveValue::set(1), +/// name: ActiveValue::set("Apple Pie".to_owned()), +/// }); +/// +/// let raw_sql = debug_query_stmt!(&c, &conn).to_string(); +/// assert_eq!(raw_sql, r#"INSERT INTO "cake" ("id", "name") VALUES (1, 'Apple Pie')"#); +/// +/// let raw_sql = debug_query_stmt!(&c, conn).to_string(); +/// assert_eq!(raw_sql, r#"INSERT INTO "cake" ("id", "name") VALUES (1, 'Apple Pie')"#); +/// +/// let raw_sql = debug_query_stmt!(&c, DbBackend::MySql).to_string(); +/// assert_eq!(raw_sql, r#"INSERT INTO `cake` (`id`, `name`) VALUES (1, 'Apple Pie')"#); +/// +/// let raw_sql = debug_query_stmt!(&c, &DbBackend::MySql).to_string(); +/// assert_eq!(raw_sql, r#"INSERT INTO `cake` (`id`, `name`) VALUES (1, 'Apple Pie')"#); +/// +/// ``` +#[macro_export] +macro_rules! debug_query_stmt { + ($query:expr,$value:expr) => { + $crate::DebugQuery { + query: $query, + value: $value, + } + .build(); + }; +} + +/// Helper to get a raw SQL string from an object that impl `QueryTrait`. +/// +/// # Example +/// +/// ``` +/// # #[cfg(feature = "mock")] +/// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, DbBackend}; +/// # +/// # let conn = MockDatabase::new(DbBackend::Postgres) +/// # .into_connection(); +/// # +/// use sea_orm::{entity::*, query::*, tests_cfg::cake,debug_query}; +/// +/// let c = cake::Entity::insert( +/// cake::ActiveModel { +/// id: ActiveValue::set(1), +/// name: ActiveValue::set("Apple Pie".to_owned()), +/// }); +/// +/// let raw_sql = debug_query!(&c, &conn); +/// assert_eq!(raw_sql, r#"INSERT INTO "cake" ("id", "name") VALUES (1, 'Apple Pie')"#); +/// +/// let raw_sql = debug_query!(&c, conn); +/// assert_eq!(raw_sql, r#"INSERT INTO "cake" ("id", "name") VALUES (1, 'Apple Pie')"#); +/// +/// let raw_sql = debug_query!(&c, DbBackend::Sqlite); +/// assert_eq!(raw_sql, r#"INSERT INTO `cake` (`id`, `name`) VALUES (1, 'Apple Pie')"#); +/// +/// ``` +#[macro_export] +macro_rules! debug_query { + ($query:expr,$value:expr) => { + $crate::debug_query_stmt!($query, $value).to_string(); + }; +} diff --git a/src/tests_cfg/mod.rs b/src/tests_cfg/mod.rs index 6bc86aed..d6c80b36 100644 --- a/src/tests_cfg/mod.rs +++ b/src/tests_cfg/mod.rs @@ -7,6 +7,7 @@ pub mod cake_filling_price; pub mod entity_linked; pub mod filling; pub mod fruit; +pub mod rust_keyword; pub mod vendor; pub use cake::Entity as Cake; @@ -15,4 +16,5 @@ pub use cake_filling::Entity as CakeFilling; pub use cake_filling_price::Entity as CakeFillingPrice; pub use filling::Entity as Filling; pub use fruit::Entity as Fruit; +pub use rust_keyword::Entity as RustKeyword; pub use vendor::Entity as Vendor; diff --git a/src/tests_cfg/rust_keyword.rs b/src/tests_cfg/rust_keyword.rs new file mode 100644 index 00000000..c8662347 --- /dev/null +++ b/src/tests_cfg/rust_keyword.rs @@ -0,0 +1,141 @@ +use crate as sea_orm; +use crate::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "rust_keyword")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub testing: i32, + pub rust: i32, + pub keywords: i32, + pub r#raw_identifier: i32, + pub r#as: i32, + pub r#async: i32, + pub r#await: i32, + pub r#break: i32, + pub r#const: i32, + pub r#continue: i32, + pub crate_: i32, + pub r#dyn: i32, + pub r#else: i32, + pub r#enum: i32, + pub r#extern: i32, + pub r#false: i32, + pub r#fn: i32, + pub r#for: i32, + pub r#if: i32, + pub r#impl: i32, + pub r#in: i32, + pub r#let: i32, + pub r#loop: i32, + pub r#match: i32, + pub r#mod: i32, + pub r#move: i32, + pub r#mut: i32, + pub r#pub: i32, + pub r#ref: i32, + pub r#return: i32, + pub self_: i32, + pub r#static: i32, + pub r#struct: i32, + pub r#trait: i32, + pub r#true: i32, + pub r#type: i32, + pub r#union: i32, + pub r#unsafe: i32, + pub r#use: i32, + pub r#where: i32, + pub r#while: i32, + pub r#abstract: i32, + pub r#become: i32, + pub r#box: i32, + pub r#do: i32, + pub r#final: i32, + pub r#macro: i32, + pub r#override: i32, + pub r#priv: i32, + pub r#try: i32, + pub r#typeof: i32, + pub r#unsized: i32, + pub r#virtual: i32, + pub r#yield: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter)] +pub enum Relation {} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + match self { + _ => panic!("No RelationDef"), + } + } +} + +impl ActiveModelBehavior for ActiveModel {} + +#[cfg(test)] +mod tests { + use crate::tests_cfg::rust_keyword::*; + use sea_query::Iden; + + #[test] + fn test_columns() { + assert_eq!(Column::Id.to_string().as_str(), "id"); + assert_eq!(Column::Testing.to_string().as_str(), "testing"); + assert_eq!(Column::Rust.to_string().as_str(), "rust"); + assert_eq!(Column::Keywords.to_string().as_str(), "keywords"); + assert_eq!(Column::RawIdentifier.to_string().as_str(), "raw_identifier"); + assert_eq!(Column::As.to_string().as_str(), "as"); + assert_eq!(Column::Async.to_string().as_str(), "async"); + assert_eq!(Column::Await.to_string().as_str(), "await"); + assert_eq!(Column::Break.to_string().as_str(), "break"); + assert_eq!(Column::Const.to_string().as_str(), "const"); + assert_eq!(Column::Continue.to_string().as_str(), "continue"); + assert_eq!(Column::Dyn.to_string().as_str(), "dyn"); + assert_eq!(Column::Crate.to_string().as_str(), "crate"); + assert_eq!(Column::Else.to_string().as_str(), "else"); + assert_eq!(Column::Enum.to_string().as_str(), "enum"); + assert_eq!(Column::Extern.to_string().as_str(), "extern"); + assert_eq!(Column::False.to_string().as_str(), "false"); + assert_eq!(Column::Fn.to_string().as_str(), "fn"); + assert_eq!(Column::For.to_string().as_str(), "for"); + assert_eq!(Column::If.to_string().as_str(), "if"); + assert_eq!(Column::Impl.to_string().as_str(), "impl"); + assert_eq!(Column::In.to_string().as_str(), "in"); + assert_eq!(Column::Let.to_string().as_str(), "let"); + assert_eq!(Column::Loop.to_string().as_str(), "loop"); + assert_eq!(Column::Match.to_string().as_str(), "match"); + assert_eq!(Column::Mod.to_string().as_str(), "mod"); + assert_eq!(Column::Move.to_string().as_str(), "move"); + assert_eq!(Column::Mut.to_string().as_str(), "mut"); + assert_eq!(Column::Pub.to_string().as_str(), "pub"); + assert_eq!(Column::Ref.to_string().as_str(), "ref"); + assert_eq!(Column::Return.to_string().as_str(), "return"); + assert_eq!(Column::Self_.to_string().as_str(), "self"); + assert_eq!(Column::Static.to_string().as_str(), "static"); + assert_eq!(Column::Struct.to_string().as_str(), "struct"); + assert_eq!(Column::Trait.to_string().as_str(), "trait"); + assert_eq!(Column::True.to_string().as_str(), "true"); + assert_eq!(Column::Type.to_string().as_str(), "type"); + assert_eq!(Column::Union.to_string().as_str(), "union"); + assert_eq!(Column::Unsafe.to_string().as_str(), "unsafe"); + assert_eq!(Column::Use.to_string().as_str(), "use"); + assert_eq!(Column::Where.to_string().as_str(), "where"); + assert_eq!(Column::While.to_string().as_str(), "while"); + assert_eq!(Column::Abstract.to_string().as_str(), "abstract"); + assert_eq!(Column::Become.to_string().as_str(), "become"); + assert_eq!(Column::Box.to_string().as_str(), "box"); + assert_eq!(Column::Do.to_string().as_str(), "do"); + assert_eq!(Column::Final.to_string().as_str(), "final"); + assert_eq!(Column::Macro.to_string().as_str(), "macro"); + assert_eq!(Column::Override.to_string().as_str(), "override"); + assert_eq!(Column::Priv.to_string().as_str(), "priv"); + assert_eq!(Column::Try.to_string().as_str(), "try"); + assert_eq!(Column::Typeof.to_string().as_str(), "typeof"); + assert_eq!(Column::Unsized.to_string().as_str(), "unsized"); + assert_eq!(Column::Virtual.to_string().as_str(), "virtual"); + assert_eq!(Column::Yield.to_string().as_str(), "yield"); + } +} diff --git a/tests/basic.rs b/tests/basic.rs index a0763d45..ef379779 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -1,6 +1,6 @@ pub mod common; -pub use sea_orm::{entity::*, error::*, sea_query, tests_cfg::*, Database, DbConn}; +pub use sea_orm::{entity::*, error::*, query::*, sea_query, tests_cfg::*, Database, DbConn}; // cargo test --features sqlx-sqlite,runtime-async-std-native-tls --test basic #[sea_orm_macros::test] diff --git a/tests/common/bakery_chain/metadata.rs b/tests/common/bakery_chain/metadata.rs index de513a22..2c297cd3 100644 --- a/tests/common/bakery_chain/metadata.rs +++ b/tests/common/bakery_chain/metadata.rs @@ -10,8 +10,8 @@ pub struct Model { pub key: String, pub value: String, pub bytes: Vec, - pub date: Date, - pub time: Time, + pub date: Option, + pub time: Option