Merge branch 'master' into active-model-behavior

This commit is contained in:
Billy Chan 2021-10-12 20:37:00 +08:00
commit 5339696da2
No known key found for this signature in database
GPG Key ID: A2D690CAC7DF3CC7
92 changed files with 3706 additions and 494 deletions

View File

@ -5,6 +5,7 @@ on:
push: push:
branches: branches:
- master - master
- 0.2.x
env: env:
CARGO_TERM_COLOR: always CARGO_TERM_COLOR: always
@ -143,6 +144,12 @@ jobs:
args: > args: >
--all --all
- uses: actions-rs/cargo@v1
with:
command: test
args: >
--manifest-path sea-orm-rocket/Cargo.toml
cli: cli:
name: CLI name: CLI
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
@ -170,7 +177,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
path: [async-std, tokio, actix_example, actix4_example, rocket_example] path: [basic, actix_example, actix4_example, rocket_example]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
@ -186,6 +193,28 @@ jobs:
args: > args: >
--manifest-path examples/${{ matrix.path }}/Cargo.toml --manifest-path examples/${{ matrix.path }}/Cargo.toml
issues:
name: Issues
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
path: [86]
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: build
args: >
--manifest-path issues/${{ matrix.path }}/Cargo.toml
sqlite: sqlite:
name: SQLite name: SQLite
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04

View File

@ -5,6 +5,23 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/) The format is based on [Keep a Changelog](http://keepachangelog.com/)
and this project adheres to [Semantic Versioning](http://semver.org/). and this project adheres to [Semantic Versioning](http://semver.org/).
## 0.2.6 - 2021-10-09
- [[#224]] [sea-orm-cli] Date & Time column type mapping
- Escape rust keywords with `r#` raw identifier
[#224]: https://github.com/SeaQL/sea-orm/pull/224
## 0.2.5 - 2021-10-06
- [[#227]] Resolve "Inserting actual none value of Option<Date> results in panic"
- [[#219]] [sea-orm-cli] Add `--tables` option
- [[#189]] Add `debug_query` and `debug_query_stmt` macro
[#227]: https://github.com/SeaQL/sea-orm/issues/227
[#219]: https://github.com/SeaQL/sea-orm/pull/219
[#189]: https://github.com/SeaQL/sea-orm/pull/189
## 0.2.4 - 2021-10-01 ## 0.2.4 - 2021-10-01
- [[#186]] [sea-orm-cli] Foreign key handling - [[#186]] [sea-orm-cli] Foreign key handling

View File

@ -3,7 +3,7 @@ members = [".", "sea-orm-macros", "sea-orm-codegen"]
[package] [package]
name = "sea-orm" name = "sea-orm"
version = "0.2.4" version = "0.2.6"
authors = ["Chris Tsang <tyt2y7@gmail.com>"] authors = ["Chris Tsang <tyt2y7@gmail.com>"]
edition = "2018" edition = "2018"
description = "🐚 An async & dynamic ORM for Rust" description = "🐚 An async & dynamic ORM for Rust"
@ -29,13 +29,14 @@ futures = { version = "^0.3" }
futures-util = { version = "^0.3" } futures-util = { version = "^0.3" }
log = { version = "^0.4", optional = true } log = { version = "^0.4", optional = true }
rust_decimal = { version = "^1", optional = true } rust_decimal = { version = "^1", optional = true }
sea-orm-macros = { version = "^0.2.4", path = "sea-orm-macros", optional = true } sea-orm-macros = { version = "^0.2.6", path = "sea-orm-macros", optional = true }
sea-query = { version = "^0.16.5", features = ["thread-safe"] } sea-query = { version = "^0.17.1", features = ["thread-safe"] }
sea-strum = { version = "^0.21", features = ["derive", "sea-orm"] } sea-strum = { version = "^0.21", features = ["derive", "sea-orm"] }
serde = { version = "^1.0", features = ["derive"] } serde = { version = "^1.0", features = ["derive"] }
serde_json = { version = "^1", optional = true } serde_json = { version = "^1", optional = true }
sqlx = { version = "^0.5", optional = true } sqlx = { version = "^0.5", optional = true }
uuid = { version = "0.8", features = ["serde", "v4"], optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true }
ouroboros = "0.11"
[dev-dependencies] [dev-dependencies]
smol = { version = "^1.2" } smol = { version = "^1.2" }

View File

@ -21,7 +21,7 @@
SeaORM is a relational ORM to help you build light weight and concurrent web services in Rust. SeaORM is a relational ORM to help you build light weight and concurrent web services in Rust.
[![Getting Started](https://img.shields.io/badge/Getting%20Started-brightgreen)](https://www.sea-ql.org/SeaORM/docs/index) [![Getting Started](https://img.shields.io/badge/Getting%20Started-brightgreen)](https://www.sea-ql.org/SeaORM/docs/index)
[![Usage Example](https://img.shields.io/badge/Usage%20Example-yellow)](https://github.com/SeaQL/sea-orm/tree/master/examples/async-std) [![Usage Example](https://img.shields.io/badge/Usage%20Example-yellow)](https://github.com/SeaQL/sea-orm/tree/master/examples/basic)
[![Actix Example](https://img.shields.io/badge/Actix%20Example-blue)](https://github.com/SeaQL/sea-orm/tree/master/examples/actix_example) [![Actix Example](https://img.shields.io/badge/Actix%20Example-blue)](https://github.com/SeaQL/sea-orm/tree/master/examples/actix_example)
[![Rocket Example](https://img.shields.io/badge/Rocket%20Example-orange)](https://github.com/SeaQL/sea-orm/tree/master/examples/rocket_example) [![Rocket Example](https://img.shields.io/badge/Rocket%20Example-orange)](https://github.com/SeaQL/sea-orm/tree/master/examples/rocket_example)
[![Discord](https://img.shields.io/discord/873880840487206962?label=Discord)](https://discord.com/invite/uCPdDXzbdv) [![Discord](https://img.shields.io/discord/873880840487206962?label=Discord)](https://discord.com/invite/uCPdDXzbdv)

View File

@ -1,5 +1,5 @@
use sea_orm::sea_query::{ColumnDef, TableCreateStatement}; use sea_orm::sea_query::{ColumnDef, TableCreateStatement};
use sea_orm::{error::*, sea_query, DbConn, ExecResult}; use sea_orm::{error::*, sea_query, ConnectionTrait, DbConn, ExecResult};
async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result<ExecResult, DbErr> { async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result<ExecResult, DbErr> {
let builder = db.get_database_backend(); let builder = db.get_database_backend();

View File

@ -1,5 +1,5 @@
use sea_orm::sea_query::{ColumnDef, TableCreateStatement}; use sea_orm::sea_query::{ColumnDef, TableCreateStatement};
use sea_orm::{error::*, sea_query, DbConn, ExecResult}; use sea_orm::{error::*, sea_query, ConnectionTrait, DbConn, ExecResult};
async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result<ExecResult, DbErr> { async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result<ExecResult, DbErr> {
let builder = db.get_database_backend(); let builder = db.get_database_backend();

View File

@ -15,15 +15,22 @@ futures-util = { version = "^0.3" }
rocket = { git = "https://github.com/SergioBenitez/Rocket.git", features = [ rocket = { git = "https://github.com/SergioBenitez/Rocket.git", features = [
"json", "json",
] } ] }
rocket_db_pools = { git = "https://github.com/SergioBenitez/Rocket.git" }
rocket_dyn_templates = { git = "https://github.com/SergioBenitez/Rocket.git", features = [ rocket_dyn_templates = { git = "https://github.com/SergioBenitez/Rocket.git", features = [
"tera", "tera",
] } ] }
# remove `path = ""` in your own project
sea-orm = { path = "../../", version = "^0.2.3", features = ["macros"], default-features = false }
serde_json = { version = "^1" } serde_json = { version = "^1" }
[dependencies.sea-orm]
path = "../../" # remove this line in your own project
version = "^0.2.3"
features = ["macros", "runtime-tokio-native-tls"]
default-features = false
[dependencies.sea-orm-rocket]
path = "../../sea-orm-rocket/lib" # remove this line in your own project and use the git line
# git = "https://github.com/SeaQL/sea-orm"
[features] [features]
default = ["sqlx-postgres"] default = ["sqlx-postgres"]
sqlx-mysql = ["sea-orm/sqlx-mysql", "rocket_db_pools/sqlx_mysql"] sqlx-mysql = ["sea-orm/sqlx-mysql"]
sqlx-postgres = ["sea-orm/sqlx-postgres", "rocket_db_pools/sqlx_postgres"] sqlx-postgres = ["sea-orm/sqlx-postgres"]

View File

@ -1,7 +1,7 @@
[default] [default]
template_dir = "templates/" template_dir = "templates/"
[default.databases.rocket_example] [default.databases.sea_orm]
# Mysql # Mysql
# make sure to enable "sqlx-mysql" feature in Cargo.toml, i.e default = ["sqlx-mysql"] # make sure to enable "sqlx-mysql" feature in Cargo.toml, i.e default = ["sqlx-mysql"]
# url = "mysql://root:@localhost/rocket_example" # url = "mysql://root:@localhost/rocket_example"

View File

@ -7,21 +7,17 @@ use rocket::fs::{relative, FileServer};
use rocket::request::FlashMessage; use rocket::request::FlashMessage;
use rocket::response::{Flash, Redirect}; use rocket::response::{Flash, Redirect};
use rocket::{Build, Request, Rocket}; use rocket::{Build, Request, Rocket};
use rocket_db_pools::{sqlx, Connection, Database};
use rocket_dyn_templates::{context, Template}; use rocket_dyn_templates::{context, Template};
use sea_orm::{entity::*, query::*}; use sea_orm::{entity::*, query::*};
use sea_orm_rocket::{Connection, Database};
mod pool; mod pool;
use pool::RocketDbPool; use pool::Db;
mod setup; mod setup;
#[derive(Database, Debug)] type Result<T, E = rocket::response::Debug<sea_orm::DbErr>> = std::result::Result<T, E>;
#[database("rocket_example")]
struct Db(RocketDbPool);
type Result<T, E = rocket::response::Debug<sqlx::Error>> = std::result::Result<T, E>;
mod post; mod post;
pub use post::Entity as Post; pub use post::Entity as Post;
@ -34,7 +30,9 @@ async fn new() -> Template {
} }
#[post("/", data = "<post_form>")] #[post("/", data = "<post_form>")]
async fn create(conn: Connection<Db>, post_form: Form<post::Model>) -> Flash<Redirect> { async fn create(conn: Connection<'_, Db>, post_form: Form<post::Model>) -> Flash<Redirect> {
let db = conn.into_inner();
let form = post_form.into_inner(); let form = post_form.into_inner();
post::ActiveModel { post::ActiveModel {
@ -42,7 +40,7 @@ async fn create(conn: Connection<Db>, post_form: Form<post::Model>) -> Flash<Red
text: Set(form.text.to_owned()), text: Set(form.text.to_owned()),
..Default::default() ..Default::default()
} }
.save(&conn) .save(db)
.await .await
.expect("could not insert post"); .expect("could not insert post");
@ -50,9 +48,11 @@ async fn create(conn: Connection<Db>, post_form: Form<post::Model>) -> Flash<Red
} }
#[post("/<id>", data = "<post_form>")] #[post("/<id>", data = "<post_form>")]
async fn update(conn: Connection<Db>, id: i32, post_form: Form<post::Model>) -> Flash<Redirect> { async fn update(conn: Connection<'_, Db>, id: i32, post_form: Form<post::Model>) -> Flash<Redirect> {
let db = conn.into_inner();
let post: post::ActiveModel = Post::find_by_id(id) let post: post::ActiveModel = Post::find_by_id(id)
.one(&conn) .one(db)
.await .await
.unwrap() .unwrap()
.unwrap() .unwrap()
@ -65,7 +65,7 @@ async fn update(conn: Connection<Db>, id: i32, post_form: Form<post::Model>) ->
title: Set(form.title.to_owned()), title: Set(form.title.to_owned()),
text: Set(form.text.to_owned()), text: Set(form.text.to_owned()),
} }
.save(&conn) .save(db)
.await .await
.expect("could not edit post"); .expect("could not edit post");
@ -74,11 +74,13 @@ async fn update(conn: Connection<Db>, id: i32, post_form: Form<post::Model>) ->
#[get("/?<page>&<posts_per_page>")] #[get("/?<page>&<posts_per_page>")]
async fn list( async fn list(
conn: Connection<Db>, conn: Connection<'_, Db>,
posts_per_page: Option<usize>, posts_per_page: Option<usize>,
page: Option<usize>, page: Option<usize>,
flash: Option<FlashMessage<'_>>, flash: Option<FlashMessage<'_>>,
) -> Template { ) -> Template {
let db = conn.into_inner();
// Set page number and items per page // Set page number and items per page
let page = page.unwrap_or(1); let page = page.unwrap_or(1);
let posts_per_page = posts_per_page.unwrap_or(DEFAULT_POSTS_PER_PAGE); let posts_per_page = posts_per_page.unwrap_or(DEFAULT_POSTS_PER_PAGE);
@ -89,7 +91,7 @@ async fn list(
// Setup paginator // Setup paginator
let paginator = Post::find() let paginator = Post::find()
.order_by_asc(post::Column::Id) .order_by_asc(post::Column::Id)
.paginate(&conn, posts_per_page); .paginate(db, posts_per_page);
let num_pages = paginator.num_pages().await.ok().unwrap(); let num_pages = paginator.num_pages().await.ok().unwrap();
// Fetch paginated posts // Fetch paginated posts
@ -111,9 +113,11 @@ async fn list(
} }
#[get("/<id>")] #[get("/<id>")]
async fn edit(conn: Connection<Db>, id: i32) -> Template { async fn edit(conn: Connection<'_, Db>, id: i32) -> Template {
let db = conn.into_inner();
let post: Option<post::Model> = Post::find_by_id(id) let post: Option<post::Model> = Post::find_by_id(id)
.one(&conn) .one(db)
.await .await
.expect("could not find post"); .expect("could not find post");
@ -126,22 +130,26 @@ async fn edit(conn: Connection<Db>, id: i32) -> Template {
} }
#[delete("/<id>")] #[delete("/<id>")]
async fn delete(conn: Connection<Db>, id: i32) -> Flash<Redirect> { async fn delete(conn: Connection<'_, Db>, id: i32) -> Flash<Redirect> {
let db = conn.into_inner();
let post: post::ActiveModel = Post::find_by_id(id) let post: post::ActiveModel = Post::find_by_id(id)
.one(&conn) .one(db)
.await .await
.unwrap() .unwrap()
.unwrap() .unwrap()
.into(); .into();
post.delete(&conn).await.unwrap(); post.delete(db).await.unwrap();
Flash::success(Redirect::to("/"), "Post successfully deleted.") Flash::success(Redirect::to("/"), "Post successfully deleted.")
} }
#[delete("/")] #[delete("/")]
async fn destroy(conn: Connection<Db>) -> Result<()> { async fn destroy(conn: Connection<'_, Db>) -> Result<()> {
Post::delete_many().exec(&conn).await.unwrap(); let db = conn.into_inner();
Post::delete_many().exec(db).await.unwrap();
Ok(()) Ok(())
} }

View File

@ -1,13 +1,17 @@
use async_trait::async_trait; use async_trait::async_trait;
use rocket_db_pools::{rocket::figment::Figment, Config}; use sea_orm_rocket::{rocket::figment::Figment, Config, Database};
#[derive(Debug)] #[derive(Database, Debug)]
pub struct RocketDbPool { #[database("sea_orm")]
pub struct Db(SeaOrmPool);
#[derive(Debug, Clone)]
pub struct SeaOrmPool {
pub conn: sea_orm::DatabaseConnection, pub conn: sea_orm::DatabaseConnection,
} }
#[async_trait] #[async_trait]
impl rocket_db_pools::Pool for RocketDbPool { impl sea_orm_rocket::Pool for SeaOrmPool {
type Error = sea_orm::DbErr; type Error = sea_orm::DbErr;
type Connection = sea_orm::DatabaseConnection; type Connection = sea_orm::DatabaseConnection;
@ -16,12 +20,10 @@ impl rocket_db_pools::Pool for RocketDbPool {
let config = figment.extract::<Config>().unwrap(); let config = figment.extract::<Config>().unwrap();
let conn = sea_orm::Database::connect(&config.url).await.unwrap(); let conn = sea_orm::Database::connect(&config.url).await.unwrap();
Ok(RocketDbPool { Ok(SeaOrmPool { conn })
conn,
})
} }
async fn get(&self) -> Result<Self::Connection, Self::Error> { fn borrow(&self) -> &Self::Connection {
Ok(self.conn.clone()) &self.conn
} }
} }

View File

@ -1,5 +1,5 @@
use sea_orm::sea_query::{ColumnDef, TableCreateStatement}; use sea_orm::sea_query::{ColumnDef, TableCreateStatement};
use sea_orm::{error::*, sea_query, DbConn, ExecResult}; use sea_orm::{error::*, query::*, sea_query, DbConn, ExecResult};
async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result<ExecResult, DbErr> { async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result<ExecResult, DbErr> {
let builder = db.get_database_backend(); let builder = db.get_database_backend();

View File

@ -3,7 +3,7 @@
[package] [package]
name = "sea-orm-cli" name = "sea-orm-cli"
version = "0.2.4" version = "0.2.6"
authors = [ "Billy Chan <ccw.billy.123@gmail.com>" ] authors = [ "Billy Chan <ccw.billy.123@gmail.com>" ]
edition = "2018" edition = "2018"
description = "Command line utility for SeaORM" description = "Command line utility for SeaORM"
@ -21,7 +21,7 @@ path = "src/main.rs"
clap = { version = "^2.33.3" } clap = { version = "^2.33.3" }
dotenv = { version = "^0.15" } dotenv = { version = "^0.15" }
async-std = { version = "^1.9", features = [ "attributes" ] } async-std = { version = "^1.9", features = [ "attributes" ] }
sea-orm-codegen = { version = "^0.2.4", path = "../sea-orm-codegen" } sea-orm-codegen = { version = "^0.2.6", path = "../sea-orm-codegen" }
sea-schema = { version = "^0.2.9", default-features = false, features = [ sea-schema = { version = "^0.2.9", default-features = false, features = [
"debug-print", "debug-print",
"sqlx-mysql", "sqlx-mysql",

View File

@ -1,6 +1,6 @@
[package] [package]
name = "sea-orm-codegen" name = "sea-orm-codegen"
version = "0.2.4" version = "0.2.6"
authors = ["Billy Chan <ccw.billy.123@gmail.com>"] authors = ["Billy Chan <ccw.billy.123@gmail.com>"]
edition = "2018" edition = "2018"
description = "Code Generator for SeaORM" description = "Code Generator for SeaORM"

View File

@ -1,3 +1,4 @@
use crate::util::escape_rust_keyword;
use heck::{CamelCase, SnakeCase}; use heck::{CamelCase, SnakeCase};
use proc_macro2::{Ident, TokenStream}; use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote}; use quote::{format_ident, quote};
@ -14,11 +15,11 @@ pub struct Column {
impl Column { impl Column {
pub fn get_name_snake_case(&self) -> Ident { pub fn get_name_snake_case(&self) -> Ident {
format_ident!("{}", self.name.to_snake_case()) format_ident!("{}", escape_rust_keyword(self.name.to_snake_case()))
} }
pub fn get_name_camel_case(&self) -> Ident { pub fn get_name_camel_case(&self) -> Ident {
format_ident!("{}", self.name.to_camel_case()) format_ident!("{}", escape_rust_keyword(self.name.to_camel_case()))
} }
pub fn get_rs_type(&self) -> TokenStream { pub fn get_rs_type(&self) -> TokenStream {
@ -27,8 +28,6 @@ impl Column {
ColumnType::Char(_) ColumnType::Char(_)
| ColumnType::String(_) | ColumnType::String(_)
| ColumnType::Text | ColumnType::Text
| ColumnType::Time(_)
| ColumnType::Date
| ColumnType::Custom(_) => "String", | ColumnType::Custom(_) => "String",
ColumnType::TinyInteger(_) => "i8", ColumnType::TinyInteger(_) => "i8",
ColumnType::SmallInteger(_) => "i16", ColumnType::SmallInteger(_) => "i16",
@ -37,6 +36,8 @@ impl Column {
ColumnType::Float(_) => "f32", ColumnType::Float(_) => "f32",
ColumnType::Double(_) => "f64", ColumnType::Double(_) => "f64",
ColumnType::Json | ColumnType::JsonBinary => "Json", ColumnType::Json | ColumnType::JsonBinary => "Json",
ColumnType::Date => "Date",
ColumnType::Time(_) => "Time",
ColumnType::DateTime(_) | ColumnType::Timestamp(_) => "DateTime", ColumnType::DateTime(_) | ColumnType::Timestamp(_) => "DateTime",
ColumnType::TimestampWithTimeZone(_) => "DateTimeWithTimeZone", ColumnType::TimestampWithTimeZone(_) => "DateTimeWithTimeZone",
ColumnType::Decimal(_) | ColumnType::Money(_) => "Decimal", ColumnType::Decimal(_) | ColumnType::Money(_) => "Decimal",
@ -194,6 +195,11 @@ mod tests {
make_col!("CAKE_FILLING_ID", ColumnType::Double(None)), make_col!("CAKE_FILLING_ID", ColumnType::Double(None)),
make_col!("CAKE-FILLING-ID", ColumnType::Binary(None)), make_col!("CAKE-FILLING-ID", ColumnType::Binary(None)),
make_col!("CAKE", ColumnType::Boolean), make_col!("CAKE", ColumnType::Boolean),
make_col!("date", ColumnType::Date),
make_col!("time", ColumnType::Time(None)),
make_col!("date_time", ColumnType::DateTime(None)),
make_col!("timestamp", ColumnType::Timestamp(None)),
make_col!("timestamp_tz", ColumnType::TimestampWithTimeZone(None)),
] ]
} }
@ -211,6 +217,11 @@ mod tests {
"cake_filling_id", "cake_filling_id",
"cake_filling_id", "cake_filling_id",
"cake", "cake",
"date",
"time",
"date_time",
"timestamp",
"timestamp_tz",
]; ];
for (col, snack_case) in columns.into_iter().zip(snack_cases) { for (col, snack_case) in columns.into_iter().zip(snack_cases) {
assert_eq!(col.get_name_snake_case().to_string(), snack_case); assert_eq!(col.get_name_snake_case().to_string(), snack_case);
@ -231,6 +242,11 @@ mod tests {
"CakeFillingId", "CakeFillingId",
"CakeFillingId", "CakeFillingId",
"Cake", "Cake",
"Date",
"Time",
"DateTime",
"Timestamp",
"TimestampTz",
]; ];
for (col, camel_case) in columns.into_iter().zip(camel_cases) { for (col, camel_case) in columns.into_iter().zip(camel_cases) {
assert_eq!(col.get_name_camel_case().to_string(), camel_case); assert_eq!(col.get_name_camel_case().to_string(), camel_case);
@ -241,7 +257,21 @@ mod tests {
fn test_get_rs_type() { fn test_get_rs_type() {
let columns = setup(); let columns = setup();
let rs_types = vec![ let rs_types = vec![
"String", "String", "i8", "i16", "i32", "i64", "f32", "f64", "Vec<u8>", "bool", "String",
"String",
"i8",
"i16",
"i32",
"i64",
"f32",
"f64",
"Vec<u8>",
"bool",
"Date",
"Time",
"DateTime",
"DateTime",
"DateTimeWithTimeZone",
]; ];
for (mut col, rs_type) in columns.into_iter().zip(rs_types) { for (mut col, rs_type) in columns.into_iter().zip(rs_types) {
let rs_type: TokenStream = rs_type.parse().unwrap(); let rs_type: TokenStream = rs_type.parse().unwrap();
@ -271,6 +301,11 @@ mod tests {
"ColumnType::Double.def()", "ColumnType::Double.def()",
"ColumnType::Binary.def()", "ColumnType::Binary.def()",
"ColumnType::Boolean.def()", "ColumnType::Boolean.def()",
"ColumnType::Date.def()",
"ColumnType::Time.def()",
"ColumnType::DateTime.def()",
"ColumnType::Timestamp.def()",
"ColumnType::TimestampWithTimeZone.def()",
]; ];
for (mut col, col_def) in columns.into_iter().zip(col_defs) { for (mut col, col_def) in columns.into_iter().zip(col_defs) {
let mut col_def: TokenStream = col_def.parse().unwrap(); let mut col_def: TokenStream = col_def.parse().unwrap();

View File

@ -597,18 +597,85 @@ mod tests {
name: "id".to_owned(), name: "id".to_owned(),
}], }],
}, },
Entity {
table_name: "rust_keyword".to_owned(),
columns: vec![
Column {
name: "id".to_owned(),
col_type: ColumnType::Integer(Some(11)),
auto_increment: true,
not_null: true,
unique: false,
},
Column {
name: "testing".to_owned(),
col_type: ColumnType::Integer(Some(11)),
auto_increment: false,
not_null: true,
unique: false,
},
Column {
name: "rust".to_owned(),
col_type: ColumnType::Integer(Some(11)),
auto_increment: false,
not_null: true,
unique: false,
},
Column {
name: "keywords".to_owned(),
col_type: ColumnType::Integer(Some(11)),
auto_increment: false,
not_null: true,
unique: false,
},
Column {
name: "type".to_owned(),
col_type: ColumnType::Integer(Some(11)),
auto_increment: false,
not_null: true,
unique: false,
},
Column {
name: "typeof".to_owned(),
col_type: ColumnType::Integer(Some(11)),
auto_increment: false,
not_null: true,
unique: false,
},
Column {
name: "crate".to_owned(),
col_type: ColumnType::Integer(Some(11)),
auto_increment: false,
not_null: true,
unique: false,
},
Column {
name: "self".to_owned(),
col_type: ColumnType::Integer(Some(11)),
auto_increment: false,
not_null: true,
unique: false,
},
],
relations: vec![],
conjunct_relations: vec![],
primary_keys: vec![PrimaryKey {
name: "id".to_owned(),
}],
},
] ]
} }
#[test] #[test]
fn test_gen_expanded_code_blocks() -> io::Result<()> { fn test_gen_expanded_code_blocks() -> io::Result<()> {
let entities = setup(); let entities = setup();
const ENTITY_FILES: [&str; 5] = [ const ENTITY_FILES: [&str; 6] = [
include_str!("../../tests/expanded/cake.rs"), include_str!("../../tests/expanded/cake.rs"),
include_str!("../../tests/expanded/cake_filling.rs"), include_str!("../../tests/expanded/cake_filling.rs"),
include_str!("../../tests/expanded/filling.rs"), include_str!("../../tests/expanded/filling.rs"),
include_str!("../../tests/expanded/fruit.rs"), include_str!("../../tests/expanded/fruit.rs"),
include_str!("../../tests/expanded/vendor.rs"), include_str!("../../tests/expanded/vendor.rs"),
include_str!("../../tests/expanded/rust_keyword.rs"),
]; ];
assert_eq!(entities.len(), ENTITY_FILES.len()); assert_eq!(entities.len(), ENTITY_FILES.len());
@ -642,12 +709,13 @@ mod tests {
#[test] #[test]
fn test_gen_compact_code_blocks() -> io::Result<()> { fn test_gen_compact_code_blocks() -> io::Result<()> {
let entities = setup(); let entities = setup();
const ENTITY_FILES: [&str; 5] = [ const ENTITY_FILES: [&str; 6] = [
include_str!("../../tests/compact/cake.rs"), include_str!("../../tests/compact/cake.rs"),
include_str!("../../tests/compact/cake_filling.rs"), include_str!("../../tests/compact/cake_filling.rs"),
include_str!("../../tests/compact/filling.rs"), include_str!("../../tests/compact/filling.rs"),
include_str!("../../tests/compact/fruit.rs"), include_str!("../../tests/compact/fruit.rs"),
include_str!("../../tests/compact/vendor.rs"), include_str!("../../tests/compact/vendor.rs"),
include_str!("../../tests/compact/rust_keyword.rs"),
]; ];
assert_eq!(entities.len(), ENTITY_FILES.len()); assert_eq!(entities.len(), ENTITY_FILES.len());

View File

@ -1,5 +1,6 @@
mod entity; mod entity;
mod error; mod error;
mod util;
pub use entity::*; pub use entity::*;
pub use error::*; pub use error::*;

View File

@ -0,0 +1,23 @@
pub(crate) fn escape_rust_keyword<T>(string: T) -> String
where
T: ToString,
{
let string = string.to_string();
if RUST_KEYWORDS.iter().any(|s| s.eq(&string)) {
format!("r#{}", string)
} else if RUST_SPECIAL_KEYWORDS.iter().any(|s| s.eq(&string)) {
format!("{}_", string)
} else {
string
}
}
pub(crate) const RUST_KEYWORDS: [&str; 49] = [
"as", "async", "await", "break", "const", "continue", "dyn", "else", "enum", "extern", "false",
"fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref",
"return", "static", "struct", "super", "trait", "true", "type", "union", "unsafe", "use",
"where", "while", "abstract", "become", "box", "do", "final", "macro", "override", "priv",
"try", "typeof", "unsized", "virtual", "yield",
];
pub(crate) const RUST_SPECIAL_KEYWORDS: [&str; 3] = ["crate", "Self", "self"];

View File

@ -0,0 +1,30 @@
//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "rust_keyword")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub testing: i32,
pub rust: i32,
pub keywords: i32,
pub r#type: i32,
pub r#typeof: i32,
pub crate_: i32,
pub self_: i32,
}
#[derive(Copy, Clone, Debug, EnumIter)]
pub enum Relation {}
impl RelationTrait for Relation {
fn def(&self) -> RelationDef {
match self {
_ => panic!("No RelationDef"),
}
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,79 @@
//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0
use sea_orm::entity::prelude::*;
#[derive(Copy, Clone, Default, Debug, DeriveEntity)]
pub struct Entity;
impl EntityName for Entity {
fn table_name(&self) -> &str {
"rust_keyword"
}
}
#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel)]
pub struct Model {
pub id: i32,
pub testing: i32,
pub rust: i32,
pub keywords: i32,
pub r#type: i32,
pub r#typeof: i32,
pub crate_: i32,
pub self_: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
pub enum Column {
Id,
Testing,
Rust,
Keywords,
Type,
Typeof,
Crate,
Self_,
}
#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)]
pub enum PrimaryKey {
Id,
}
impl PrimaryKeyTrait for PrimaryKey {
type ValueType = i32;
fn auto_increment() -> bool {
true
}
}
#[derive(Copy, Clone, Debug, EnumIter)]
pub enum Relation {}
impl ColumnTrait for Column {
type EntityName = Entity;
fn def(&self) -> ColumnDef {
match self {
Self::Id => ColumnType::Integer.def(),
Self::Testing => ColumnType::Integer.def(),
Self::Rust => ColumnType::Integer.def(),
Self::Keywords => ColumnType::Integer.def(),
Self::Type => ColumnType::Integer.def(),
Self::Typeof => ColumnType::Integer.def(),
Self::Crate => ColumnType::Integer.def(),
Self::Self_ => ColumnType::Integer.def(),
}
}
}
impl RelationTrait for Relation {
fn def(&self) -> RelationDef {
match self {
_ => panic!("No RelationDef"),
}
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -1,6 +1,6 @@
[package] [package]
name = "sea-orm-macros" name = "sea-orm-macros"
version = "0.2.4" version = "0.2.6"
authors = [ "Billy Chan <ccw.billy.123@gmail.com>" ] authors = [ "Billy Chan <ccw.billy.123@gmail.com>" ]
edition = "2018" edition = "2018"
description = "Derive macros for SeaORM" description = "Derive macros for SeaORM"

View File

@ -1,4 +1,4 @@
use crate::util::field_not_ignored; use crate::util::{escape_rust_keyword, field_not_ignored, trim_starting_raw_identifier};
use heck::CamelCase; use heck::CamelCase;
use proc_macro2::{Ident, TokenStream}; use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote, quote_spanned}; use quote::{format_ident, quote, quote_spanned};
@ -29,10 +29,10 @@ pub fn expand_derive_active_model(ident: Ident, data: Data) -> syn::Result<Token
.clone() .clone()
.into_iter() .into_iter()
.map(|field| { .map(|field| {
let mut ident = format_ident!( let ident = field.ident.as_ref().unwrap().to_string();
"{}", let ident = trim_starting_raw_identifier(ident).to_camel_case();
field.ident.as_ref().unwrap().to_string().to_camel_case() let ident = escape_rust_keyword(ident);
); let mut ident = format_ident!("{}", &ident);
for attr in field.attrs.iter() { for attr in field.attrs.iter() {
if let Some(ident) = attr.path.get_ident() { if let Some(ident) = attr.path.get_ident() {
if ident != "sea_orm" { if ident != "sea_orm" {

View File

@ -1,3 +1,5 @@
use crate::util::{escape_rust_keyword, trim_starting_raw_identifier};
use convert_case::{Case, Casing};
use proc_macro2::{Ident, Span, TokenStream}; use proc_macro2::{Ident, Span, TokenStream};
use quote::quote; use quote::quote;
use syn::{ use syn::{
@ -5,8 +7,6 @@ use syn::{
Lit, Meta, Lit, Meta,
}; };
use convert_case::{Case, Casing};
pub fn expand_derive_entity_model(data: Data, attrs: Vec<Attribute>) -> syn::Result<TokenStream> { pub fn expand_derive_entity_model(data: Data, attrs: Vec<Attribute>) -> syn::Result<TokenStream> {
// if #[sea_orm(table_name = "foo", schema_name = "bar")] specified, create Entity struct // if #[sea_orm(table_name = "foo", schema_name = "bar")] specified, create Entity struct
let mut table_name = None; let mut table_name = None;
@ -60,8 +60,10 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec<Attribute>) -> syn::Res
if let Fields::Named(fields) = item_struct.fields { if let Fields::Named(fields) = item_struct.fields {
for field in fields.named { for field in fields.named {
if let Some(ident) = &field.ident { if let Some(ident) = &field.ident {
let mut field_name = let mut field_name = Ident::new(
Ident::new(&ident.to_string().to_case(Case::Pascal), Span::call_site()); &trim_starting_raw_identifier(&ident).to_case(Case::Pascal),
Span::call_site(),
);
let mut nullable = false; let mut nullable = false;
let mut default_value = None; let mut default_value = None;
@ -168,6 +170,8 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec<Attribute>) -> syn::Res
field_name = enum_name; field_name = enum_name;
} }
field_name = Ident::new(&escape_rust_keyword(field_name), Span::call_site());
if ignore { if ignore {
continue; continue;
} else { } else {

View File

@ -1,4 +1,7 @@
use crate::{attributes::derive_attr, util::field_not_ignored}; use crate::{
attributes::derive_attr,
util::{escape_rust_keyword, field_not_ignored, trim_starting_raw_identifier},
};
use heck::CamelCase; use heck::CamelCase;
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned}; use quote::{format_ident, quote, quote_spanned};
@ -43,10 +46,10 @@ impl DeriveModel {
let column_idents = fields let column_idents = fields
.iter() .iter()
.map(|field| { .map(|field| {
let mut ident = format_ident!( let ident = field.ident.as_ref().unwrap().to_string();
"{}", let ident = trim_starting_raw_identifier(ident).to_camel_case();
field.ident.as_ref().unwrap().to_string().to_camel_case() let ident = escape_rust_keyword(ident);
); let mut ident = format_ident!("{}", &ident);
for attr in field.attrs.iter() { for attr in field.attrs.iter() {
if let Some(ident) = attr.path.get_ident() { if let Some(ident) = attr.path.get_ident() {
if ident != "sea_orm" { if ident != "sea_orm" {

View File

@ -24,3 +24,39 @@ pub(crate) fn field_not_ignored(field: &Field) -> bool {
} }
true true
} }
pub(crate) fn trim_starting_raw_identifier<T>(string: T) -> String
where
T: ToString,
{
string
.to_string()
.trim_start_matches(RAW_IDENTIFIER)
.to_string()
}
pub(crate) fn escape_rust_keyword<T>(string: T) -> String
where
T: ToString,
{
let string = string.to_string();
if RUST_KEYWORDS.iter().any(|s| s.eq(&string)) {
format!("r#{}", string)
} else if RUST_SPECIAL_KEYWORDS.iter().any(|s| s.eq(&string)) {
format!("{}_", string)
} else {
string
}
}
pub(crate) const RAW_IDENTIFIER: &str = "r#";
pub(crate) const RUST_KEYWORDS: [&str; 49] = [
"as", "async", "await", "break", "const", "continue", "dyn", "else", "enum", "extern", "false",
"fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref",
"return", "static", "struct", "super", "trait", "true", "type", "union", "unsafe", "use",
"where", "while", "abstract", "become", "box", "do", "final", "macro", "override", "priv",
"try", "typeof", "unsized", "virtual", "yield",
];
pub(crate) const RUST_SPECIAL_KEYWORDS: [&str; 3] = ["crate", "Self", "self"];

View File

@ -0,0 +1,2 @@
[workspace]
members = ["codegen", "lib"]

1
sea-orm-rocket/README.md Normal file
View File

@ -0,0 +1 @@
# SeaORM Rocket support crate.

View File

@ -0,0 +1,22 @@
[package]
name = "sea-orm-rocket-codegen"
version = "0.1.0-rc"
authors = ["Sergio Benitez <sb@sergio.bz>", "Jeb Rosen <jeb@jebrosen.com>"]
description = "Procedural macros for sea_orm_rocket."
repository = "https://github.com/SergioBenitez/Rocket/contrib/db_pools"
readme = "../README.md"
keywords = ["rocket", "framework", "database", "pools"]
license = "MIT OR Apache-2.0"
edition = "2018"
[lib]
proc-macro = true
[dependencies]
devise = "0.3"
quote = "1"
[dev-dependencies]
rocket = { git = "https://github.com/SergioBenitez/Rocket.git", default-features = false }
trybuild = "1.0"
version_check = "0.9"

View File

@ -0,0 +1,110 @@
use proc_macro::TokenStream;
use devise::{DeriveGenerator, FromMeta, MapperBuild, Support, ValidatorBuild};
use devise::proc_macro2_diagnostics::SpanDiagnosticExt;
use devise::syn::{self, spanned::Spanned};
const ONE_DATABASE_ATTR: &str = "missing `#[database(\"name\")]` attribute";
const ONE_UNNAMED_FIELD: &str = "struct must have exactly one unnamed field";
#[derive(Debug, FromMeta)]
struct DatabaseAttribute {
#[meta(naked)]
name: String,
}
pub fn derive_database(input: TokenStream) -> TokenStream {
DeriveGenerator::build_for(input, quote!(impl sea_orm_rocket::Database))
.support(Support::TupleStruct)
.validator(ValidatorBuild::new()
.struct_validate(|_, s| {
if s.fields.len() == 1 {
Ok(())
} else {
Err(s.span().error(ONE_UNNAMED_FIELD))
}
})
)
.outer_mapper(MapperBuild::new()
.struct_map(|_, s| {
let pool_type = match &s.fields {
syn::Fields::Unnamed(f) => &f.unnamed[0].ty,
_ => unreachable!("Support::TupleStruct"),
};
let decorated_type = &s.ident;
let db_ty = quote_spanned!(decorated_type.span() =>
<#decorated_type as sea_orm_rocket::Database>
);
quote_spanned! { decorated_type.span() =>
impl From<#pool_type> for #decorated_type {
fn from(pool: #pool_type) -> Self {
Self(pool)
}
}
impl std::ops::Deref for #decorated_type {
type Target = #pool_type;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for #decorated_type {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[rocket::async_trait]
impl<'r> rocket::request::FromRequest<'r> for &'r #decorated_type {
type Error = ();
async fn from_request(
req: &'r rocket::request::Request<'_>
) -> rocket::request::Outcome<Self, Self::Error> {
match #db_ty::fetch(req.rocket()) {
Some(db) => rocket::outcome::Outcome::Success(db),
None => rocket::outcome::Outcome::Failure((
rocket::http::Status::InternalServerError, ()))
}
}
}
impl rocket::Sentinel for &#decorated_type {
fn abort(rocket: &rocket::Rocket<rocket::Ignite>) -> bool {
#db_ty::fetch(rocket).is_none()
}
}
}
})
)
.outer_mapper(quote!(#[rocket::async_trait]))
.inner_mapper(MapperBuild::new()
.try_struct_map(|_, s| {
let db_name = DatabaseAttribute::one_from_attrs("database", &s.attrs)?
.map(|attr| attr.name)
.ok_or_else(|| s.span().error(ONE_DATABASE_ATTR))?;
let fairing_name = format!("'{}' Database Pool", db_name);
let pool_type = match &s.fields {
syn::Fields::Unnamed(f) => &f.unnamed[0].ty,
_ => unreachable!("Support::TupleStruct"),
};
Ok(quote_spanned! { pool_type.span() =>
type Pool = #pool_type;
const NAME: &'static str = #db_name;
fn init() -> sea_orm_rocket::Initializer<Self> {
sea_orm_rocket::Initializer::with_name(#fairing_name)
}
})
})
)
.to_tokens()
}

View File

@ -0,0 +1,52 @@
#![recursion_limit="256"]
#![warn(rust_2018_idioms)]
//! # `sea_orm_rocket` - Code Generation
//!
//! Implements the code generation portion of the `sea_orm_rocket` crate. This
//! is an implementation detail. This create should never be depended on
//! directly.
#[macro_use] extern crate quote;
mod database;
/// Automatic derive for the [`Database`] trait.
///
/// The derive generates an implementation of [`Database`] as follows:
///
/// * [`Database::NAME`] is set to the value in the `#[database("name")]`
/// attribute.
///
/// This names the database, providing an anchor to configure the database via
/// `Rocket.toml` or any other configuration source. Specifically, the
/// configuration in `databases.name` is used to configure the driver.
///
/// * [`Database::Pool`] is set to the wrapped type: `PoolType` above. The type
/// must implement [`Pool`].
///
/// To meet the required [`Database`] supertrait bounds, this derive also
/// generates implementations for:
///
/// * `From<Db::Pool>`
///
/// * `Deref<Target = Db::Pool>`
///
/// * `DerefMut<Target = Db::Pool>`
///
/// * `FromRequest<'_> for &Db`
///
/// * `Sentinel for &Db`
///
/// The `Deref` impls enable accessing the database pool directly from
/// references `&Db` or `&mut Db`. To force a dereference to the underlying
/// type, use `&db.0` or `&**db` or their `&mut` variants.
///
/// [`Database`]: ../sea_orm_rocket/trait.Database.html
/// [`Database::NAME`]: ../sea_orm_rocket/trait.Database.html#associatedconstant.NAME
/// [`Database::Pool`]: ../sea_orm_rocket/trait.Database.html#associatedtype.Pool
/// [`Pool`]: ../sea_orm_rocket/trait.Pool.html
#[proc_macro_derive(Database, attributes(database))]
pub fn derive_database(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
crate::database::derive_database(input)
}

View File

@ -0,0 +1,27 @@
[package]
name = "sea-orm-rocket"
version = "0.1.0"
authors = ["Sergio Benitez <sb@sergio.bz>", "Jeb Rosen <jeb@jebrosen.com>"]
description = "SeaORM Rocket support crate"
repository = "https://github.com/SeaQL/sea-orm"
readme = "../README.md"
keywords = ["rocket", "framework", "database", "pools"]
license = "MIT OR Apache-2.0"
edition = "2018"
[package.metadata.docs.rs]
all-features = true
[dependencies.rocket]
git = "https://github.com/SergioBenitez/Rocket.git"
version = "0.5.0-rc.1"
default-features = false
[dependencies.sea-orm-rocket-codegen]
path = "../codegen"
version = "0.1.0-rc"
[dev-dependencies.rocket]
git = "https://github.com/SergioBenitez/Rocket.git"
default-features = false
features = ["json"]

View File

@ -0,0 +1,83 @@
use rocket::serde::{Deserialize, Serialize};
/// Base configuration for all database drivers.
///
/// A dictionary matching this structure is extracted from the active
/// [`Figment`](crate::figment::Figment), scoped to `databases.name`, where
/// `name` is the name of the database, by the
/// [`Initializer`](crate::Initializer) fairing on ignition and used to
/// configure the relevant database and database pool.
///
/// With the default provider, these parameters are typically configured in a
/// `Rocket.toml` file:
///
/// ```toml
/// [default.databases.db_name]
/// url = "/path/to/db.sqlite"
///
/// # only `url` is required. `Initializer` provides defaults for the rest.
/// min_connections = 64
/// max_connections = 1024
/// connect_timeout = 5
/// idle_timeout = 120
/// ```
///
/// Alternatively, a custom provider can be used. For example, a custom `Figment`
/// with a global `databases.name` configuration:
///
/// ```rust
/// # use rocket::launch;
/// #[launch]
/// fn rocket() -> _ {
/// let figment = rocket::Config::figment()
/// .merge(("databases.name", sea_orm_rocket::Config {
/// url: "db:specific@config&url".into(),
/// min_connections: None,
/// max_connections: 1024,
/// connect_timeout: 3,
/// idle_timeout: None,
/// }));
///
/// rocket::custom(figment)
/// }
/// ```
///
/// For general information on configuration in Rocket, see [`rocket::config`].
/// For higher-level details on configuring a database, see the [crate-level
/// docs](crate#configuration).
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(crate = "rocket::serde")]
pub struct Config {
/// Database-specific connection and configuration URL.
///
/// The format of the URL is database specific; consult your database's
/// documentation.
pub url: String,
/// Minimum number of connections to maintain in the pool.
///
/// **Note:** `deadpool` drivers do not support and thus ignore this value.
///
/// _Default:_ `None`.
pub min_connections: Option<u32>,
/// Maximum number of connections to maintain in the pool.
///
/// _Default:_ `workers * 4`.
pub max_connections: usize,
/// Number of seconds to wait for a connection before timing out.
///
/// If the timeout elapses before a connection can be made or retrieved from
/// a pool, an error is returned.
///
/// _Default:_ `5`.
pub connect_timeout: u64,
/// Maximum number of seconds to keep a connection alive for.
///
/// After a connection is established, it is maintained in a pool for
/// efficient connection retrieval. When an `idle_timeout` is set, that
/// connection will be closed after the timeout elapses. If an
/// `idle_timeout` is not specified, the behavior is driver specific but
/// typically defaults to keeping a connection active indefinitely.
///
/// _Default:_ `None`.
pub idle_timeout: Option<u64>,
}

View File

@ -0,0 +1,245 @@
use std::marker::PhantomData;
use std::ops::{DerefMut};
use rocket::{error, info_, Build, Ignite, Phase, Rocket, Sentinel};
use rocket::fairing::{self, Fairing, Info, Kind};
use rocket::request::{FromRequest, Outcome, Request};
use rocket::http::Status;
use rocket::yansi::Paint;
use rocket::figment::providers::Serialized;
use crate::Pool;
/// Derivable trait which ties a database [`Pool`] with a configuration name.
///
/// This trait should rarely, if ever, be implemented manually. Instead, it
/// should be derived:
///
/// ```ignore
/// use sea_orm_rocket::{Database};
/// # use sea_orm_rocket::MockPool as SeaOrmPool;
///
/// #[derive(Database, Debug)]
/// #[database("sea_orm")]
/// struct Db(SeaOrmPool);
///
/// #[launch]
/// fn rocket() -> _ {
/// rocket::build().attach(Db::init())
/// }
/// ```
///
/// See the [`Database` derive](derive@crate::Database) for details.
pub trait Database: From<Self::Pool> + DerefMut<Target = Self::Pool> + Send + Sync + 'static {
/// The [`Pool`] type of connections to this database.
///
/// When `Database` is derived, this takes the value of the `Inner` type in
/// `struct Db(Inner)`.
type Pool: Pool;
/// The configuration name for this database.
///
/// When `Database` is derived, this takes the value `"name"` in the
/// `#[database("name")]` attribute.
const NAME: &'static str;
/// Returns a fairing that initializes the database and its connection pool.
///
/// # Example
///
/// ```rust
/// # mod _inner {
/// # use rocket::launch;
/// use sea_orm_rocket::{Database};
/// # use sea_orm_rocket::MockPool as SeaOrmPool;
///
/// #[derive(Database)]
/// #[database("sea_orm")]
/// struct Db(SeaOrmPool);
///
/// #[launch]
/// fn rocket() -> _ {
/// rocket::build().attach(Db::init())
/// }
/// # }
/// ```
fn init() -> Initializer<Self> {
Initializer::new()
}
/// Returns a reference to the initialized database in `rocket`. The
/// initializer fairing returned by `init()` must have already executed for
/// `Option` to be `Some`. This is guaranteed to be the case if the fairing
/// is attached and either:
///
/// * Rocket is in the [`Orbit`](rocket::Orbit) phase. That is, the
/// application is running. This is always the case in request guards
/// and liftoff fairings,
/// * _or_ Rocket is in the [`Build`](rocket::Build) or
/// [`Ignite`](rocket::Ignite) phase and the `Initializer` fairing has
/// already been run. This is the case in all fairing callbacks
/// corresponding to fairings attached _after_ the `Initializer`
/// fairing.
///
/// # Example
///
/// Run database migrations in an ignite fairing. It is imperative that the
/// migration fairing be registered _after_ the `init()` fairing.
///
/// ```rust
/// # mod _inner {
/// # use rocket::launch;
/// use rocket::{Rocket, Build};
/// use rocket::fairing::{self, AdHoc};
///
/// use sea_orm_rocket::{Database};
/// # use sea_orm_rocket::MockPool as SeaOrmPool;
///
/// #[derive(Database)]
/// #[database("sea_orm")]
/// struct Db(SeaOrmPool);
///
/// async fn run_migrations(rocket: Rocket<Build>) -> fairing::Result {
/// if let Some(db) = Db::fetch(&rocket) {
/// // run migrations using `db`. get the inner type with &db.0.
/// Ok(rocket)
/// } else {
/// Err(rocket)
/// }
/// }
///
/// #[launch]
/// fn rocket() -> _ {
/// rocket::build()
/// .attach(Db::init())
/// .attach(AdHoc::try_on_ignite("DB Migrations", run_migrations))
/// }
/// # }
/// ```
fn fetch<P: Phase>(rocket: &Rocket<P>) -> Option<&Self> {
if let Some(db) = rocket.state() {
return Some(db);
}
let dbtype = std::any::type_name::<Self>();
let fairing = Paint::default(format!("{}::init()", dbtype)).bold();
error!("Attempted to fetch unattached database `{}`.", Paint::default(dbtype).bold());
info_!("`{}` fairing must be attached prior to using this database.", fairing);
None
}
}
/// A [`Fairing`] which initializes a [`Database`] and its connection pool.
///
/// A value of this type can be created for any type `D` that implements
/// [`Database`] via the [`Database::init()`] method on the type. Normally, a
/// value of this type _never_ needs to be constructed directly. This
/// documentation exists purely as a reference.
///
/// This fairing initializes a database pool. Specifically, it:
///
/// 1. Reads the configuration at `database.db_name`, where `db_name` is
/// [`Database::NAME`].
///
/// 2. Sets [`Config`](crate::Config) defaults on the configuration figment.
///
/// 3. Calls [`Pool::init()`].
///
/// 4. Stores the database instance in managed storage, retrievable via
/// [`Database::fetch()`].
///
/// The name of the fairing itself is `Initializer<D>`, with `D` replaced with
/// the type name `D` unless a name is explicitly provided via
/// [`Self::with_name()`].
pub struct Initializer<D: Database>(Option<&'static str>, PhantomData<fn() -> D>);
/// A request guard which retrieves a single connection to a [`Database`].
///
/// For a database type of `Db`, a request guard of `Connection<Db>` retrieves a
/// single connection to `Db`.
///
/// The request guard succeeds if the database was initialized by the
/// [`Initializer`] fairing and a connection is available within
/// [`connect_timeout`](crate::Config::connect_timeout) seconds.
/// * If the `Initializer` fairing was _not_ attached, the guard _fails_ with
/// status `InternalServerError`. A [`Sentinel`] guards this condition, and so
/// this type of failure is unlikely to occur. A `None` error is returned.
/// * If a connection is not available within `connect_timeout` seconds or
/// another error occurs, the gaurd _fails_ with status `ServiceUnavailable`
/// and the error is returned in `Some`.
///
pub struct Connection<'a, D: Database>(&'a <D::Pool as Pool>::Connection);
impl<D: Database> Initializer<D> {
/// Returns a database initializer fairing for `D`.
///
/// This method should never need to be called manually. See the [crate
/// docs](crate) for usage information.
pub fn new() -> Self {
Self(None, std::marker::PhantomData)
}
/// Returns a database initializer fairing for `D` with name `name`.
///
/// This method should never need to be called manually. See the [crate
/// docs](crate) for usage information.
pub fn with_name(name: &'static str) -> Self {
Self(Some(name), std::marker::PhantomData)
}
}
impl<'a, D: Database> Connection<'a, D> {
/// Returns the internal connection value. See the [`Connection` Deref
/// column](crate#supported-drivers) for the expected type of this value.
pub fn into_inner(self) -> &'a <D::Pool as Pool>::Connection {
self.0
}
}
#[rocket::async_trait]
impl<D: Database> Fairing for Initializer<D> {
fn info(&self) -> Info {
Info {
name: self.0.unwrap_or_else(std::any::type_name::<Self>),
kind: Kind::Ignite,
}
}
async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
let workers: usize = rocket.figment()
.extract_inner(rocket::Config::WORKERS)
.unwrap_or_else(|_| rocket::Config::default().workers);
let figment = rocket.figment()
.focus(&format!("databases.{}", D::NAME))
.merge(Serialized::default("max_connections", workers * 4))
.merge(Serialized::default("connect_timeout", 5));
match <D::Pool>::init(&figment).await {
Ok(pool) => Ok(rocket.manage(D::from(pool))),
Err(e) => {
error!("failed to initialize database: {}", e);
Err(rocket)
}
}
}
}
#[rocket::async_trait]
impl<'r, D: Database> FromRequest<'r> for Connection<'r, D> {
type Error = Option<<D::Pool as Pool>::Error>;
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match D::fetch(req.rocket()) {
Some(pool) => Outcome::Success(Connection(pool.borrow())),
None => Outcome::Failure((Status::InternalServerError, None)),
}
}
}
impl<D: Database> Sentinel for Connection<'_, D> {
fn abort(rocket: &Rocket<Ignite>) -> bool {
D::fetch(rocket).is_none()
}
}

View File

@ -0,0 +1,35 @@
use std::fmt;
/// A general error type for use by [`Pool`](crate::Pool#implementing)
/// implementors and returned by the [`Connection`](crate::Connection) request
/// guard.
#[derive(Debug)]
pub enum Error<A, B = A> {
/// An error that occured during database/pool initialization.
Init(A),
/// An error that ocurred while retrieving a connection from the pool.
Get(B),
/// A [`Figment`](crate::figment::Figment) configuration error.
Config(crate::figment::Error),
}
impl<A: fmt::Display, B: fmt::Display> fmt::Display for Error<A, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Init(e) => write!(f, "failed to initialize database: {}", e),
Error::Get(e) => write!(f, "failed to get db connection: {}", e),
Error::Config(e) => write!(f, "bad configuration: {}", e),
}
}
}
impl<A, B> std::error::Error for Error<A, B>
where A: fmt::Debug + fmt::Display, B: fmt::Debug + fmt::Display {}
impl<A, B> From<crate::figment::Error> for Error<A, B> {
fn from(e: crate::figment::Error) -> Self {
Self::Config(e)
}
}

View File

@ -0,0 +1,20 @@
//! SeaORM Rocket support crate.
#![deny(missing_docs)]
/// Re-export of the `figment` crate.
#[doc(inline)]
pub use rocket::figment;
pub use rocket;
mod database;
mod error;
mod pool;
mod config;
pub use self::database::{Connection, Database, Initializer};
pub use self::error::Error;
pub use self::pool::{Pool, MockPool};
pub use self::config::Config;
pub use sea_orm_rocket_codegen::*;

View File

@ -0,0 +1,70 @@
use rocket::figment::Figment;
/// Generic [`Database`](crate::Database) driver connection pool trait.
///
/// This trait provides a generic interface to various database pooling
/// implementations in the Rust ecosystem. It can be implemented by anyone.
///
/// This is adapted from the original `rocket_db_pools`. But on top we require
/// `Connection` itself to be `Sync`. Hence, instead of cloning or allocating
/// a new connection per request, here we only borrow a reference to the pool.
///
/// In SeaORM, only *when* you are about to execute a SQL statement will a
/// connection be acquired from the pool, and returned as soon as the query finishes.
/// This helps a bit with concurrency if the lifecycle of a request is long enough.
/// ```
#[rocket::async_trait]
pub trait Pool: Sized + Send + Sync + 'static {
/// The connection type managed by this pool.
type Connection;
/// The error type returned by [`Self::init()`].
type Error: std::error::Error;
/// Constructs a pool from a [Value](rocket::figment::value::Value).
///
/// It is up to each implementor of `Pool` to define its accepted
/// configuration value(s) via the `Config` associated type. Most
/// integrations provided in `sea_orm_rocket` use [`Config`], which
/// accepts a (required) `url` and an (optional) `pool_size`.
///
/// ## Errors
///
/// This method returns an error if the configuration is not compatible, or
/// if creating a pool failed due to an unavailable database server,
/// insufficient resources, or another database-specific error.
async fn init(figment: &Figment) -> Result<Self, Self::Error>;
/// Borrows a reference to the pool
fn borrow(&self) -> &Self::Connection;
}
#[derive(Debug)]
/// A mock object which impl `Pool`, for testing only
pub struct MockPool;
#[derive(Debug)]
pub struct MockPoolErr;
impl std::fmt::Display for MockPoolErr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl std::error::Error for MockPoolErr {}
#[rocket::async_trait]
impl Pool for MockPool {
type Error = MockPoolErr;
type Connection = bool;
async fn init(_figment: &Figment) -> Result<Self, Self::Error> {
Ok(MockPool)
}
fn borrow(&self) -> &Self::Connection {
&true
}
}

View File

@ -1,168 +1,40 @@
use crate::{error::*, ExecResult, QueryResult, Statement, StatementBuilder}; use crate::{
use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; DatabaseTransaction, DbBackend, DbErr, ExecResult, QueryResult, Statement, TransactionError,
};
use futures::Stream;
use std::{future::Future, pin::Pin};
#[cfg_attr(not(feature = "mock"), derive(Clone))] #[async_trait::async_trait]
pub enum DatabaseConnection { pub trait ConnectionTrait<'a>: Sync {
#[cfg(feature = "sqlx-mysql")] type Stream: Stream<Item = Result<QueryResult, DbErr>>;
SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
#[cfg(feature = "sqlx-postgres")]
SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
#[cfg(feature = "sqlx-sqlite")]
SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection),
#[cfg(feature = "mock")]
MockDatabaseConnection(crate::MockDatabaseConnection),
Disconnected,
}
pub type DbConn = DatabaseConnection; fn get_database_backend(&self) -> DbBackend;
#[derive(Debug, Copy, Clone, PartialEq)] async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr>;
pub enum DatabaseBackend {
MySql,
Postgres,
Sqlite,
}
pub type DbBackend = DatabaseBackend; async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr>;
impl Default for DatabaseConnection { async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
fn default() -> Self {
Self::Disconnected
}
}
impl std::fmt::Debug for DatabaseConnection { fn stream(
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { &'a self,
write!( stmt: Statement,
f, ) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a>>;
"{}",
match self {
#[cfg(feature = "sqlx-mysql")]
Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
#[cfg(feature = "sqlx-postgres")]
Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
#[cfg(feature = "sqlx-sqlite")]
Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection",
#[cfg(feature = "mock")]
Self::MockDatabaseConnection(_) => "MockDatabaseConnection",
Self::Disconnected => "Disconnected",
}
)
}
}
impl DatabaseConnection { async fn begin(&self) -> Result<DatabaseTransaction, DbErr>;
pub fn get_database_backend(&self) -> DbBackend {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.get_database_backend(),
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}
pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> { /// Execute the function inside a transaction.
match self { /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
#[cfg(feature = "sqlx-mysql")] async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt).await,
DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())),
}
}
pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt).await,
DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())),
}
}
pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt).await,
DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())),
}
}
#[cfg(feature = "mock")]
pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection {
match self {
DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn,
_ => panic!("not mock connection"),
}
}
#[cfg(not(feature = "mock"))]
pub fn as_mock_connection(&self) -> Option<bool> {
None
}
#[cfg(feature = "mock")]
pub fn into_transaction_log(self) -> Vec<crate::Transaction> {
let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap();
mocker.drain_transaction_log()
}
}
impl DbBackend {
pub fn is_prefix_of(self, base_url: &str) -> bool {
match self {
Self::Postgres => {
base_url.starts_with("postgres://") || base_url.starts_with("postgresql://")
}
Self::MySql => base_url.starts_with("mysql://"),
Self::Sqlite => base_url.starts_with("sqlite:"),
}
}
pub fn build<S>(&self, statement: &S) -> Statement
where where
S: StatementBuilder, F: for<'c> FnOnce(
{ &'c DatabaseTransaction,
statement.build(self) ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
} + Send,
T: Send,
E: std::error::Error + Send;
pub fn get_query_builder(&self) -> Box<dyn QueryBuilder> { fn is_mock_connection(&self) -> bool {
match self { false
Self::MySql => Box::new(MysqlQueryBuilder),
Self::Postgres => Box::new(PostgresQueryBuilder),
Self::Sqlite => Box::new(SqliteQueryBuilder),
}
}
}
#[cfg(test)]
mod tests {
use crate::DatabaseConnection;
#[test]
fn assert_database_connection_traits() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<DatabaseConnection>();
} }
} }

View File

@ -0,0 +1,261 @@
use crate::{
error::*, ConnectionTrait, DatabaseTransaction, ExecResult, QueryResult, Statement,
StatementBuilder, TransactionError,
};
use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder};
use std::{future::Future, pin::Pin};
#[cfg(feature = "sqlx-dep")]
use sqlx::pool::PoolConnection;
#[cfg(feature = "mock")]
use std::sync::Arc;
#[cfg_attr(not(feature = "mock"), derive(Clone))]
pub enum DatabaseConnection {
#[cfg(feature = "sqlx-mysql")]
SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
#[cfg(feature = "sqlx-postgres")]
SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
#[cfg(feature = "sqlx-sqlite")]
SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection),
#[cfg(feature = "mock")]
MockDatabaseConnection(Arc<crate::MockDatabaseConnection>),
Disconnected,
}
pub type DbConn = DatabaseConnection;
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum DatabaseBackend {
MySql,
Postgres,
Sqlite,
}
pub type DbBackend = DatabaseBackend;
pub(crate) enum InnerConnection {
#[cfg(feature = "sqlx-mysql")]
MySql(PoolConnection<sqlx::MySql>),
#[cfg(feature = "sqlx-postgres")]
Postgres(PoolConnection<sqlx::Postgres>),
#[cfg(feature = "sqlx-sqlite")]
Sqlite(PoolConnection<sqlx::Sqlite>),
#[cfg(feature = "mock")]
Mock(std::sync::Arc<crate::MockDatabaseConnection>),
}
impl Default for DatabaseConnection {
fn default() -> Self {
Self::Disconnected
}
}
impl std::fmt::Debug for DatabaseConnection {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"{}",
match self {
#[cfg(feature = "sqlx-mysql")]
Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
#[cfg(feature = "sqlx-postgres")]
Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
#[cfg(feature = "sqlx-sqlite")]
Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection",
#[cfg(feature = "mock")]
Self::MockDatabaseConnection(_) => "MockDatabaseConnection",
Self::Disconnected => "Disconnected",
}
)
}
}
#[async_trait::async_trait]
impl<'a> ConnectionTrait<'a> for DatabaseConnection {
type Stream = crate::QueryStream;
fn get_database_backend(&self) -> DbBackend {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.get_database_backend(),
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt),
DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())),
}
}
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt),
DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())),
}
}
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt),
DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())),
}
}
fn stream(
&'a self,
stmt: Statement,
) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a>> {
Box::pin(async move {
Ok(match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => {
crate::QueryStream::from((Arc::clone(conn), stmt))
}
DatabaseConnection::Disconnected => panic!("Disconnected"),
})
})
}
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => {
DatabaseTransaction::new_mock(Arc::clone(conn)).await
}
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}
/// Execute the function inside a transaction.
/// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
T: Send,
E: std::error::Error + Send,
{
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => {
conn.transaction(_callback).await
}
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => {
let transaction = DatabaseTransaction::new_mock(Arc::clone(conn))
.await
.map_err(TransactionError::Connection)?;
transaction.run(_callback).await
}
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}
#[cfg(feature = "mock")]
fn is_mock_connection(&self) -> bool {
matches!(self, DatabaseConnection::MockDatabaseConnection(_))
}
}
#[cfg(feature = "mock")]
impl DatabaseConnection {
pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection {
match self {
DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn,
_ => panic!("not mock connection"),
}
}
pub fn into_transaction_log(self) -> Vec<crate::Transaction> {
let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap();
mocker.drain_transaction_log()
}
}
impl DbBackend {
pub fn is_prefix_of(self, base_url: &str) -> bool {
match self {
Self::Postgres => {
base_url.starts_with("postgres://") || base_url.starts_with("postgresql://")
}
Self::MySql => base_url.starts_with("mysql://"),
Self::Sqlite => base_url.starts_with("sqlite:"),
}
}
pub fn build<S>(&self, statement: &S) -> Statement
where
S: StatementBuilder,
{
statement.build(self)
}
pub fn get_query_builder(&self) -> Box<dyn QueryBuilder> {
match self {
Self::MySql => Box::new(MysqlQueryBuilder),
Self::Postgres => Box::new(PostgresQueryBuilder),
Self::Sqlite => Box::new(SqliteQueryBuilder),
}
}
}
#[cfg(test)]
mod tests {
use crate::DatabaseConnection;
#[test]
fn assert_database_connection_traits() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<DatabaseConnection>();
}
}

View File

@ -1,14 +1,15 @@
use crate::{ use crate::{
error::*, DatabaseConnection, DbBackend, EntityTrait, ExecResult, ExecResultHolder, Iden, error::*, DatabaseConnection, DbBackend, EntityTrait, ExecResult, ExecResultHolder, Iden,
Iterable, MockDatabaseConnection, MockDatabaseTrait, ModelTrait, QueryResult, QueryResultRow, Iterable, MockDatabaseConnection, MockDatabaseTrait, ModelTrait, QueryResult, QueryResultRow,
Statement, Transaction, Statement,
}; };
use sea_query::{Value, ValueType}; use sea_query::{Value, ValueType, Values};
use std::collections::BTreeMap; use std::{collections::BTreeMap, sync::Arc};
#[derive(Debug)] #[derive(Debug)]
pub struct MockDatabase { pub struct MockDatabase {
db_backend: DbBackend, db_backend: DbBackend,
transaction: Option<OpenTransaction>,
transaction_log: Vec<Transaction>, transaction_log: Vec<Transaction>,
exec_results: Vec<MockExecResult>, exec_results: Vec<MockExecResult>,
query_results: Vec<Vec<MockRow>>, query_results: Vec<Vec<MockRow>>,
@ -29,10 +30,22 @@ pub trait IntoMockRow {
fn into_mock_row(self) -> MockRow; fn into_mock_row(self) -> MockRow;
} }
#[derive(Debug)]
pub struct OpenTransaction {
stmts: Vec<Statement>,
transaction_depth: usize,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Transaction {
stmts: Vec<Statement>,
}
impl MockDatabase { impl MockDatabase {
pub fn new(db_backend: DbBackend) -> Self { pub fn new(db_backend: DbBackend) -> Self {
Self { Self {
db_backend, db_backend,
transaction: None,
transaction_log: Vec::new(), transaction_log: Vec::new(),
exec_results: Vec::new(), exec_results: Vec::new(),
query_results: Vec::new(), query_results: Vec::new(),
@ -40,7 +53,7 @@ impl MockDatabase {
} }
pub fn into_connection(self) -> DatabaseConnection { pub fn into_connection(self) -> DatabaseConnection {
DatabaseConnection::MockDatabaseConnection(MockDatabaseConnection::new(self)) DatabaseConnection::MockDatabaseConnection(Arc::new(MockDatabaseConnection::new(self)))
} }
pub fn append_exec_results(mut self, mut vec: Vec<MockExecResult>) -> Self { pub fn append_exec_results(mut self, mut vec: Vec<MockExecResult>) -> Self {
@ -62,7 +75,11 @@ impl MockDatabase {
impl MockDatabaseTrait for MockDatabase { impl MockDatabaseTrait for MockDatabase {
fn execute(&mut self, counter: usize, statement: Statement) -> Result<ExecResult, DbErr> { fn execute(&mut self, counter: usize, statement: Statement) -> Result<ExecResult, DbErr> {
self.transaction_log.push(Transaction::one(statement)); if let Some(transaction) = &mut self.transaction {
transaction.push(statement);
} else {
self.transaction_log.push(Transaction::one(statement));
}
if counter < self.exec_results.len() { if counter < self.exec_results.len() {
Ok(ExecResult { Ok(ExecResult {
result: ExecResultHolder::Mock(std::mem::take(&mut self.exec_results[counter])), result: ExecResultHolder::Mock(std::mem::take(&mut self.exec_results[counter])),
@ -73,7 +90,11 @@ impl MockDatabaseTrait for MockDatabase {
} }
fn query(&mut self, counter: usize, statement: Statement) -> Result<Vec<QueryResult>, DbErr> { fn query(&mut self, counter: usize, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
self.transaction_log.push(Transaction::one(statement)); if let Some(transaction) = &mut self.transaction {
transaction.push(statement);
} else {
self.transaction_log.push(Transaction::one(statement));
}
if counter < self.query_results.len() { if counter < self.query_results.len() {
Ok(std::mem::take(&mut self.query_results[counter]) Ok(std::mem::take(&mut self.query_results[counter])
.into_iter() .into_iter()
@ -86,6 +107,39 @@ impl MockDatabaseTrait for MockDatabase {
} }
} }
fn begin(&mut self) {
if self.transaction.is_some() {
self.transaction
.as_mut()
.unwrap()
.begin_nested(self.db_backend);
} else {
self.transaction = Some(OpenTransaction::init());
}
}
fn commit(&mut self) {
if self.transaction.is_some() {
if self.transaction.as_mut().unwrap().commit(self.db_backend) {
let transaction = self.transaction.take().unwrap();
self.transaction_log.push(transaction.into_transaction());
}
} else {
panic!("There is no open transaction to commit");
}
}
fn rollback(&mut self) {
if self.transaction.is_some() {
if self.transaction.as_mut().unwrap().rollback(self.db_backend) {
let transaction = self.transaction.take().unwrap();
self.transaction_log.push(transaction.into_transaction());
}
} else {
panic!("There is no open transaction to rollback");
}
}
fn drain_transaction_log(&mut self) -> Vec<Transaction> { fn drain_transaction_log(&mut self) -> Vec<Transaction> {
std::mem::take(&mut self.transaction_log) std::mem::take(&mut self.transaction_log)
} }
@ -100,7 +154,7 @@ impl MockRow {
where where
T: ValueType, T: ValueType,
{ {
Ok(self.values.get(col).unwrap().clone().unwrap()) T::try_from(self.values.get(col).unwrap().clone()).map_err(|e| DbErr::Query(e.to_string()))
} }
pub fn into_column_value_tuples(self) -> impl Iterator<Item = (String, Value)> { pub fn into_column_value_tuples(self) -> impl Iterator<Item = (String, Value)> {
@ -134,3 +188,372 @@ impl IntoMockRow for BTreeMap<&str, Value> {
} }
} }
} }
impl Transaction {
pub fn from_sql_and_values<I>(db_backend: DbBackend, sql: &str, values: I) -> Self
where
I: IntoIterator<Item = Value>,
{
Self::one(Statement::from_string_values_tuple(
db_backend,
(sql.to_string(), Values(values.into_iter().collect())),
))
}
/// Create a Transaction with one statement
pub fn one(stmt: Statement) -> Self {
Self { stmts: vec![stmt] }
}
/// Create a Transaction with many statements
pub fn many<I>(stmts: I) -> Self
where
I: IntoIterator<Item = Statement>,
{
Self {
stmts: stmts.into_iter().collect(),
}
}
/// Wrap each Statement as a single-statement Transaction
pub fn wrap<I>(stmts: I) -> Vec<Self>
where
I: IntoIterator<Item = Statement>,
{
stmts.into_iter().map(Self::one).collect()
}
}
impl OpenTransaction {
fn init() -> Self {
Self {
stmts: Vec::new(),
transaction_depth: 0,
}
}
fn begin_nested(&mut self, db_backend: DbBackend) {
self.transaction_depth += 1;
self.push(Statement::from_string(
db_backend,
format!("SAVEPOINT savepoint_{}", self.transaction_depth),
));
}
fn commit(&mut self, db_backend: DbBackend) -> bool {
if self.transaction_depth == 0 {
self.push(Statement::from_string(db_backend, "COMMIT".to_owned()));
true
} else {
self.push(Statement::from_string(
db_backend,
format!("RELEASE SAVEPOINT savepoint_{}", self.transaction_depth),
));
self.transaction_depth -= 1;
false
}
}
fn rollback(&mut self, db_backend: DbBackend) -> bool {
if self.transaction_depth == 0 {
self.push(Statement::from_string(db_backend, "ROLLBACK".to_owned()));
true
} else {
self.push(Statement::from_string(
db_backend,
format!("ROLLBACK TO SAVEPOINT savepoint_{}", self.transaction_depth),
));
self.transaction_depth -= 1;
false
}
}
fn push(&mut self, stmt: Statement) {
self.stmts.push(stmt);
}
fn into_transaction(self) -> Transaction {
if self.transaction_depth != 0 {
panic!("There is uncommitted nested transaction.");
}
Transaction { stmts: self.stmts }
}
}
#[cfg(test)]
#[cfg(feature = "mock")]
mod tests {
use crate::{
entity::*, tests_cfg::*, ConnectionTrait, DbBackend, DbErr, MockDatabase, Statement,
Transaction, TransactionError,
};
use pretty_assertions::assert_eq;
#[derive(Debug, PartialEq)]
pub struct MyErr(String);
impl std::error::Error for MyErr {}
impl std::fmt::Display for MyErr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0.as_str())
}
}
#[smol_potat::test]
async fn test_transaction_1() {
let db = MockDatabase::new(DbBackend::Postgres).into_connection();
db.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _1 = cake::Entity::find().one(txn).await;
let _2 = fruit::Entity::find().all(txn).await;
Ok(())
})
})
.await
.unwrap();
let _ = cake::Entity::find().all(&db).await;
assert_eq!(
db.into_transaction_log(),
vec![
Transaction::many(vec![
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#,
vec![1u64.into()]
),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#,
vec![]
),
Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()),
]),
Transaction::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake""#,
vec![]
),
]
);
}
#[smol_potat::test]
async fn test_transaction_2() {
let db = MockDatabase::new(DbBackend::Postgres).into_connection();
let result = db
.transaction::<_, (), MyErr>(|txn| {
Box::pin(async move {
let _ = cake::Entity::find().one(txn).await;
Err(MyErr("test".to_owned()))
})
})
.await;
match result {
Err(TransactionError::Transaction(err)) => {
assert_eq!(err, MyErr("test".to_owned()))
}
_ => panic!(),
}
assert_eq!(
db.into_transaction_log(),
vec![Transaction::many(vec![
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#,
vec![1u64.into()]
),
Statement::from_string(DbBackend::Postgres, "ROLLBACK".to_owned()),
]),]
);
}
#[smol_potat::test]
async fn test_nested_transaction_1() {
let db = MockDatabase::new(DbBackend::Postgres).into_connection();
db.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = cake::Entity::find().one(txn).await;
txn.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = fruit::Entity::find().all(txn).await;
Ok(())
})
})
.await
.unwrap();
Ok(())
})
})
.await
.unwrap();
assert_eq!(
db.into_transaction_log(),
vec![Transaction::many(vec![
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#,
vec![1u64.into()]
),
Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_1".to_owned()),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#,
vec![]
),
Statement::from_string(
DbBackend::Postgres,
"RELEASE SAVEPOINT savepoint_1".to_owned()
),
Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()),
]),]
);
}
#[smol_potat::test]
async fn test_nested_transaction_2() {
let db = MockDatabase::new(DbBackend::Postgres).into_connection();
db.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = cake::Entity::find().one(txn).await;
txn.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = fruit::Entity::find().all(txn).await;
txn.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = cake::Entity::find().all(txn).await;
Ok(())
})
})
.await
.unwrap();
Ok(())
})
})
.await
.unwrap();
Ok(())
})
})
.await
.unwrap();
assert_eq!(
db.into_transaction_log(),
vec![Transaction::many(vec![
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake" LIMIT $1"#,
vec![1u64.into()]
),
Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_1".to_owned()),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#,
vec![]
),
Statement::from_string(DbBackend::Postgres, "SAVEPOINT savepoint_2".to_owned()),
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "cake"."id", "cake"."name" FROM "cake""#,
vec![]
),
Statement::from_string(
DbBackend::Postgres,
"RELEASE SAVEPOINT savepoint_2".to_owned()
),
Statement::from_string(
DbBackend::Postgres,
"RELEASE SAVEPOINT savepoint_1".to_owned()
),
Statement::from_string(DbBackend::Postgres, "COMMIT".to_owned()),
]),]
);
}
#[smol_potat::test]
async fn test_stream_1() -> Result<(), DbErr> {
use futures::TryStreamExt;
let apple = fruit::Model {
id: 1,
name: "Apple".to_owned(),
cake_id: Some(1),
};
let orange = fruit::Model {
id: 2,
name: "orange".to_owned(),
cake_id: None,
};
let db = MockDatabase::new(DbBackend::Postgres)
.append_query_results(vec![vec![apple.clone(), orange.clone()]])
.into_connection();
let mut stream = fruit::Entity::find().stream(&db).await?;
assert_eq!(stream.try_next().await?, Some(apple));
assert_eq!(stream.try_next().await?, Some(orange));
assert_eq!(stream.try_next().await?, None);
Ok(())
}
#[smol_potat::test]
async fn test_stream_in_transaction() -> Result<(), DbErr> {
use futures::TryStreamExt;
let apple = fruit::Model {
id: 1,
name: "Apple".to_owned(),
cake_id: Some(1),
};
let orange = fruit::Model {
id: 2,
name: "orange".to_owned(),
cake_id: None,
};
let db = MockDatabase::new(DbBackend::Postgres)
.append_query_results(vec![vec![apple.clone(), orange.clone()]])
.into_connection();
let txn = db.begin().await?;
if let Ok(mut stream) = fruit::Entity::find().stream(&txn).await {
assert_eq!(stream.try_next().await?, Some(apple));
assert_eq!(stream.try_next().await?, Some(orange));
assert_eq!(stream.try_next().await?, None);
// stream will be dropped end of scope OR std::mem::drop(stream);
}
txn.commit().await?;
Ok(())
}
}

View File

@ -1,13 +1,17 @@
mod connection; mod connection;
mod db_connection;
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
mod mock; mod mock;
mod statement; mod statement;
mod stream;
mod transaction; mod transaction;
pub use connection::*; pub use connection::*;
pub use db_connection::*;
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
pub use mock::*; pub use mock::*;
pub use statement::*; pub use statement::*;
pub use stream::*;
pub use transaction::*; pub use transaction::*;
use crate::DbErr; use crate::DbErr;

View File

@ -0,0 +1,5 @@
mod query;
mod transaction;
pub use query::*;
pub use transaction::*;

View File

@ -0,0 +1,109 @@
use std::{pin::Pin, task::Poll};
#[cfg(feature = "mock")]
use std::sync::Arc;
use futures::Stream;
#[cfg(feature = "sqlx-dep")]
use futures::TryStreamExt;
#[cfg(feature = "sqlx-dep")]
use sqlx::{pool::PoolConnection, Executor};
use crate::{DbErr, InnerConnection, QueryResult, Statement};
#[ouroboros::self_referencing]
pub struct QueryStream {
stmt: Statement,
conn: InnerConnection,
#[borrows(mut conn, stmt)]
#[not_covariant]
stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + 'this>>,
}
#[cfg(feature = "sqlx-mysql")]
impl From<(PoolConnection<sqlx::MySql>, Statement)> for QueryStream {
fn from((conn, stmt): (PoolConnection<sqlx::MySql>, Statement)) -> Self {
QueryStream::build(stmt, InnerConnection::MySql(conn))
}
}
#[cfg(feature = "sqlx-postgres")]
impl From<(PoolConnection<sqlx::Postgres>, Statement)> for QueryStream {
fn from((conn, stmt): (PoolConnection<sqlx::Postgres>, Statement)) -> Self {
QueryStream::build(stmt, InnerConnection::Postgres(conn))
}
}
#[cfg(feature = "sqlx-sqlite")]
impl From<(PoolConnection<sqlx::Sqlite>, Statement)> for QueryStream {
fn from((conn, stmt): (PoolConnection<sqlx::Sqlite>, Statement)) -> Self {
QueryStream::build(stmt, InnerConnection::Sqlite(conn))
}
}
#[cfg(feature = "mock")]
impl From<(Arc<crate::MockDatabaseConnection>, Statement)> for QueryStream {
fn from((conn, stmt): (Arc<crate::MockDatabaseConnection>, Statement)) -> Self {
QueryStream::build(stmt, InnerConnection::Mock(conn))
}
}
impl std::fmt::Debug for QueryStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "QueryStream")
}
}
impl QueryStream {
fn build(stmt: Statement, conn: InnerConnection) -> QueryStream {
QueryStreamBuilder {
stmt,
conn,
stream_builder: |conn, stmt| match conn {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(c) => {
let query = crate::driver::sqlx_mysql::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(c) => {
let query = crate::driver::sqlx_postgres::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(c) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
}
#[cfg(feature = "mock")]
InnerConnection::Mock(c) => c.fetch(stmt),
},
}
.build()
}
}
impl Stream for QueryStream {
type Item = Result<QueryResult, DbErr>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
this.with_stream_mut(|stream| stream.as_mut().poll_next(cx))
}
}

View File

@ -0,0 +1,93 @@
use std::{ops::DerefMut, pin::Pin, task::Poll};
use futures::Stream;
#[cfg(feature = "sqlx-dep")]
use futures::TryStreamExt;
#[cfg(feature = "sqlx-dep")]
use sqlx::Executor;
use futures::lock::MutexGuard;
use crate::{DbErr, InnerConnection, QueryResult, Statement};
#[ouroboros::self_referencing]
/// `TransactionStream` cannot be used in a `transaction` closure as it does not impl `Send`.
/// It seems to be a Rust limitation right now, and solution to work around this deemed to be extremely hard.
pub struct TransactionStream<'a> {
stmt: Statement,
conn: MutexGuard<'a, InnerConnection>,
#[borrows(mut conn, stmt)]
#[not_covariant]
stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + 'this>>,
}
impl<'a> std::fmt::Debug for TransactionStream<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TransactionStream")
}
}
impl<'a> TransactionStream<'a> {
pub(crate) async fn build(
conn: MutexGuard<'a, InnerConnection>,
stmt: Statement,
) -> TransactionStream<'a> {
TransactionStreamAsyncBuilder {
stmt,
conn,
stream_builder: |conn, stmt| {
Box::pin(async move {
match conn.deref_mut() {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(c) => {
let query = crate::driver::sqlx_mysql::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>>
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(c) => {
let query = crate::driver::sqlx_postgres::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>>
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(c) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>>
}
#[cfg(feature = "mock")]
InnerConnection::Mock(c) => c.fetch(stmt),
}
})
},
}
.build()
.await
}
}
impl<'a> Stream for TransactionStream<'a> {
type Item = Result<QueryResult, DbErr>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
this.with_stream_mut(|stream| stream.as_mut().poll_next(cx))
}
}

View File

@ -1,42 +1,367 @@
use crate::{DbBackend, Statement}; use crate::{
use sea_query::{Value, Values}; debug_print, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult,
Statement, TransactionStream,
};
#[cfg(feature = "sqlx-dep")]
use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
use futures::lock::Mutex;
#[cfg(feature = "sqlx-dep")]
use sqlx::{pool::PoolConnection, TransactionManager};
use std::{future::Future, pin::Pin, sync::Arc};
#[derive(Debug, Clone, PartialEq)] // a Transaction is just a sugar for a connection where START TRANSACTION has been executed
pub struct Transaction { pub struct DatabaseTransaction {
stmts: Vec<Statement>, conn: Arc<Mutex<InnerConnection>>,
backend: DbBackend,
open: bool,
} }
impl Transaction { impl std::fmt::Debug for DatabaseTransaction {
pub fn from_sql_and_values<I>(db_backend: DbBackend, sql: &str, values: I) -> Self fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
where write!(f, "DatabaseTransaction")
I: IntoIterator<Item = Value>, }
{ }
Self::one(Statement::from_string_values_tuple(
db_backend, impl DatabaseTransaction {
(sql.to_string(), Values(values.into_iter().collect())), #[cfg(feature = "sqlx-mysql")]
)) pub(crate) async fn new_mysql(
inner: PoolConnection<sqlx::MySql>,
) -> Result<DatabaseTransaction, DbErr> {
Self::begin(
Arc::new(Mutex::new(InnerConnection::MySql(inner))),
DbBackend::MySql,
)
.await
} }
/// Create a Transaction with one statement #[cfg(feature = "sqlx-postgres")]
pub fn one(stmt: Statement) -> Self { pub(crate) async fn new_postgres(
Self { stmts: vec![stmt] } inner: PoolConnection<sqlx::Postgres>,
) -> Result<DatabaseTransaction, DbErr> {
Self::begin(
Arc::new(Mutex::new(InnerConnection::Postgres(inner))),
DbBackend::Postgres,
)
.await
} }
/// Create a Transaction with many statements #[cfg(feature = "sqlx-sqlite")]
pub fn many<I>(stmts: I) -> Self pub(crate) async fn new_sqlite(
inner: PoolConnection<sqlx::Sqlite>,
) -> Result<DatabaseTransaction, DbErr> {
Self::begin(
Arc::new(Mutex::new(InnerConnection::Sqlite(inner))),
DbBackend::Sqlite,
)
.await
}
#[cfg(feature = "mock")]
pub(crate) async fn new_mock(
inner: Arc<crate::MockDatabaseConnection>,
) -> Result<DatabaseTransaction, DbErr> {
let backend = inner.get_database_backend();
Self::begin(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await
}
async fn begin(
conn: Arc<Mutex<InnerConnection>>,
backend: DbBackend,
) -> Result<DatabaseTransaction, DbErr> {
let res = DatabaseTransaction {
conn,
backend,
open: true,
};
match *res.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(ref mut c) => {
<sqlx::MySql as sqlx::Database>::TransactionManager::begin(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(ref mut c) => {
<sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(ref mut c) => {
<sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "mock")]
InnerConnection::Mock(ref mut c) => {
c.begin();
}
}
Ok(res)
}
pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
where where
I: IntoIterator<Item = Statement>, F: for<'b> FnOnce(
&'b DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
+ Send,
T: Send,
E: std::error::Error + Send,
{ {
Self { let res = callback(&self).await.map_err(TransactionError::Transaction);
stmts: stmts.into_iter().collect(), if res.is_ok() {
self.commit().await.map_err(TransactionError::Connection)?;
} else {
self.rollback()
.await
.map_err(TransactionError::Connection)?;
}
res
}
pub async fn commit(mut self) -> Result<(), DbErr> {
self.open = false;
match *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(ref mut c) => {
<sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(ref mut c) => {
<sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(ref mut c) => {
<sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "mock")]
InnerConnection::Mock(ref mut c) => {
c.commit();
}
}
Ok(())
}
pub async fn rollback(mut self) -> Result<(), DbErr> {
self.open = false;
match *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(ref mut c) => {
<sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(ref mut c) => {
<sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(ref mut c) => {
<sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "mock")]
InnerConnection::Mock(ref mut c) => {
c.rollback();
}
}
Ok(())
}
// the rollback is queued and will be performed on next async operation, like returning the connection to the pool
fn start_rollback(&mut self) {
if self.open {
if let Some(mut conn) = self.conn.try_lock() {
match &mut *conn {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(c) => {
<sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(c) => {
<sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(c) => {
<sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
}
#[cfg(feature = "mock")]
InnerConnection::Mock(c) => {
c.rollback();
}
}
} else {
//this should never happen
panic!("Dropping a locked Transaction");
}
}
}
}
impl Drop for DatabaseTransaction {
fn drop(&mut self) {
self.start_rollback();
}
}
#[async_trait::async_trait]
impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
type Stream = TransactionStream<'a>;
fn get_database_backend(&self) -> DbBackend {
// this way we don't need to lock
self.backend
}
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
debug_print!("{}", stmt);
let _res = match &mut *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(conn) => {
let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
query.execute(conn).await.map(Into::into)
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(conn) => {
let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
query.execute(conn).await.map(Into::into)
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(conn) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
query.execute(conn).await.map(Into::into)
}
#[cfg(feature = "mock")]
InnerConnection::Mock(conn) => return conn.execute(stmt),
};
#[cfg(feature = "sqlx-dep")]
_res.map_err(sqlx_error_to_exec_err)
}
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
debug_print!("{}", stmt);
let _res = match &mut *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(conn) => {
let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
query.fetch_one(conn).await.map(|row| Some(row.into()))
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(conn) => {
let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
query.fetch_one(conn).await.map(|row| Some(row.into()))
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(conn) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
query.fetch_one(conn).await.map(|row| Some(row.into()))
}
#[cfg(feature = "mock")]
InnerConnection::Mock(conn) => return conn.query_one(stmt),
};
#[cfg(feature = "sqlx-dep")]
if let Err(sqlx::Error::RowNotFound) = _res {
Ok(None)
} else {
_res.map_err(sqlx_error_to_query_err)
} }
} }
/// Wrap each Statement as a single-statement Transaction async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
pub fn wrap<I>(stmts: I) -> Vec<Self> debug_print!("{}", stmt);
let _res = match &mut *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(conn) => {
let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
query
.fetch_all(conn)
.await
.map(|rows| rows.into_iter().map(|r| r.into()).collect())
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(conn) => {
let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
query
.fetch_all(conn)
.await
.map(|rows| rows.into_iter().map(|r| r.into()).collect())
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(conn) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
query
.fetch_all(conn)
.await
.map(|rows| rows.into_iter().map(|r| r.into()).collect())
}
#[cfg(feature = "mock")]
InnerConnection::Mock(conn) => return conn.query_all(stmt),
};
#[cfg(feature = "sqlx-dep")]
_res.map_err(sqlx_error_to_query_err)
}
fn stream(
&'a self,
stmt: Statement,
) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a>> {
Box::pin(
async move { Ok(crate::TransactionStream::build(self.conn.lock().await, stmt).await) },
)
}
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend).await
}
/// Execute the function inside a transaction.
/// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
where where
I: IntoIterator<Item = Statement>, F: for<'c> FnOnce(
&'c DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
T: Send,
E: std::error::Error + Send,
{ {
stmts.into_iter().map(Self::one).collect() let transaction = self.begin().await.map_err(TransactionError::Connection)?;
transaction.run(_callback).await
} }
} }
#[derive(Debug)]
pub enum TransactionError<E>
where
E: std::error::Error,
{
Connection(DbErr),
Transaction(E),
}
impl<E> std::fmt::Display for TransactionError<E>
where
E: std::error::Error,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
}
}
}
impl<E> std::error::Error for TransactionError<E> where E: std::error::Error {}

View File

@ -2,10 +2,14 @@ use crate::{
debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult, debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult,
Statement, Transaction, Statement, Transaction,
}; };
use std::fmt::Debug; use futures::Stream;
use std::sync::{ use std::{
atomic::{AtomicUsize, Ordering}, fmt::Debug,
Mutex, pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
}; };
#[derive(Debug)] #[derive(Debug)]
@ -22,6 +26,12 @@ pub trait MockDatabaseTrait: Send + Debug {
fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>; fn query(&mut self, counter: usize, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
fn begin(&mut self);
fn commit(&mut self);
fn rollback(&mut self);
fn drain_transaction_log(&mut self) -> Vec<Transaction>; fn drain_transaction_log(&mut self) -> Vec<Transaction>;
fn get_database_backend(&self) -> DbBackend; fn get_database_backend(&self) -> DbBackend;
@ -49,9 +59,9 @@ impl MockDatabaseConnector {
pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> { pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
macro_rules! connect_mock_db { macro_rules! connect_mock_db {
( $syntax: expr ) => { ( $syntax: expr ) => {
Ok(DatabaseConnection::MockDatabaseConnection( Ok(DatabaseConnection::MockDatabaseConnection(Arc::new(
MockDatabaseConnection::new(MockDatabase::new($syntax)), MockDatabaseConnection::new(MockDatabase::new($syntax)),
)) )))
}; };
} }
@ -82,30 +92,52 @@ impl MockDatabaseConnection {
} }
} }
pub fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> { pub(crate) fn get_mocker_mutex(&self) -> &Mutex<Box<dyn MockDatabaseTrait>> {
&self.mocker &self.mocker
} }
pub async fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> { pub fn get_database_backend(&self) -> DbBackend {
self.mocker.lock().unwrap().get_database_backend()
}
pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> {
debug_print!("{}", statement); debug_print!("{}", statement);
let counter = self.counter.fetch_add(1, Ordering::SeqCst); let counter = self.counter.fetch_add(1, Ordering::SeqCst);
self.mocker.lock().unwrap().execute(counter, statement) self.mocker.lock().unwrap().execute(counter, statement)
} }
pub async fn query_one(&self, statement: Statement) -> Result<Option<QueryResult>, DbErr> { pub fn query_one(&self, statement: Statement) -> Result<Option<QueryResult>, DbErr> {
debug_print!("{}", statement); debug_print!("{}", statement);
let counter = self.counter.fetch_add(1, Ordering::SeqCst); let counter = self.counter.fetch_add(1, Ordering::SeqCst);
let result = self.mocker.lock().unwrap().query(counter, statement)?; let result = self.mocker.lock().unwrap().query(counter, statement)?;
Ok(result.into_iter().next()) Ok(result.into_iter().next())
} }
pub async fn query_all(&self, statement: Statement) -> Result<Vec<QueryResult>, DbErr> { pub fn query_all(&self, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
debug_print!("{}", statement); debug_print!("{}", statement);
let counter = self.counter.fetch_add(1, Ordering::SeqCst); let counter = self.counter.fetch_add(1, Ordering::SeqCst);
self.mocker.lock().unwrap().query(counter, statement) self.mocker.lock().unwrap().query(counter, statement)
} }
pub fn get_database_backend(&self) -> DbBackend { pub fn fetch(
self.mocker.lock().unwrap().get_database_backend() &self,
statement: &Statement,
) -> Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>> {
match self.query_all(statement.clone()) {
Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(Ok))),
Err(e) => Box::pin(futures::stream::iter(Some(Err(e)).into_iter())),
}
}
pub fn begin(&self) {
self.mocker.lock().unwrap().begin()
}
pub fn commit(&self) {
self.mocker.lock().unwrap().commit()
}
pub fn rollback(&self) {
self.mocker.lock().unwrap().rollback()
} }
} }

View File

@ -3,11 +3,11 @@ mod mock;
#[cfg(feature = "sqlx-dep")] #[cfg(feature = "sqlx-dep")]
mod sqlx_common; mod sqlx_common;
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
mod sqlx_mysql; pub(crate) mod sqlx_mysql;
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
mod sqlx_postgres; pub(crate) mod sqlx_postgres;
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
mod sqlx_sqlite; pub(crate) mod sqlx_sqlite;
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
pub use mock::*; pub use mock::*;

View File

@ -1,3 +1,5 @@
use std::{future::Future, pin::Pin};
use sqlx::{ use sqlx::{
mysql::{MySqlArguments, MySqlQueryResult, MySqlRow}, mysql::{MySqlArguments, MySqlQueryResult, MySqlRow},
MySql, MySqlPool, MySql, MySqlPool,
@ -6,7 +8,10 @@ use sqlx::{
sea_query::sea_query_driver_mysql!(); sea_query::sea_query_driver_mysql!();
use sea_query_driver_mysql::bind_query; use sea_query_driver_mysql::bind_query;
use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; use crate::{
debug_print, error::*, executor::*, DatabaseConnection, DatabaseTransaction, QueryStream,
Statement, TransactionError,
};
use super::sqlx_common::*; use super::sqlx_common::*;
@ -20,7 +25,7 @@ pub struct SqlxMySqlPoolConnection {
impl SqlxMySqlConnector { impl SqlxMySqlConnector {
pub fn accepts(string: &str) -> bool { pub fn accepts(string: &str) -> bool {
DbBackend::MySql.is_prefix_of(string) string.starts_with("mysql://")
} }
pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> { pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
@ -91,6 +96,49 @@ impl SqlxMySqlPoolConnection {
)) ))
} }
} }
pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
debug_print!("{}", stmt);
if let Ok(conn) = self.pool.acquire().await {
Ok(QueryStream::from((conn, stmt)))
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
))
}
}
pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
if let Ok(conn) = self.pool.acquire().await {
DatabaseTransaction::new_mysql(conn).await
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
))
}
}
pub async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'b> FnOnce(
&'b DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
+ Send,
T: Send,
E: std::error::Error + Send,
{
if let Ok(conn) = self.pool.acquire().await {
let transaction = DatabaseTransaction::new_mysql(conn)
.await
.map_err(|e| TransactionError::Connection(e))?;
transaction.run(callback).await
} else {
Err(TransactionError::Connection(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
)))
}
}
} }
impl From<MySqlRow> for QueryResult { impl From<MySqlRow> for QueryResult {
@ -109,7 +157,7 @@ impl From<MySqlQueryResult> for ExecResult {
} }
} }
fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySqlArguments> { pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySqlArguments> {
let mut query = sqlx::query(&stmt.sql); let mut query = sqlx::query(&stmt.sql);
if let Some(values) = &stmt.values { if let Some(values) = &stmt.values {
query = bind_query(query, values); query = bind_query(query, values);

View File

@ -1,3 +1,5 @@
use std::{future::Future, pin::Pin};
use sqlx::{ use sqlx::{
postgres::{PgArguments, PgQueryResult, PgRow}, postgres::{PgArguments, PgQueryResult, PgRow},
PgPool, Postgres, PgPool, Postgres,
@ -6,7 +8,10 @@ use sqlx::{
sea_query::sea_query_driver_postgres!(); sea_query::sea_query_driver_postgres!();
use sea_query_driver_postgres::bind_query; use sea_query_driver_postgres::bind_query;
use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; use crate::{
debug_print, error::*, executor::*, DatabaseConnection, DatabaseTransaction, QueryStream,
Statement, TransactionError,
};
use super::sqlx_common::*; use super::sqlx_common::*;
@ -20,7 +25,7 @@ pub struct SqlxPostgresPoolConnection {
impl SqlxPostgresConnector { impl SqlxPostgresConnector {
pub fn accepts(string: &str) -> bool { pub fn accepts(string: &str) -> bool {
DbBackend::Postgres.is_prefix_of(string) string.starts_with("postgres://")
} }
pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> { pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
@ -91,6 +96,49 @@ impl SqlxPostgresPoolConnection {
)) ))
} }
} }
pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
debug_print!("{}", stmt);
if let Ok(conn) = self.pool.acquire().await {
Ok(QueryStream::from((conn, stmt)))
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
))
}
}
pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
if let Ok(conn) = self.pool.acquire().await {
DatabaseTransaction::new_postgres(conn).await
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
))
}
}
pub async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'b> FnOnce(
&'b DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
+ Send,
T: Send,
E: std::error::Error + Send,
{
if let Ok(conn) = self.pool.acquire().await {
let transaction = DatabaseTransaction::new_postgres(conn)
.await
.map_err(|e| TransactionError::Connection(e))?;
transaction.run(callback).await
} else {
Err(TransactionError::Connection(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
)))
}
}
} }
impl From<PgRow> for QueryResult { impl From<PgRow> for QueryResult {
@ -109,7 +157,7 @@ impl From<PgQueryResult> for ExecResult {
} }
} }
fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Postgres, PgArguments> { pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Postgres, PgArguments> {
let mut query = sqlx::query(&stmt.sql); let mut query = sqlx::query(&stmt.sql);
if let Some(values) = &stmt.values { if let Some(values) = &stmt.values {
query = bind_query(query, values); query = bind_query(query, values);

View File

@ -1,12 +1,17 @@
use std::{future::Future, pin::Pin};
use sqlx::{ use sqlx::{
sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}, sqlite::{SqliteArguments, SqlitePoolOptions, SqliteQueryResult, SqliteRow},
Sqlite, SqlitePool, Sqlite, SqlitePool,
}; };
sea_query::sea_query_driver_sqlite!(); sea_query::sea_query_driver_sqlite!();
use sea_query_driver_sqlite::bind_query; use sea_query_driver_sqlite::bind_query;
use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; use crate::{
debug_print, error::*, executor::*, DatabaseConnection, DatabaseTransaction, QueryStream,
Statement, TransactionError,
};
use super::sqlx_common::*; use super::sqlx_common::*;
@ -20,11 +25,15 @@ pub struct SqlxSqlitePoolConnection {
impl SqlxSqliteConnector { impl SqlxSqliteConnector {
pub fn accepts(string: &str) -> bool { pub fn accepts(string: &str) -> bool {
DbBackend::Sqlite.is_prefix_of(string) string.starts_with("sqlite:")
} }
pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> { pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
if let Ok(pool) = SqlitePool::connect(string).await { if let Ok(pool) = SqlitePoolOptions::new()
.max_connections(1)
.connect(string)
.await
{
Ok(DatabaseConnection::SqlxSqlitePoolConnection( Ok(DatabaseConnection::SqlxSqlitePoolConnection(
SqlxSqlitePoolConnection { pool }, SqlxSqlitePoolConnection { pool },
)) ))
@ -91,6 +100,49 @@ impl SqlxSqlitePoolConnection {
)) ))
} }
} }
pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
debug_print!("{}", stmt);
if let Ok(conn) = self.pool.acquire().await {
Ok(QueryStream::from((conn, stmt)))
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
))
}
}
pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
if let Ok(conn) = self.pool.acquire().await {
DatabaseTransaction::new_sqlite(conn).await
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
))
}
}
pub async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'b> FnOnce(
&'b DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
+ Send,
T: Send,
E: std::error::Error + Send,
{
if let Ok(conn) = self.pool.acquire().await {
let transaction = DatabaseTransaction::new_sqlite(conn)
.await
.map_err(|e| TransactionError::Connection(e))?;
transaction.run(callback).await
} else {
Err(TransactionError::Connection(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
)))
}
}
} }
impl From<SqliteRow> for QueryResult { impl From<SqliteRow> for QueryResult {
@ -109,7 +161,7 @@ impl From<SqliteQueryResult> for ExecResult {
} }
} }
fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqliteArguments> { pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqliteArguments> {
let mut query = sqlx::query(&stmt.sql); let mut query = sqlx::query(&stmt.sql);
if let Some(values) = &stmt.values { if let Some(values) = &stmt.values {
query = bind_query(query, values); query = bind_query(query, values);

View File

@ -1,8 +1,8 @@
use crate::{ use crate::{
error::*, DatabaseConnection, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, error::*, ConnectionTrait, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, Value,
PrimaryKeyTrait, Value,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use sea_query::ValueTuple;
use std::fmt::Debug; use std::fmt::Debug;
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
@ -10,7 +10,8 @@ pub struct ActiveValue<V>
where where
V: Into<Value>, V: Into<Value>,
{ {
value: Option<V>, // Don't want to call ActiveValue::unwrap() and cause panic
pub(self) value: Option<V>,
state: ActiveValueState, state: ActiveValueState,
} }
@ -67,33 +68,64 @@ pub trait ActiveModelTrait: Clone + Debug {
fn default() -> Self; fn default() -> Self;
async fn insert(self, db: &DatabaseConnection) -> Result<Self, DbErr> #[allow(clippy::question_mark)]
where fn get_primary_key_value(&self) -> Option<ValueTuple> {
Self: ActiveModelBehavior, let mut cols = <Self::Entity as EntityTrait>::PrimaryKey::iter();
<Self::Entity as EntityTrait>::Model: IntoActiveModel<Self>, macro_rules! next {
{ () => {
let am = ActiveModelBehavior::before_save(self, true, db)?; if let Some(col) = cols.next() {
let res = <Self::Entity as EntityTrait>::insert(am).exec(db).await?; if let Some(val) = self.get(col.into_column()).value {
// Assume valid last_insert_id is not equals to Default::default() val
if res.last_insert_id } else {
!= <<Self::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType::default() return None;
{ }
let found = <Self::Entity as EntityTrait>::find_by_id(res.last_insert_id) } else {
.one(db) return None;
.await?; }
let am = match found { };
Some(model) => Ok(model.into_active_model()), }
None => Err(DbErr::Exec("Failed to find inserted item".to_owned())), match <Self::Entity as EntityTrait>::PrimaryKey::iter().count() {
}?; 1 => {
ActiveModelBehavior::after_save(am, true, db) let s1 = next!();
} else { Some(ValueTuple::One(s1))
Ok(Self::default()) }
2 => {
let s1 = next!();
let s2 = next!();
Some(ValueTuple::Two(s1, s2))
}
3 => {
let s1 = next!();
let s2 = next!();
let s3 = next!();
Some(ValueTuple::Three(s1, s2, s3))
}
_ => panic!("The arity cannot be larger than 3"),
} }
} }
async fn update(self, db: &DatabaseConnection) -> Result<Self, DbErr> async fn insert<'a, C>(self, db: &'a C) -> Result<Self, DbErr>
where where
Self: ActiveModelBehavior, Self: ActiveModelBehavior,
<Self::Entity as EntityTrait>::Model: IntoActiveModel<Self>,
C: ConnectionTrait<'a>,
Self: 'a,
{
let am = ActiveModelBehavior::before_save(self, true, db)?;
let res = <Self::Entity as EntityTrait>::insert(am).exec(db).await?;
let found = <Self::Entity as EntityTrait>::find_by_id(res.last_insert_id)
.one(db)
.await?;
match found {
Some(model) => Ok(model.into_active_model()),
None => Err(DbErr::Exec("Failed to find inserted item".to_owned())),
}
}
async fn update<'a, C>(self, db: &'a C) -> Result<Self, DbErr>
where
C: ConnectionTrait<'a>,
Self: 'a,
{ {
let am = ActiveModelBehavior::before_save(self, false, db)?; let am = ActiveModelBehavior::before_save(self, false, db)?;
let am = Self::Entity::update(am).exec(db).await?; let am = Self::Entity::update(am).exec(db).await?;
@ -102,10 +134,11 @@ pub trait ActiveModelTrait: Clone + Debug {
/// Insert the model if primary key is unset, update otherwise. /// Insert the model if primary key is unset, update otherwise.
/// Only works if the entity has auto increment primary key. /// Only works if the entity has auto increment primary key.
async fn save(self, db: &DatabaseConnection) -> Result<Self, DbErr> async fn save<'a, C>(self, db: &'a C) -> Result<Self, DbErr>
where where
Self: ActiveModelBehavior, Self: ActiveModelBehavior + 'a,
<Self::Entity as EntityTrait>::Model: IntoActiveModel<Self>, <Self::Entity as EntityTrait>::Model: IntoActiveModel<Self>,
C: ConnectionTrait<'a>,
{ {
let mut am = self; let mut am = self;
let mut is_update = true; let mut is_update = true;
@ -125,9 +158,10 @@ pub trait ActiveModelTrait: Clone + Debug {
} }
/// Delete an active model by its primary key /// Delete an active model by its primary key
async fn delete(self, db: &DatabaseConnection) -> Result<DeleteResult, DbErr> async fn delete<'a, C>(self, db: &'a C) -> Result<DeleteResult, DbErr>
where where
Self: ActiveModelBehavior, Self: ActiveModelBehavior + 'a,
C: ConnectionTrait<'a>,
{ {
let am = ActiveModelBehavior::before_delete(self, db)?; let am = ActiveModelBehavior::before_delete(self, db)?;
let am_clone = am.clone(); let am_clone = am.clone();
@ -219,23 +253,23 @@ where
matches!(self.state, ActiveValueState::Unset) matches!(self.state, ActiveValueState::Unset)
} }
pub fn take(&mut self) -> V { pub fn take(&mut self) -> Option<V> {
self.state = ActiveValueState::Unset; self.state = ActiveValueState::Unset;
self.value.take().unwrap() self.value.take()
} }
pub fn unwrap(self) -> V { pub fn unwrap(self) -> V {
self.value.unwrap() self.value.unwrap()
} }
pub fn into_value(self) -> Value { pub fn into_value(self) -> Option<Value> {
self.value.unwrap().into() self.value.map(Into::into)
} }
pub fn into_wrapped_value(self) -> ActiveValue<Value> { pub fn into_wrapped_value(self) -> ActiveValue<Value> {
match self.state { match self.state {
ActiveValueState::Set => ActiveValue::set(self.into_value()), ActiveValueState::Set => ActiveValue::set(self.into_value().unwrap()),
ActiveValueState::Unchanged => ActiveValue::unchanged(self.into_value()), ActiveValueState::Unchanged => ActiveValue::unchanged(self.into_value().unwrap()),
ActiveValueState::Unset => ActiveValue::unset(), ActiveValueState::Unset => ActiveValue::unset(),
} }
} }

View File

@ -510,7 +510,7 @@ pub trait EntityTrait: EntityName {
/// ///
/// ``` /// ```
/// # #[cfg(feature = "mock")] /// # #[cfg(feature = "mock")]
/// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; /// # use sea_orm::{entity::*, error::*, query::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend};
/// # /// #
/// # let db = MockDatabase::new(DbBackend::Postgres) /// # let db = MockDatabase::new(DbBackend::Postgres)
/// # .append_exec_results(vec![ /// # .append_exec_results(vec![

View File

@ -1,16 +1,16 @@
use super::{ColumnTrait, IdenStatic, Iterable}; use super::{ColumnTrait, IdenStatic, Iterable};
use crate::{TryFromU64, TryGetableMany}; use crate::{TryFromU64, TryGetableMany};
use sea_query::IntoValueTuple; use sea_query::{FromValueTuple, IntoValueTuple};
use std::fmt::Debug; use std::fmt::Debug;
//LINT: composite primary key cannot auto increment //LINT: composite primary key cannot auto increment
pub trait PrimaryKeyTrait: IdenStatic + Iterable { pub trait PrimaryKeyTrait: IdenStatic + Iterable {
type ValueType: Sized type ValueType: Sized
+ Send + Send
+ Default
+ Debug + Debug
+ PartialEq + PartialEq
+ IntoValueTuple + IntoValueTuple
+ FromValueTuple
+ TryGetableMany + TryGetableMany
+ TryFromU64; + TryFromU64;

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
error::*, ActiveModelTrait, DatabaseConnection, DeleteMany, DeleteOne, EntityTrait, Statement, error::*, ActiveModelTrait, ConnectionTrait, DeleteMany, DeleteOne, EntityTrait, Statement,
}; };
use sea_query::DeleteStatement; use sea_query::DeleteStatement;
use std::future::Future; use std::future::Future;
@ -18,10 +18,10 @@ impl<'a, A: 'a> DeleteOne<A>
where where
A: ActiveModelTrait, A: ActiveModelTrait,
{ {
pub fn exec( pub fn exec<C>(self, db: &'a C) -> impl Future<Output = Result<DeleteResult, DbErr>> + '_
self, where
db: &'a DatabaseConnection, C: ConnectionTrait<'a>,
) -> impl Future<Output = Result<DeleteResult, DbErr>> + 'a { {
// so that self is dropped before entering await // so that self is dropped before entering await
exec_delete_only(self.query, db) exec_delete_only(self.query, db)
} }
@ -31,10 +31,10 @@ impl<'a, E> DeleteMany<E>
where where
E: EntityTrait, E: EntityTrait,
{ {
pub fn exec( pub fn exec<C>(self, db: &'a C) -> impl Future<Output = Result<DeleteResult, DbErr>> + '_
self, where
db: &'a DatabaseConnection, C: ConnectionTrait<'a>,
) -> impl Future<Output = Result<DeleteResult, DbErr>> + 'a { {
// so that self is dropped before entering await // so that self is dropped before entering await
exec_delete_only(self.query, db) exec_delete_only(self.query, db)
} }
@ -45,24 +45,26 @@ impl Deleter {
Self { query } Self { query }
} }
pub fn exec( pub fn exec<'a, C>(self, db: &'a C) -> impl Future<Output = Result<DeleteResult, DbErr>> + '_
self, where
db: &DatabaseConnection, C: ConnectionTrait<'a>,
) -> impl Future<Output = Result<DeleteResult, DbErr>> + '_ { {
let builder = db.get_database_backend(); let builder = db.get_database_backend();
exec_delete(builder.build(&self.query), db) exec_delete(builder.build(&self.query), db)
} }
} }
async fn exec_delete_only( async fn exec_delete_only<'a, C>(query: DeleteStatement, db: &'a C) -> Result<DeleteResult, DbErr>
query: DeleteStatement, where
db: &DatabaseConnection, C: ConnectionTrait<'a>,
) -> Result<DeleteResult, DbErr> { {
Deleter::new(query).exec(db).await Deleter::new(query).exec(db).await
} }
// Only Statement impl Send async fn exec_delete<'a, C>(statement: Statement, db: &'a C) -> Result<DeleteResult, DbErr>
async fn exec_delete(statement: Statement, db: &DatabaseConnection) -> Result<DeleteResult, DbErr> { where
C: ConnectionTrait<'a>,
{
let result = db.execute(statement).await?; let result = db.execute(statement).await?;
Ok(DeleteResult { Ok(DeleteResult {
rows_affected: result.rows_affected(), rows_affected: result.rows_affected(),

View File

@ -1,15 +1,16 @@
use crate::{ use crate::{
error::*, ActiveModelTrait, DatabaseConnection, DbBackend, EntityTrait, Insert, error::*, ActiveModelTrait, ConnectionTrait, DbBackend, EntityTrait, Insert, PrimaryKeyTrait,
PrimaryKeyTrait, Statement, TryFromU64, Statement, TryFromU64,
}; };
use sea_query::InsertStatement; use sea_query::{FromValueTuple, InsertStatement, ValueTuple};
use std::{future::Future, marker::PhantomData}; use std::{future::Future, marker::PhantomData};
#[derive(Clone, Debug)] #[derive(Debug)]
pub struct Inserter<A> pub struct Inserter<A>
where where
A: ActiveModelTrait, A: ActiveModelTrait,
{ {
primary_key: Option<ValueTuple>,
query: InsertStatement, query: InsertStatement,
model: PhantomData<A>, model: PhantomData<A>,
} }
@ -27,14 +28,11 @@ where
A: ActiveModelTrait, A: ActiveModelTrait,
{ {
#[allow(unused_mut)] #[allow(unused_mut)]
pub fn exec<'a>( pub fn exec<'a, C>(self, db: &'a C) -> impl Future<Output = Result<InsertResult<A>, DbErr>> + '_
self,
db: &'a DatabaseConnection,
) -> impl Future<Output = Result<InsertResult<A>, DbErr>> + 'a
where where
C: ConnectionTrait<'a>,
A: 'a, A: 'a,
{ {
// TODO: extract primary key's value from query
// so that self is dropped before entering await // so that self is dropped before entering await
let mut query = self.query; let mut query = self.query;
if db.get_database_backend() == DbBackend::Postgres { if db.get_database_backend() == DbBackend::Postgres {
@ -47,8 +45,7 @@ where
); );
} }
} }
Inserter::<A>::new(query).exec(db) Inserter::<A>::new(self.primary_key, query).exec(db)
// TODO: return primary key if extracted before, otherwise use InsertResult
} }
} }
@ -56,50 +53,55 @@ impl<A> Inserter<A>
where where
A: ActiveModelTrait, A: ActiveModelTrait,
{ {
pub fn new(query: InsertStatement) -> Self { pub fn new(primary_key: Option<ValueTuple>, query: InsertStatement) -> Self {
Self { Self {
primary_key,
query, query,
model: PhantomData, model: PhantomData,
} }
} }
pub fn exec<'a>( pub fn exec<'a, C>(self, db: &'a C) -> impl Future<Output = Result<InsertResult<A>, DbErr>> + '_
self,
db: &'a DatabaseConnection,
) -> impl Future<Output = Result<InsertResult<A>, DbErr>> + 'a
where where
C: ConnectionTrait<'a>,
A: 'a, A: 'a,
{ {
let builder = db.get_database_backend(); let builder = db.get_database_backend();
exec_insert(builder.build(&self.query), db) exec_insert(self.primary_key, builder.build(&self.query), db)
} }
} }
// Only Statement impl Send async fn exec_insert<'a, A, C>(
async fn exec_insert<A>( primary_key: Option<ValueTuple>,
statement: Statement, statement: Statement,
db: &DatabaseConnection, db: &'a C,
) -> Result<InsertResult<A>, DbErr> ) -> Result<InsertResult<A>, DbErr>
where where
C: ConnectionTrait<'a>,
A: ActiveModelTrait, A: ActiveModelTrait,
{ {
type PrimaryKey<A> = <<A as ActiveModelTrait>::Entity as EntityTrait>::PrimaryKey; type PrimaryKey<A> = <<A as ActiveModelTrait>::Entity as EntityTrait>::PrimaryKey;
type ValueTypeOf<A> = <PrimaryKey<A> as PrimaryKeyTrait>::ValueType; type ValueTypeOf<A> = <PrimaryKey<A> as PrimaryKeyTrait>::ValueType;
let last_insert_id = match db.get_database_backend() { let last_insert_id_opt = match db.get_database_backend() {
DbBackend::Postgres => { DbBackend::Postgres => {
use crate::{sea_query::Iden, Iterable}; use crate::{sea_query::Iden, Iterable};
let cols = PrimaryKey::<A>::iter() let cols = PrimaryKey::<A>::iter()
.map(|col| col.to_string()) .map(|col| col.to_string())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let res = db.query_one(statement).await?.unwrap(); let res = db.query_one(statement).await?.unwrap();
res.try_get_many("", cols.as_ref()).unwrap_or_default() res.try_get_many("", cols.as_ref()).ok()
} }
_ => { _ => {
let last_insert_id = db.execute(statement).await?.last_insert_id(); let last_insert_id = db.execute(statement).await?.last_insert_id();
ValueTypeOf::<A>::try_from_u64(last_insert_id) ValueTypeOf::<A>::try_from_u64(last_insert_id).ok()
.ok()
.unwrap_or_default()
} }
}; };
let last_insert_id = match last_insert_id_opt {
Some(last_insert_id) => last_insert_id,
None => match primary_key {
Some(value_tuple) => FromValueTuple::from_value_tuple(value_tuple),
None => return Err(DbErr::Exec("Fail to unpack last_insert_id".to_owned())),
},
};
Ok(InsertResult { last_insert_id }) Ok(InsertResult { last_insert_id })
} }

View File

@ -1,4 +1,4 @@
use crate::{error::*, DatabaseConnection, DbBackend, SelectorTrait}; use crate::{error::*, ConnectionTrait, DbBackend, SelectorTrait};
use async_stream::stream; use async_stream::stream;
use futures::Stream; use futures::Stream;
use sea_query::{Alias, Expr, SelectStatement}; use sea_query::{Alias, Expr, SelectStatement};
@ -7,21 +7,23 @@ use std::{marker::PhantomData, pin::Pin};
pub type PinBoxStream<'db, Item> = Pin<Box<dyn Stream<Item = Item> + 'db>>; pub type PinBoxStream<'db, Item> = Pin<Box<dyn Stream<Item = Item> + 'db>>;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Paginator<'db, S> pub struct Paginator<'db, C, S>
where where
C: ConnectionTrait<'db>,
S: SelectorTrait + 'db, S: SelectorTrait + 'db,
{ {
pub(crate) query: SelectStatement, pub(crate) query: SelectStatement,
pub(crate) page: usize, pub(crate) page: usize,
pub(crate) page_size: usize, pub(crate) page_size: usize,
pub(crate) db: &'db DatabaseConnection, pub(crate) db: &'db C,
pub(crate) selector: PhantomData<S>, pub(crate) selector: PhantomData<S>,
} }
// LINT: warn if paginator is used without an order by clause // LINT: warn if paginator is used without an order by clause
impl<'db, S> Paginator<'db, S> impl<'db, C, S> Paginator<'db, C, S>
where where
C: ConnectionTrait<'db>,
S: SelectorTrait + 'db, S: SelectorTrait + 'db,
{ {
/// Fetch a specific page; page index starts from zero /// Fetch a specific page; page index starts from zero
@ -155,7 +157,7 @@ where
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
mod tests { mod tests {
use crate::entity::prelude::*; use crate::entity::prelude::*;
use crate::tests_cfg::*; use crate::{tests_cfg::*, ConnectionTrait};
use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction}; use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction};
use futures::TryStreamExt; use futures::TryStreamExt;
use sea_query::{Alias, Expr, SelectStatement, Value}; use sea_query::{Alias, Expr, SelectStatement, Value};

View File

@ -126,12 +126,12 @@ macro_rules! try_getable_unsigned {
( $type: ty ) => { ( $type: ty ) => {
impl TryGetable for $type { impl TryGetable for $type {
fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result<Self, TryGetError> { fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result<Self, TryGetError> {
let column = format!("{}{}", pre, col); let _column = format!("{}{}", pre, col);
match &res.row { match &res.row {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
QueryResultRow::SqlxMySql(row) => { QueryResultRow::SqlxMySql(row) => {
use sqlx::Row; use sqlx::Row;
row.try_get::<Option<$type>, _>(column.as_str()) row.try_get::<Option<$type>, _>(_column.as_str())
.map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e))) .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e)))
.and_then(|opt| opt.ok_or(TryGetError::Null)) .and_then(|opt| opt.ok_or(TryGetError::Null))
} }
@ -142,13 +142,13 @@ macro_rules! try_getable_unsigned {
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
QueryResultRow::SqlxSqlite(row) => { QueryResultRow::SqlxSqlite(row) => {
use sqlx::Row; use sqlx::Row;
row.try_get::<Option<$type>, _>(column.as_str()) row.try_get::<Option<$type>, _>(_column.as_str())
.map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e))) .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e)))
.and_then(|opt| opt.ok_or(TryGetError::Null)) .and_then(|opt| opt.ok_or(TryGetError::Null))
} }
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
#[allow(unused_variables)] #[allow(unused_variables)]
QueryResultRow::Mock(row) => row.try_get(column.as_str()).map_err(|e| { QueryResultRow::Mock(row) => row.try_get(_column.as_str()).map_err(|e| {
debug_print!("{:#?}", e.to_string()); debug_print!("{:#?}", e.to_string());
TryGetError::Null TryGetError::Null
}), }),
@ -162,12 +162,12 @@ macro_rules! try_getable_mysql {
( $type: ty ) => { ( $type: ty ) => {
impl TryGetable for $type { impl TryGetable for $type {
fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result<Self, TryGetError> { fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result<Self, TryGetError> {
let column = format!("{}{}", pre, col); let _column = format!("{}{}", pre, col);
match &res.row { match &res.row {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
QueryResultRow::SqlxMySql(row) => { QueryResultRow::SqlxMySql(row) => {
use sqlx::Row; use sqlx::Row;
row.try_get::<Option<$type>, _>(column.as_str()) row.try_get::<Option<$type>, _>(_column.as_str())
.map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e))) .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e)))
.and_then(|opt| opt.ok_or(TryGetError::Null)) .and_then(|opt| opt.ok_or(TryGetError::Null))
} }
@ -181,7 +181,7 @@ macro_rules! try_getable_mysql {
} }
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
#[allow(unused_variables)] #[allow(unused_variables)]
QueryResultRow::Mock(row) => row.try_get(column.as_str()).map_err(|e| { QueryResultRow::Mock(row) => row.try_get(_column.as_str()).map_err(|e| {
debug_print!("{:#?}", e.to_string()); debug_print!("{:#?}", e.to_string());
TryGetError::Null TryGetError::Null
}), }),
@ -195,7 +195,7 @@ macro_rules! try_getable_postgres {
( $type: ty ) => { ( $type: ty ) => {
impl TryGetable for $type { impl TryGetable for $type {
fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result<Self, TryGetError> { fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result<Self, TryGetError> {
let column = format!("{}{}", pre, col); let _column = format!("{}{}", pre, col);
match &res.row { match &res.row {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
QueryResultRow::SqlxMySql(_) => { QueryResultRow::SqlxMySql(_) => {
@ -204,7 +204,7 @@ macro_rules! try_getable_postgres {
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
QueryResultRow::SqlxPostgres(row) => { QueryResultRow::SqlxPostgres(row) => {
use sqlx::Row; use sqlx::Row;
row.try_get::<Option<$type>, _>(column.as_str()) row.try_get::<Option<$type>, _>(_column.as_str())
.map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e))) .map_err(|e| TryGetError::DbErr(crate::sqlx_error_to_query_err(e)))
.and_then(|opt| opt.ok_or(TryGetError::Null)) .and_then(|opt| opt.ok_or(TryGetError::Null))
} }
@ -214,7 +214,7 @@ macro_rules! try_getable_postgres {
} }
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
#[allow(unused_variables)] #[allow(unused_variables)]
QueryResultRow::Mock(row) => row.try_get(column.as_str()).map_err(|e| { QueryResultRow::Mock(row) => row.try_get(_column.as_str()).map_err(|e| {
debug_print!("{:#?}", e.to_string()); debug_print!("{:#?}", e.to_string());
TryGetError::Null TryGetError::Null
}), }),

View File

@ -1,10 +1,12 @@
use crate::{ use crate::{
error::*, DatabaseConnection, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, error::*, ConnectionTrait, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue,
ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo,
SelectTwoMany, Statement, TryGetableMany, SelectTwoMany, Statement, TryGetableMany,
}; };
use futures::{Stream, TryStreamExt};
use sea_query::SelectStatement; use sea_query::SelectStatement;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Selector<S> pub struct Selector<S>
@ -234,23 +236,45 @@ where
Selector::<SelectGetableValue<T, C>>::with_columns(self.query) Selector::<SelectGetableValue<T, C>>::with_columns(self.query)
} }
pub async fn one(self, db: &DatabaseConnection) -> Result<Option<E::Model>, DbErr> { pub async fn one<'a, C>(self, db: &C) -> Result<Option<E::Model>, DbErr>
where
C: ConnectionTrait<'a>,
{
self.into_model().one(db).await self.into_model().one(db).await
} }
pub async fn all(self, db: &DatabaseConnection) -> Result<Vec<E::Model>, DbErr> { pub async fn all<'a, C>(self, db: &C) -> Result<Vec<E::Model>, DbErr>
where
C: ConnectionTrait<'a>,
{
self.into_model().all(db).await self.into_model().all(db).await
} }
pub fn paginate( pub async fn stream<'a: 'b, 'b, C>(
self, self,
db: &DatabaseConnection, db: &'a C,
) -> Result<impl Stream<Item = Result<E::Model, DbErr>> + 'b, DbErr>
where
C: ConnectionTrait<'a>,
{
self.into_model().stream(db).await
}
pub fn paginate<'a, C>(
self,
db: &'a C,
page_size: usize, page_size: usize,
) -> Paginator<'_, SelectModel<E::Model>> { ) -> Paginator<'a, C, SelectModel<E::Model>>
where
C: ConnectionTrait<'a>,
{
self.into_model().paginate(db, page_size) self.into_model().paginate(db, page_size)
} }
pub async fn count(self, db: &DatabaseConnection) -> Result<usize, DbErr> { pub async fn count<'a, C>(self, db: &'a C) -> Result<usize, DbErr>
where
C: ConnectionTrait<'a>,
{
self.paginate(db, 1).num_items().await self.paginate(db, 1).num_items().await
} }
} }
@ -279,29 +303,45 @@ where
} }
} }
pub async fn one( pub async fn one<'a, C>(self, db: &C) -> Result<Option<(E::Model, Option<F::Model>)>, DbErr>
self, where
db: &DatabaseConnection, C: ConnectionTrait<'a>,
) -> Result<Option<(E::Model, Option<F::Model>)>, DbErr> { {
self.into_model().one(db).await self.into_model().one(db).await
} }
pub async fn all( pub async fn all<'a, C>(self, db: &C) -> Result<Vec<(E::Model, Option<F::Model>)>, DbErr>
self, where
db: &DatabaseConnection, C: ConnectionTrait<'a>,
) -> Result<Vec<(E::Model, Option<F::Model>)>, DbErr> { {
self.into_model().all(db).await self.into_model().all(db).await
} }
pub fn paginate( pub async fn stream<'a: 'b, 'b, C>(
self, self,
db: &DatabaseConnection, db: &'a C,
) -> Result<impl Stream<Item = Result<(E::Model, Option<F::Model>), DbErr>> + 'b, DbErr>
where
C: ConnectionTrait<'a>,
{
self.into_model().stream(db).await
}
pub fn paginate<'a, C>(
self,
db: &'a C,
page_size: usize, page_size: usize,
) -> Paginator<'_, SelectTwoModel<E::Model, F::Model>> { ) -> Paginator<'a, C, SelectTwoModel<E::Model, F::Model>>
where
C: ConnectionTrait<'a>,
{
self.into_model().paginate(db, page_size) self.into_model().paginate(db, page_size)
} }
pub async fn count(self, db: &DatabaseConnection) -> Result<usize, DbErr> { pub async fn count<'a, C>(self, db: &'a C) -> Result<usize, DbErr>
where
C: ConnectionTrait<'a>,
{
self.paginate(db, 1).num_items().await self.paginate(db, 1).num_items().await
} }
} }
@ -330,17 +370,27 @@ where
} }
} }
pub async fn one( pub async fn one<'a, C>(self, db: &C) -> Result<Option<(E::Model, Option<F::Model>)>, DbErr>
self, where
db: &DatabaseConnection, C: ConnectionTrait<'a>,
) -> Result<Option<(E::Model, Option<F::Model>)>, DbErr> { {
self.into_model().one(db).await self.into_model().one(db).await
} }
pub async fn all( pub async fn stream<'a: 'b, 'b, C>(
self, self,
db: &DatabaseConnection, db: &'a C,
) -> Result<Vec<(E::Model, Vec<F::Model>)>, DbErr> { ) -> Result<impl Stream<Item = Result<(E::Model, Option<F::Model>), DbErr>> + 'b, DbErr>
where
C: ConnectionTrait<'a>,
{
self.into_model().stream(db).await
}
pub async fn all<'a, C>(self, db: &C) -> Result<Vec<(E::Model, Vec<F::Model>)>, DbErr>
where
C: ConnectionTrait<'a>,
{
let rows = self.into_model().all(db).await?; let rows = self.into_model().all(db).await?;
Ok(consolidate_query_result::<E, F>(rows)) Ok(consolidate_query_result::<E, F>(rows))
} }
@ -375,7 +425,10 @@ where
} }
} }
pub async fn one(mut self, db: &DatabaseConnection) -> Result<Option<S::Item>, DbErr> { pub async fn one<'a, C>(mut self, db: &C) -> Result<Option<S::Item>, DbErr>
where
C: ConnectionTrait<'a>,
{
let builder = db.get_database_backend(); let builder = db.get_database_backend();
self.query.limit(1); self.query.limit(1);
let row = db.query_one(builder.build(&self.query)).await?; let row = db.query_one(builder.build(&self.query)).await?;
@ -385,7 +438,10 @@ where
} }
} }
pub async fn all(self, db: &DatabaseConnection) -> Result<Vec<S::Item>, DbErr> { pub async fn all<'a, C>(self, db: &C) -> Result<Vec<S::Item>, DbErr>
where
C: ConnectionTrait<'a>,
{
let builder = db.get_database_backend(); let builder = db.get_database_backend();
let rows = db.query_all(builder.build(&self.query)).await?; let rows = db.query_all(builder.build(&self.query)).await?;
let mut models = Vec::new(); let mut models = Vec::new();
@ -395,7 +451,25 @@ where
Ok(models) Ok(models)
} }
pub fn paginate(self, db: &DatabaseConnection, page_size: usize) -> Paginator<'_, S> { pub async fn stream<'a: 'b, 'b, C>(
self,
db: &'a C,
) -> Result<Pin<Box<dyn Stream<Item = Result<S::Item, DbErr>> + 'b>>, DbErr>
where
C: ConnectionTrait<'a>,
S: 'b,
{
let builder = db.get_database_backend();
let stream = db.stream(builder.build(&self.query)).await?;
Ok(Box::pin(stream.and_then(|row| {
futures::future::ready(S::from_raw_query_result(row))
})))
}
pub fn paginate<'a, C>(self, db: &'a C, page_size: usize) -> Paginator<'a, C, S>
where
C: ConnectionTrait<'a>,
{
Paginator { Paginator {
query: self.query, query: self.query,
page: 0, page: 0,
@ -605,7 +679,10 @@ where
/// ),] /// ),]
/// ); /// );
/// ``` /// ```
pub async fn one(self, db: &DatabaseConnection) -> Result<Option<S::Item>, DbErr> { pub async fn one<'a, C>(self, db: &C) -> Result<Option<S::Item>, DbErr>
where
C: ConnectionTrait<'a>,
{
let row = db.query_one(self.stmt).await?; let row = db.query_one(self.stmt).await?;
match row { match row {
Some(row) => Ok(Some(S::from_raw_query_result(row)?)), Some(row) => Ok(Some(S::from_raw_query_result(row)?)),
@ -644,7 +721,10 @@ where
/// ),] /// ),]
/// ); /// );
/// ``` /// ```
pub async fn all(self, db: &DatabaseConnection) -> Result<Vec<S::Item>, DbErr> { pub async fn all<'a, C>(self, db: &C) -> Result<Vec<S::Item>, DbErr>
where
C: ConnectionTrait<'a>,
{
let rows = db.query_all(self.stmt).await?; let rows = db.query_all(self.stmt).await?;
let mut models = Vec::new(); let mut models = Vec::new();
for row in rows.into_iter() { for row in rows.into_iter() {

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
error::*, ActiveModelTrait, DatabaseConnection, EntityTrait, Statement, UpdateMany, UpdateOne, error::*, ActiveModelTrait, ConnectionTrait, EntityTrait, Statement, UpdateMany, UpdateOne,
}; };
use sea_query::UpdateStatement; use sea_query::UpdateStatement;
use std::future::Future; use std::future::Future;
@ -7,9 +7,10 @@ use std::future::Future;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Updater { pub struct Updater {
query: UpdateStatement, query: UpdateStatement,
check_record_exists: bool,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug, PartialEq)]
pub struct UpdateResult { pub struct UpdateResult {
pub rows_affected: u64, pub rows_affected: u64,
} }
@ -18,9 +19,12 @@ impl<'a, A: 'a> UpdateOne<A>
where where
A: ActiveModelTrait, A: ActiveModelTrait,
{ {
pub fn exec(self, db: &'a DatabaseConnection) -> impl Future<Output = Result<A, DbErr>> + 'a { pub async fn exec<'b, C>(self, db: &'b C) -> Result<A, DbErr>
where
C: ConnectionTrait<'b>,
{
// so that self is dropped before entering await // so that self is dropped before entering await
exec_update_and_return_original(self.query, self.model, db) exec_update_and_return_original(self.query, self.model, db).await
} }
} }
@ -28,10 +32,10 @@ impl<'a, E> UpdateMany<E>
where where
E: EntityTrait, E: EntityTrait,
{ {
pub fn exec( pub fn exec<C>(self, db: &'a C) -> impl Future<Output = Result<UpdateResult, DbErr>> + '_
self, where
db: &'a DatabaseConnection, C: ConnectionTrait<'a>,
) -> impl Future<Output = Result<UpdateResult, DbErr>> + 'a { {
// so that self is dropped before entering await // so that self is dropped before entering await
exec_update_only(self.query, db) exec_update_only(self.query, db)
} }
@ -39,41 +43,198 @@ where
impl Updater { impl Updater {
pub fn new(query: UpdateStatement) -> Self { pub fn new(query: UpdateStatement) -> Self {
Self { query } Self {
query,
check_record_exists: false,
}
} }
pub fn exec( pub fn check_record_exists(mut self) -> Self {
self, self.check_record_exists = true;
db: &DatabaseConnection, self
) -> impl Future<Output = Result<UpdateResult, DbErr>> + '_ { }
pub fn exec<'a, C>(self, db: &'a C) -> impl Future<Output = Result<UpdateResult, DbErr>> + '_
where
C: ConnectionTrait<'a>,
{
let builder = db.get_database_backend(); let builder = db.get_database_backend();
exec_update(builder.build(&self.query), db) exec_update(builder.build(&self.query), db, self.check_record_exists)
} }
} }
async fn exec_update_only( async fn exec_update_only<'a, C>(query: UpdateStatement, db: &'a C) -> Result<UpdateResult, DbErr>
query: UpdateStatement, where
db: &DatabaseConnection, C: ConnectionTrait<'a>,
) -> Result<UpdateResult, DbErr> { {
Updater::new(query).exec(db).await Updater::new(query).exec(db).await
} }
async fn exec_update_and_return_original<A>( async fn exec_update_and_return_original<'a, A, C>(
query: UpdateStatement, query: UpdateStatement,
model: A, model: A,
db: &DatabaseConnection, db: &'a C,
) -> Result<A, DbErr> ) -> Result<A, DbErr>
where where
A: ActiveModelTrait, A: ActiveModelTrait,
C: ConnectionTrait<'a>,
{ {
Updater::new(query).exec(db).await?; Updater::new(query).check_record_exists().exec(db).await?;
Ok(model) Ok(model)
} }
// Only Statement impl Send async fn exec_update<'a, C>(
async fn exec_update(statement: Statement, db: &DatabaseConnection) -> Result<UpdateResult, DbErr> { statement: Statement,
db: &'a C,
check_record_exists: bool,
) -> Result<UpdateResult, DbErr>
where
C: ConnectionTrait<'a>,
{
let result = db.execute(statement).await?; let result = db.execute(statement).await?;
if check_record_exists && result.rows_affected() == 0 {
return Err(DbErr::RecordNotFound(
"None of the database rows are affected".to_owned(),
));
}
Ok(UpdateResult { Ok(UpdateResult {
rows_affected: result.rows_affected(), rows_affected: result.rows_affected(),
}) })
} }
#[cfg(test)]
mod tests {
use crate::{entity::prelude::*, tests_cfg::*, *};
use pretty_assertions::assert_eq;
use sea_query::Expr;
#[smol_potat::test]
async fn update_record_not_found_1() -> Result<(), DbErr> {
let db = MockDatabase::new(DbBackend::Postgres)
.append_exec_results(vec![
MockExecResult {
last_insert_id: 0,
rows_affected: 1,
},
MockExecResult {
last_insert_id: 0,
rows_affected: 0,
},
MockExecResult {
last_insert_id: 0,
rows_affected: 0,
},
MockExecResult {
last_insert_id: 0,
rows_affected: 0,
},
MockExecResult {
last_insert_id: 0,
rows_affected: 0,
},
])
.into_connection();
let model = cake::Model {
id: 1,
name: "New York Cheese".to_owned(),
};
assert_eq!(
cake::ActiveModel {
name: Set("Cheese Cake".to_owned()),
..model.into_active_model()
}
.update(&db)
.await?,
cake::Model {
id: 1,
name: "Cheese Cake".to_owned(),
}
.into_active_model()
);
let model = cake::Model {
id: 2,
name: "New York Cheese".to_owned(),
};
assert_eq!(
cake::ActiveModel {
name: Set("Cheese Cake".to_owned()),
..model.clone().into_active_model()
}
.update(&db)
.await,
Err(DbErr::RecordNotFound(
"None of the database rows are affected".to_owned()
))
);
assert_eq!(
cake::Entity::update(cake::ActiveModel {
name: Set("Cheese Cake".to_owned()),
..model.clone().into_active_model()
})
.exec(&db)
.await,
Err(DbErr::RecordNotFound(
"None of the database rows are affected".to_owned()
))
);
assert_eq!(
Update::one(cake::ActiveModel {
name: Set("Cheese Cake".to_owned()),
..model.into_active_model()
})
.exec(&db)
.await,
Err(DbErr::RecordNotFound(
"None of the database rows are affected".to_owned()
))
);
assert_eq!(
Update::many(cake::Entity)
.col_expr(cake::Column::Name, Expr::value("Cheese Cake".to_owned()))
.filter(cake::Column::Id.eq(2))
.exec(&db)
.await,
Ok(UpdateResult { rows_affected: 0 })
);
assert_eq!(
db.into_transaction_log(),
vec![
Transaction::from_sql_and_values(
DbBackend::Postgres,
r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#,
vec!["Cheese Cake".into(), 1i32.into()]
),
Transaction::from_sql_and_values(
DbBackend::Postgres,
r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#,
vec!["Cheese Cake".into(), 2i32.into()]
),
Transaction::from_sql_and_values(
DbBackend::Postgres,
r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#,
vec!["Cheese Cake".into(), 2i32.into()]
),
Transaction::from_sql_and_values(
DbBackend::Postgres,
r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#,
vec!["Cheese Cake".into(), 2i32.into()]
),
Transaction::from_sql_and_values(
DbBackend::Postgres,
r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#,
vec!["Cheese Cake".into(), 2i32.into()]
),
]
);
Ok(())
}
}

View File

@ -28,7 +28,7 @@
//! SeaORM is a relational ORM to help you build light weight and concurrent web services in Rust. //! SeaORM is a relational ORM to help you build light weight and concurrent web services in Rust.
//! //!
//! [![Getting Started](https://img.shields.io/badge/Getting%20Started-brightgreen)](https://www.sea-ql.org/SeaORM/docs/index) //! [![Getting Started](https://img.shields.io/badge/Getting%20Started-brightgreen)](https://www.sea-ql.org/SeaORM/docs/index)
//! [![Usage Example](https://img.shields.io/badge/Usage%20Example-yellow)](https://github.com/SeaQL/sea-orm/tree/master/examples/async-std) //! [![Usage Example](https://img.shields.io/badge/Usage%20Example-yellow)](https://github.com/SeaQL/sea-orm/tree/master/examples/basic)
//! [![Actix Example](https://img.shields.io/badge/Actix%20Example-blue)](https://github.com/SeaQL/sea-orm/tree/master/examples/actix_example) //! [![Actix Example](https://img.shields.io/badge/Actix%20Example-blue)](https://github.com/SeaQL/sea-orm/tree/master/examples/actix_example)
//! [![Rocket Example](https://img.shields.io/badge/Rocket%20Example-orange)](https://github.com/SeaQL/sea-orm/tree/master/examples/rocket_example) //! [![Rocket Example](https://img.shields.io/badge/Rocket%20Example-orange)](https://github.com/SeaQL/sea-orm/tree/master/examples/rocket_example)
//! [![Discord](https://img.shields.io/discord/873880840487206962?label=Discord)](https://discord.com/invite/uCPdDXzbdv) //! [![Discord](https://img.shields.io/discord/873880840487206962?label=Discord)](https://discord.com/invite/uCPdDXzbdv)

View File

@ -1,14 +1,18 @@
use crate::{ActiveModelTrait, EntityName, EntityTrait, IntoActiveModel, Iterable, QueryTrait}; use crate::{
ActiveModelTrait, EntityName, EntityTrait, IntoActiveModel, Iterable, PrimaryKeyTrait,
QueryTrait,
};
use core::marker::PhantomData; use core::marker::PhantomData;
use sea_query::InsertStatement; use sea_query::{InsertStatement, ValueTuple};
#[derive(Clone, Debug)] #[derive(Debug)]
pub struct Insert<A> pub struct Insert<A>
where where
A: ActiveModelTrait, A: ActiveModelTrait,
{ {
pub(crate) query: InsertStatement, pub(crate) query: InsertStatement,
pub(crate) columns: Vec<bool>, pub(crate) columns: Vec<bool>,
pub(crate) primary_key: Option<ValueTuple>,
pub(crate) model: PhantomData<A>, pub(crate) model: PhantomData<A>,
} }
@ -31,6 +35,7 @@ where
.into_table(A::Entity::default().table_ref()) .into_table(A::Entity::default().table_ref())
.to_owned(), .to_owned(),
columns: Vec::new(), columns: Vec::new(),
primary_key: None,
model: PhantomData, model: PhantomData,
} }
} }
@ -107,6 +112,12 @@ where
M: IntoActiveModel<A>, M: IntoActiveModel<A>,
{ {
let mut am: A = m.into_active_model(); let mut am: A = m.into_active_model();
self.primary_key =
if !<<A::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::auto_increment() {
am.get_primary_key_value()
} else {
None
};
let mut columns = Vec::new(); let mut columns = Vec::new();
let mut values = Vec::new(); let mut values = Vec::new();
let columns_empty = self.columns.is_empty(); let columns_empty = self.columns.is_empty();
@ -120,7 +131,7 @@ where
} }
if av_has_val { if av_has_val {
columns.push(col); columns.push(col);
values.push(av.into_value()); values.push(av.into_value().unwrap());
} }
} }
self.query.columns(columns); self.query.columns(columns);

View File

@ -8,6 +8,7 @@ mod json;
mod select; mod select;
mod traits; mod traits;
mod update; mod update;
mod util;
pub use combine::{SelectA, SelectB}; pub use combine::{SelectA, SelectB};
pub use delete::*; pub use delete::*;
@ -19,5 +20,6 @@ pub use json::*;
pub use select::*; pub use select::*;
pub use traits::*; pub use traits::*;
pub use update::*; pub use update::*;
pub use util::*;
pub use crate::{InsertResult, Statement, UpdateResult, Value, Values}; pub use crate::{ConnectionTrait, InsertResult, Statement, UpdateResult, Value, Values};

112
src/query/util.rs Normal file
View File

@ -0,0 +1,112 @@
use crate::{database::*, QueryTrait, Statement};
#[derive(Debug)]
pub struct DebugQuery<'a, Q, T> {
pub query: &'a Q,
pub value: T,
}
macro_rules! debug_query_build {
($impl_obj:ty, $db_expr:expr) => {
impl<'a, Q> DebugQuery<'a, Q, $impl_obj>
where
Q: QueryTrait,
{
pub fn build(&self) -> Statement {
let func = $db_expr;
let db_backend = func(self);
self.query.build(db_backend)
}
}
};
}
debug_query_build!(DbBackend, |x: &DebugQuery<_, DbBackend>| x.value);
debug_query_build!(&DbBackend, |x: &DebugQuery<_, &DbBackend>| *x.value);
debug_query_build!(DatabaseConnection, |x: &DebugQuery<
_,
DatabaseConnection,
>| x.value.get_database_backend());
debug_query_build!(&DatabaseConnection, |x: &DebugQuery<
_,
&DatabaseConnection,
>| x.value.get_database_backend());
/// Helper to get a `Statement` from an object that impl `QueryTrait`.
///
/// # Example
///
/// ```
/// # #[cfg(feature = "mock")]
/// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, DbBackend};
/// #
/// # let conn = MockDatabase::new(DbBackend::Postgres)
/// # .into_connection();
/// #
/// use sea_orm::{entity::*, query::*, tests_cfg::cake, debug_query_stmt};
///
/// let c = cake::Entity::insert(
/// cake::ActiveModel {
/// id: ActiveValue::set(1),
/// name: ActiveValue::set("Apple Pie".to_owned()),
/// });
///
/// let raw_sql = debug_query_stmt!(&c, &conn).to_string();
/// assert_eq!(raw_sql, r#"INSERT INTO "cake" ("id", "name") VALUES (1, 'Apple Pie')"#);
///
/// let raw_sql = debug_query_stmt!(&c, conn).to_string();
/// assert_eq!(raw_sql, r#"INSERT INTO "cake" ("id", "name") VALUES (1, 'Apple Pie')"#);
///
/// let raw_sql = debug_query_stmt!(&c, DbBackend::MySql).to_string();
/// assert_eq!(raw_sql, r#"INSERT INTO `cake` (`id`, `name`) VALUES (1, 'Apple Pie')"#);
///
/// let raw_sql = debug_query_stmt!(&c, &DbBackend::MySql).to_string();
/// assert_eq!(raw_sql, r#"INSERT INTO `cake` (`id`, `name`) VALUES (1, 'Apple Pie')"#);
///
/// ```
#[macro_export]
macro_rules! debug_query_stmt {
($query:expr,$value:expr) => {
$crate::DebugQuery {
query: $query,
value: $value,
}
.build();
};
}
/// Helper to get a raw SQL string from an object that impl `QueryTrait`.
///
/// # Example
///
/// ```
/// # #[cfg(feature = "mock")]
/// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, DbBackend};
/// #
/// # let conn = MockDatabase::new(DbBackend::Postgres)
/// # .into_connection();
/// #
/// use sea_orm::{entity::*, query::*, tests_cfg::cake,debug_query};
///
/// let c = cake::Entity::insert(
/// cake::ActiveModel {
/// id: ActiveValue::set(1),
/// name: ActiveValue::set("Apple Pie".to_owned()),
/// });
///
/// let raw_sql = debug_query!(&c, &conn);
/// assert_eq!(raw_sql, r#"INSERT INTO "cake" ("id", "name") VALUES (1, 'Apple Pie')"#);
///
/// let raw_sql = debug_query!(&c, conn);
/// assert_eq!(raw_sql, r#"INSERT INTO "cake" ("id", "name") VALUES (1, 'Apple Pie')"#);
///
/// let raw_sql = debug_query!(&c, DbBackend::Sqlite);
/// assert_eq!(raw_sql, r#"INSERT INTO `cake` (`id`, `name`) VALUES (1, 'Apple Pie')"#);
///
/// ```
#[macro_export]
macro_rules! debug_query {
($query:expr,$value:expr) => {
$crate::debug_query_stmt!($query, $value).to_string();
};
}

View File

@ -7,6 +7,7 @@ pub mod cake_filling_price;
pub mod entity_linked; pub mod entity_linked;
pub mod filling; pub mod filling;
pub mod fruit; pub mod fruit;
pub mod rust_keyword;
pub mod vendor; pub mod vendor;
pub use cake::Entity as Cake; pub use cake::Entity as Cake;
@ -15,4 +16,5 @@ pub use cake_filling::Entity as CakeFilling;
pub use cake_filling_price::Entity as CakeFillingPrice; pub use cake_filling_price::Entity as CakeFillingPrice;
pub use filling::Entity as Filling; pub use filling::Entity as Filling;
pub use fruit::Entity as Fruit; pub use fruit::Entity as Fruit;
pub use rust_keyword::Entity as RustKeyword;
pub use vendor::Entity as Vendor; pub use vendor::Entity as Vendor;

View File

@ -0,0 +1,141 @@
use crate as sea_orm;
use crate::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "rust_keyword")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub testing: i32,
pub rust: i32,
pub keywords: i32,
pub r#raw_identifier: i32,
pub r#as: i32,
pub r#async: i32,
pub r#await: i32,
pub r#break: i32,
pub r#const: i32,
pub r#continue: i32,
pub crate_: i32,
pub r#dyn: i32,
pub r#else: i32,
pub r#enum: i32,
pub r#extern: i32,
pub r#false: i32,
pub r#fn: i32,
pub r#for: i32,
pub r#if: i32,
pub r#impl: i32,
pub r#in: i32,
pub r#let: i32,
pub r#loop: i32,
pub r#match: i32,
pub r#mod: i32,
pub r#move: i32,
pub r#mut: i32,
pub r#pub: i32,
pub r#ref: i32,
pub r#return: i32,
pub self_: i32,
pub r#static: i32,
pub r#struct: i32,
pub r#trait: i32,
pub r#true: i32,
pub r#type: i32,
pub r#union: i32,
pub r#unsafe: i32,
pub r#use: i32,
pub r#where: i32,
pub r#while: i32,
pub r#abstract: i32,
pub r#become: i32,
pub r#box: i32,
pub r#do: i32,
pub r#final: i32,
pub r#macro: i32,
pub r#override: i32,
pub r#priv: i32,
pub r#try: i32,
pub r#typeof: i32,
pub r#unsized: i32,
pub r#virtual: i32,
pub r#yield: i32,
}
#[derive(Copy, Clone, Debug, EnumIter)]
pub enum Relation {}
impl RelationTrait for Relation {
fn def(&self) -> RelationDef {
match self {
_ => panic!("No RelationDef"),
}
}
}
impl ActiveModelBehavior for ActiveModel {}
#[cfg(test)]
mod tests {
use crate::tests_cfg::rust_keyword::*;
use sea_query::Iden;
#[test]
fn test_columns() {
assert_eq!(Column::Id.to_string().as_str(), "id");
assert_eq!(Column::Testing.to_string().as_str(), "testing");
assert_eq!(Column::Rust.to_string().as_str(), "rust");
assert_eq!(Column::Keywords.to_string().as_str(), "keywords");
assert_eq!(Column::RawIdentifier.to_string().as_str(), "raw_identifier");
assert_eq!(Column::As.to_string().as_str(), "as");
assert_eq!(Column::Async.to_string().as_str(), "async");
assert_eq!(Column::Await.to_string().as_str(), "await");
assert_eq!(Column::Break.to_string().as_str(), "break");
assert_eq!(Column::Const.to_string().as_str(), "const");
assert_eq!(Column::Continue.to_string().as_str(), "continue");
assert_eq!(Column::Dyn.to_string().as_str(), "dyn");
assert_eq!(Column::Crate.to_string().as_str(), "crate");
assert_eq!(Column::Else.to_string().as_str(), "else");
assert_eq!(Column::Enum.to_string().as_str(), "enum");
assert_eq!(Column::Extern.to_string().as_str(), "extern");
assert_eq!(Column::False.to_string().as_str(), "false");
assert_eq!(Column::Fn.to_string().as_str(), "fn");
assert_eq!(Column::For.to_string().as_str(), "for");
assert_eq!(Column::If.to_string().as_str(), "if");
assert_eq!(Column::Impl.to_string().as_str(), "impl");
assert_eq!(Column::In.to_string().as_str(), "in");
assert_eq!(Column::Let.to_string().as_str(), "let");
assert_eq!(Column::Loop.to_string().as_str(), "loop");
assert_eq!(Column::Match.to_string().as_str(), "match");
assert_eq!(Column::Mod.to_string().as_str(), "mod");
assert_eq!(Column::Move.to_string().as_str(), "move");
assert_eq!(Column::Mut.to_string().as_str(), "mut");
assert_eq!(Column::Pub.to_string().as_str(), "pub");
assert_eq!(Column::Ref.to_string().as_str(), "ref");
assert_eq!(Column::Return.to_string().as_str(), "return");
assert_eq!(Column::Self_.to_string().as_str(), "self");
assert_eq!(Column::Static.to_string().as_str(), "static");
assert_eq!(Column::Struct.to_string().as_str(), "struct");
assert_eq!(Column::Trait.to_string().as_str(), "trait");
assert_eq!(Column::True.to_string().as_str(), "true");
assert_eq!(Column::Type.to_string().as_str(), "type");
assert_eq!(Column::Union.to_string().as_str(), "union");
assert_eq!(Column::Unsafe.to_string().as_str(), "unsafe");
assert_eq!(Column::Use.to_string().as_str(), "use");
assert_eq!(Column::Where.to_string().as_str(), "where");
assert_eq!(Column::While.to_string().as_str(), "while");
assert_eq!(Column::Abstract.to_string().as_str(), "abstract");
assert_eq!(Column::Become.to_string().as_str(), "become");
assert_eq!(Column::Box.to_string().as_str(), "box");
assert_eq!(Column::Do.to_string().as_str(), "do");
assert_eq!(Column::Final.to_string().as_str(), "final");
assert_eq!(Column::Macro.to_string().as_str(), "macro");
assert_eq!(Column::Override.to_string().as_str(), "override");
assert_eq!(Column::Priv.to_string().as_str(), "priv");
assert_eq!(Column::Try.to_string().as_str(), "try");
assert_eq!(Column::Typeof.to_string().as_str(), "typeof");
assert_eq!(Column::Unsized.to_string().as_str(), "unsized");
assert_eq!(Column::Virtual.to_string().as_str(), "virtual");
assert_eq!(Column::Yield.to_string().as_str(), "yield");
}
}

View File

@ -1,6 +1,6 @@
pub mod common; pub mod common;
pub use sea_orm::{entity::*, error::*, sea_query, tests_cfg::*, Database, DbConn}; pub use sea_orm::{entity::*, error::*, query::*, sea_query, tests_cfg::*, Database, DbConn};
// cargo test --features sqlx-sqlite,runtime-async-std-native-tls --test basic // cargo test --features sqlx-sqlite,runtime-async-std-native-tls --test basic
#[sea_orm_macros::test] #[sea_orm_macros::test]

View File

@ -10,8 +10,8 @@ pub struct Model {
pub key: String, pub key: String,
pub value: String, pub value: String,
pub bytes: Vec<u8>, pub bytes: Vec<u8>,
pub date: Date, pub date: Option<Date>,
pub time: Time, pub time: Option<Time>,
} }
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

View File

@ -1,4 +1,4 @@
use sea_orm::{Database, DatabaseBackend, DatabaseConnection, Statement}; use sea_orm::{ConnectionTrait, Database, DatabaseBackend, DatabaseConnection, Statement};
pub mod schema; pub mod schema;
pub use schema::*; pub use schema::*;

View File

@ -1,6 +1,8 @@
pub use super::super::bakery_chain::*; pub use super::super::bakery_chain::*;
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use sea_orm::{error::*, sea_query, DbBackend, DbConn, EntityTrait, ExecResult, Schema}; use sea_orm::{
error::*, sea_query, ConnectionTrait, DbBackend, DbConn, EntityTrait, ExecResult, Schema,
};
use sea_query::{ use sea_query::{
Alias, ColumnDef, ForeignKey, ForeignKeyAction, Index, Table, TableCreateStatement, Alias, ColumnDef, ForeignKey, ForeignKeyAction, Index, Table, TableCreateStatement,
}; };
@ -287,8 +289,8 @@ pub async fn create_metadata_table(db: &DbConn) -> Result<ExecResult, DbErr> {
.col(ColumnDef::new(metadata::Column::Key).string().not_null()) .col(ColumnDef::new(metadata::Column::Key).string().not_null())
.col(ColumnDef::new(metadata::Column::Value).string().not_null()) .col(ColumnDef::new(metadata::Column::Value).string().not_null())
.col(ColumnDef::new(metadata::Column::Bytes).binary().not_null()) .col(ColumnDef::new(metadata::Column::Bytes).binary().not_null())
.col(ColumnDef::new(metadata::Column::Date).date().not_null()) .col(ColumnDef::new(metadata::Column::Date).date())
.col(ColumnDef::new(metadata::Column::Time).time().not_null()) .col(ColumnDef::new(metadata::Column::Time).time())
.to_owned(); .to_owned();
create_table(db, &stmt, Metadata).await create_table(db, &stmt, Metadata).await

View File

@ -58,11 +58,7 @@ pub async fn test_create_cake(db: &DbConn) {
.expect("could not insert cake_baker"); .expect("could not insert cake_baker");
assert_eq!( assert_eq!(
cake_baker_res.last_insert_id, cake_baker_res.last_insert_id,
if cfg!(feature = "sqlx-postgres") { (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap())
(cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap())
} else {
Default::default()
}
); );
assert!(cake.is_some()); assert!(cake.is_some());

View File

@ -57,11 +57,7 @@ pub async fn test_create_lineitem(db: &DbConn) {
.expect("could not insert cake_baker"); .expect("could not insert cake_baker");
assert_eq!( assert_eq!(
cake_baker_res.last_insert_id, cake_baker_res.last_insert_id,
if cfg!(feature = "sqlx-postgres") { (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap())
(cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap())
} else {
Default::default()
}
); );
// Customer // Customer

View File

@ -57,11 +57,7 @@ pub async fn test_create_order(db: &DbConn) {
.expect("could not insert cake_baker"); .expect("could not insert cake_baker");
assert_eq!( assert_eq!(
cake_baker_res.last_insert_id, cake_baker_res.last_insert_id,
if cfg!(feature = "sqlx-postgres") { (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap())
(cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap())
} else {
Default::default()
}
); );
// Customer // Customer

View File

@ -1,5 +1,6 @@
pub use super::*; pub use super::*;
use rust_decimal_macros::dec; use rust_decimal_macros::dec;
use sea_orm::DbErr;
use uuid::Uuid; use uuid::Uuid;
pub async fn test_update_cake(db: &DbConn) { pub async fn test_update_cake(db: &DbConn) {
@ -119,10 +120,14 @@ pub async fn test_update_deleted_customer(db: &DbConn) {
..Default::default() ..Default::default()
}; };
let _customer_update_res: customer::ActiveModel = customer let customer_update_res = customer.update(db).await;
.update(db)
.await assert_eq!(
.expect("could not update customer"); customer_update_res,
Err(DbErr::RecordNotFound(
"None of the database rows are affected".to_owned()
))
);
assert_eq!(Customer::find().count(db).await.unwrap(), init_n_customers); assert_eq!(Customer::find().count(db).await.unwrap(), init_n_customers);

View File

@ -26,8 +26,8 @@ pub async fn crud_in_parallel(db: &DatabaseConnection) -> Result<(), DbErr> {
key: "markup".to_owned(), key: "markup".to_owned(),
value: "1.18".to_owned(), value: "1.18".to_owned(),
bytes: vec![1, 2, 3], bytes: vec![1, 2, 3],
date: Date::from_ymd(2021, 9, 27), date: Some(Date::from_ymd(2021, 9, 27)),
time: Time::from_hms(11, 32, 55), time: Some(Time::from_hms(11, 32, 55)),
}, },
metadata::Model { metadata::Model {
uuid: Uuid::new_v4(), uuid: Uuid::new_v4(),
@ -35,8 +35,8 @@ pub async fn crud_in_parallel(db: &DatabaseConnection) -> Result<(), DbErr> {
key: "exchange_rate".to_owned(), key: "exchange_rate".to_owned(),
value: "0.78".to_owned(), value: "0.78".to_owned(),
bytes: vec![1, 2, 3], bytes: vec![1, 2, 3],
date: Date::from_ymd(2021, 9, 27), date: Some(Date::from_ymd(2021, 9, 27)),
time: Time::from_hms(11, 32, 55), time: Some(Time::from_hms(11, 32, 55)),
}, },
metadata::Model { metadata::Model {
uuid: Uuid::new_v4(), uuid: Uuid::new_v4(),
@ -44,8 +44,8 @@ pub async fn crud_in_parallel(db: &DatabaseConnection) -> Result<(), DbErr> {
key: "service_charge".to_owned(), key: "service_charge".to_owned(),
value: "1.1".to_owned(), value: "1.1".to_owned(),
bytes: vec![1, 2, 3], bytes: vec![1, 2, 3],
date: Date::from_ymd(2021, 9, 27), date: None,
time: Time::from_hms(11, 32, 55), time: None,
}, },
]; ];

View File

@ -2,7 +2,7 @@ pub mod common;
pub use common::{bakery_chain::*, setup::*, TestContext}; pub use common::{bakery_chain::*, setup::*, TestContext};
pub use sea_orm::entity::*; pub use sea_orm::entity::*;
pub use sea_orm::QueryFilter; pub use sea_orm::{ConnectionTrait, QueryFilter};
// Run the test locally: // Run the test locally:
// DATABASE_URL="mysql://root:@localhost" cargo test --features sqlx-mysql,runtime-async-std --test query_tests // DATABASE_URL="mysql://root:@localhost" cargo test --features sqlx-mysql,runtime-async-std --test query_tests

View File

@ -84,11 +84,7 @@ async fn init_setup(db: &DatabaseConnection) {
.expect("could not insert cake_baker"); .expect("could not insert cake_baker");
assert_eq!( assert_eq!(
cake_baker_res.last_insert_id, cake_baker_res.last_insert_id,
if cfg!(feature = "sqlx-postgres") { (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap())
(cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap())
} else {
Default::default()
}
); );
let customer_kate = customer::ActiveModel { let customer_kate = customer::ActiveModel {
@ -179,7 +175,7 @@ async fn find_baker_least_sales(db: &DatabaseConnection) -> Option<baker::Model>
let mut results: Vec<LeastSalesBakerResult> = select let mut results: Vec<LeastSalesBakerResult> = select
.into_model::<SelectResult>() .into_model::<SelectResult>()
.all(&db) .all(db)
.await .await
.unwrap() .unwrap()
.into_iter() .into_iter()
@ -225,11 +221,7 @@ async fn create_cake(db: &DatabaseConnection, baker: baker::Model) -> Option<cak
.expect("could not insert cake_baker"); .expect("could not insert cake_baker");
assert_eq!( assert_eq!(
cake_baker_res.last_insert_id, cake_baker_res.last_insert_id,
if cfg!(feature = "sqlx-postgres") { (cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap())
(cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap())
} else {
Default::default()
}
); );
Cake::find_by_id(cake_insert_res.last_insert_id) Cake::find_by_id(cake_insert_res.last_insert_id)

38
tests/stream_tests.rs Normal file
View File

@ -0,0 +1,38 @@
pub mod common;
pub use common::{bakery_chain::*, setup::*, TestContext};
pub use sea_orm::entity::*;
pub use sea_orm::{ConnectionTrait, DbErr, QueryFilter};
#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
feature = "sqlx-sqlite",
feature = "sqlx-postgres"
))]
pub async fn stream() -> Result<(), DbErr> {
use futures::StreamExt;
let ctx = TestContext::new("stream").await;
let bakery = bakery::ActiveModel {
name: Set("SeaSide Bakery".to_owned()),
profit_margin: Set(10.4),
..Default::default()
}
.save(&ctx.db)
.await?;
let result = Bakery::find_by_id(bakery.id.clone().unwrap())
.stream(&ctx.db)
.await?
.next()
.await
.unwrap()?;
assert_eq!(result.id, bakery.id.unwrap());
ctx.delete().await;
Ok(())
}

348
tests/transaction_tests.rs Normal file
View File

@ -0,0 +1,348 @@
pub mod common;
pub use common::{bakery_chain::*, setup::*, TestContext};
pub use sea_orm::entity::*;
pub use sea_orm::{ConnectionTrait, QueryFilter};
use sea_orm::{DatabaseTransaction, DbErr};
#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
feature = "sqlx-sqlite",
feature = "sqlx-postgres"
))]
pub async fn transaction() {
let ctx = TestContext::new("transaction_test").await;
ctx.db
.transaction::<_, _, DbErr>(|txn| {
Box::pin(async move {
let _ = bakery::ActiveModel {
name: Set("SeaSide Bakery".to_owned()),
profit_margin: Set(10.4),
..Default::default()
}
.save(txn)
.await?;
let _ = bakery::ActiveModel {
name: Set("Top Bakery".to_owned()),
profit_margin: Set(15.0),
..Default::default()
}
.save(txn)
.await?;
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains("Bakery"))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 2);
Ok(())
})
})
.await
.unwrap();
ctx.delete().await;
}
#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
feature = "sqlx-sqlite",
feature = "sqlx-postgres"
))]
pub async fn transaction_with_reference() {
let ctx = TestContext::new("transaction_with_reference_test").await;
let name1 = "SeaSide Bakery";
let name2 = "Top Bakery";
let search_name = "Bakery";
ctx.db
.transaction(|txn| _transaction_with_reference(txn, name1, name2, search_name))
.await
.unwrap();
ctx.delete().await;
}
fn _transaction_with_reference<'a>(
txn: &'a DatabaseTransaction,
name1: &'a str,
name2: &'a str,
search_name: &'a str,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), DbErr>> + Send + 'a>> {
Box::pin(async move {
let _ = bakery::ActiveModel {
name: Set(name1.to_owned()),
profit_margin: Set(10.4),
..Default::default()
}
.save(txn)
.await?;
let _ = bakery::ActiveModel {
name: Set(name2.to_owned()),
profit_margin: Set(15.0),
..Default::default()
}
.save(txn)
.await?;
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains(search_name))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 2);
Ok(())
})
}
#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
feature = "sqlx-sqlite",
feature = "sqlx-postgres"
))]
pub async fn transaction_nested() {
let ctx = TestContext::new("transaction_nested_test").await;
ctx.db
.transaction::<_, _, DbErr>(|txn| {
Box::pin(async move {
let _ = bakery::ActiveModel {
name: Set("SeaSide Bakery".to_owned()),
profit_margin: Set(10.4),
..Default::default()
}
.save(txn)
.await?;
let _ = bakery::ActiveModel {
name: Set("Top Bakery".to_owned()),
profit_margin: Set(15.0),
..Default::default()
}
.save(txn)
.await?;
// Try nested transaction committed
txn.transaction::<_, _, DbErr>(|txn| {
Box::pin(async move {
let _ = bakery::ActiveModel {
name: Set("Nested Bakery".to_owned()),
profit_margin: Set(88.88),
..Default::default()
}
.save(txn)
.await?;
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains("Bakery"))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 3);
// Try nested-nested transaction rollbacked
let is_err = txn
.transaction::<_, _, DbErr>(|txn| {
Box::pin(async move {
let _ = bakery::ActiveModel {
name: Set("Rock n Roll Bakery".to_owned()),
profit_margin: Set(28.8),
..Default::default()
}
.save(txn)
.await?;
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains("Bakery"))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 4);
if true {
Err(DbErr::Query("Force Rollback!".to_owned()))
} else {
Ok(())
}
})
})
.await
.is_err();
assert!(is_err);
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains("Bakery"))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 3);
// Try nested-nested transaction committed
txn.transaction::<_, _, DbErr>(|txn| {
Box::pin(async move {
let _ = bakery::ActiveModel {
name: Set("Rock n Roll Bakery".to_owned()),
profit_margin: Set(28.8),
..Default::default()
}
.save(txn)
.await?;
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains("Bakery"))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 4);
Ok(())
})
})
.await
.unwrap();
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains("Bakery"))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 4);
Ok(())
})
})
.await
.unwrap();
// Try nested transaction rollbacked
let is_err = txn
.transaction::<_, _, DbErr>(|txn| {
Box::pin(async move {
let _ = bakery::ActiveModel {
name: Set("Rock n Roll Bakery".to_owned()),
profit_margin: Set(28.8),
..Default::default()
}
.save(txn)
.await?;
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains("Bakery"))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 5);
// Try nested-nested transaction committed
txn.transaction::<_, _, DbErr>(|txn| {
Box::pin(async move {
let _ = bakery::ActiveModel {
name: Set("Rock n Roll Bakery".to_owned()),
profit_margin: Set(28.8),
..Default::default()
}
.save(txn)
.await?;
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains("Bakery"))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 6);
Ok(())
})
})
.await
.unwrap();
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains("Bakery"))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 6);
// Try nested-nested transaction rollbacked
let is_err = txn
.transaction::<_, _, DbErr>(|txn| {
Box::pin(async move {
let _ = bakery::ActiveModel {
name: Set("Rock n Roll Bakery".to_owned()),
profit_margin: Set(28.8),
..Default::default()
}
.save(txn)
.await?;
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains("Bakery"))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 7);
if true {
Err(DbErr::Query("Force Rollback!".to_owned()))
} else {
Ok(())
}
})
})
.await
.is_err();
assert!(is_err);
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains("Bakery"))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 6);
if true {
Err(DbErr::Query("Force Rollback!".to_owned()))
} else {
Ok(())
}
})
})
.await
.is_err();
assert!(is_err);
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains("Bakery"))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 4);
Ok(())
})
})
.await
.unwrap();
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains("Bakery"))
.all(&ctx.db)
.await
.unwrap();
assert_eq!(bakeries.len(), 4);
ctx.delete().await;
}

View File

@ -1,7 +1,7 @@
pub mod common; pub mod common;
pub use common::{bakery_chain::*, setup::*, TestContext}; pub use common::{bakery_chain::*, setup::*, TestContext};
use sea_orm::{entity::prelude::*, DatabaseConnection, IntoActiveModel}; use sea_orm::{entity::prelude::*, DatabaseConnection, IntoActiveModel, Set};
#[sea_orm_macros::test] #[sea_orm_macros::test]
#[cfg(any( #[cfg(any(
@ -24,8 +24,8 @@ pub async fn create_metadata(db: &DatabaseConnection) -> Result<(), DbErr> {
key: "markup".to_owned(), key: "markup".to_owned(),
value: "1.18".to_owned(), value: "1.18".to_owned(),
bytes: vec![1, 2, 3], bytes: vec![1, 2, 3],
date: Date::from_ymd(2021, 9, 27), date: Some(Date::from_ymd(2021, 9, 27)),
time: Time::from_hms(11, 32, 55), time: Some(Time::from_hms(11, 32, 55)),
}; };
let res = Metadata::insert(metadata.clone().into_active_model()) let res = Metadata::insert(metadata.clone().into_active_model())
@ -34,13 +34,21 @@ pub async fn create_metadata(db: &DatabaseConnection) -> Result<(), DbErr> {
assert_eq!(Metadata::find().one(db).await?, Some(metadata.clone())); assert_eq!(Metadata::find().one(db).await?, Some(metadata.clone()));
assert_eq!(res.last_insert_id, metadata.uuid);
let update_res = Metadata::update(metadata::ActiveModel {
value: Set("0.22".to_owned()),
..metadata.clone().into_active_model()
})
.filter(metadata::Column::Uuid.eq(Uuid::default()))
.exec(db)
.await;
assert_eq!( assert_eq!(
res.last_insert_id, update_res,
if cfg!(feature = "sqlx-postgres") { Err(DbErr::RecordNotFound(
metadata.uuid "None of the database rows are affected".to_owned()
} else { ))
Default::default()
}
); );
Ok(()) Ok(())