From 87d4f09931df6210e8745a2235ac1cc86303c053 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Wed, 25 Aug 2021 23:05:30 +0800 Subject: [PATCH] Add rocket-mysql driver --- Cargo.toml | 3 ++ examples/rocket_example/Cargo.toml | 19 ++------- examples/rocket_example/src/sqlx/setup.rs | 48 +++++++++++------------ src/driver/mod.rs | 5 +++ src/driver/rocket_mysql.rs | 39 ++++++++++++++++++ src/entity/column.rs | 2 +- 6 files changed, 75 insertions(+), 41 deletions(-) create mode 100644 src/driver/rocket_mysql.rs diff --git a/Cargo.toml b/Cargo.toml index ca61c819..bc90db55 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,8 @@ sqlx-core = { version = "^0.5", optional = true } sqlx-macros = { version = "^0.5", optional = true } serde_json = { version = "^1", optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true } +rocket = { git = "https://github.com/SergioBenitez/Rocket.git", features = ["json"], optional = true } +rocket_db_pools = { git = "https://github.com/SergioBenitez/Rocket.git", features = ["sqlx_mysql"], optional = true } [dev-dependencies] smol = { version = "^1.2" } @@ -91,3 +93,4 @@ runtime-actix-rustls = ["sqlx/runtime-actix-rustls", "runtime-actix"] runtime-tokio = [] runtime-tokio-native-tls = ["sqlx/runtime-tokio-native-tls", "runtime-tokio"] runtime-tokio-rustls = ["sqlx/runtime-tokio-rustls", "runtime-tokio"] +rocket-mysql = ["rocket", "rocket_db_pools"] diff --git a/examples/rocket_example/Cargo.toml b/examples/rocket_example/Cargo.toml index d7c05921..feb77d43 100644 --- a/examples/rocket_example/Cargo.toml +++ b/examples/rocket_example/Cargo.toml @@ -5,12 +5,9 @@ edition = "2018" publish = false [workspace] [dependencies] -rocket = { path = "../../../Rocket/core/lib", features = ["json"] } -# rocket = { git = "https://github.com/SergioBenitez/Rocket.git", branch = "master", features = [ -# "json", -# ] } -# async-std = { version = "^1.9", features = ["attributes"] } -sea-orm = { path = "../../", features = ["sqlx-all"] } +rocket = { git = "https://github.com/SergioBenitez/Rocket.git", features = ["json"] } +rocket_db_pools = { git = "https://github.com/SergioBenitez/Rocket.git", features = ["sqlx_mysql"] } +sea-orm = { path = "../../", features = ["sqlx-all", "rocket-mysql"] } sea-query = { version = "^0.12.8" } serde_json = { version = "^1" } @@ -22,13 +19,3 @@ futures-util = { version = "^0.3" } version = "0.5.1" default-features = false features = ["macros", "offline", "migrate"] - -# [dependencies.rocket_db_pools] -# git = "https://github.com/SergioBenitez/Rocket" -# branch = "master" -# features = ["sea_orm"] -[dependencies.rocket_db_pools] -path = "../../../Rocket/contrib/db_pools/lib" -# git = "https://github.com/samsamai/Rocket.git" -# branch = "ss/seaorm-contrib" -features = ["seaorm_sqlx_sqlite"] diff --git a/examples/rocket_example/src/sqlx/setup.rs b/examples/rocket_example/src/sqlx/setup.rs index b7ba2df3..fac405be 100644 --- a/examples/rocket_example/src/sqlx/setup.rs +++ b/examples/rocket_example/src/sqlx/setup.rs @@ -6,32 +6,32 @@ use sea_orm::sea_query::{ColumnDef, ForeignKey, ForeignKeyAction, Index, TableCr pub use super::post::*; async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result { - let builder = db.get_database_backend(); - db.execute(builder.build(stmt)).await + let builder = db.get_database_backend(); + db.execute(builder.build(stmt)).await } pub async fn create_post_table(db: &DbConn) -> Result { - let stmt = sea_query::Table::create() - .table(super::post::Entity) - .if_not_exists() - .col( - ColumnDef::new(super::post::Column::Id) - .integer() - .not_null() - .auto_increment() - .primary_key(), - ) - .col( - ColumnDef::new(super::post::Column::Title) - .string() - .not_null(), - ) - .col( - ColumnDef::new(super::post::Column::Text) - .string() - .not_null(), - ) - .to_owned(); + let stmt = sea_query::Table::create() + .table(super::post::Entity) + .if_not_exists() + .col( + ColumnDef::new(super::post::Column::Id) + .integer() + .not_null() + .auto_increment() + .primary_key(), + ) + .col( + ColumnDef::new(super::post::Column::Title) + .string() + .not_null(), + ) + .col( + ColumnDef::new(super::post::Column::Text) + .string() + .not_null(), + ) + .to_owned(); - create_table(db, &stmt).await + create_table(db, &stmt).await } diff --git a/src/driver/mod.rs b/src/driver/mod.rs index 6f6cfb64..5786d299 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -9,8 +9,13 @@ mod sqlx_postgres; #[cfg(feature = "sqlx-sqlite")] mod sqlx_sqlite; +#[cfg(feature = "rocket-mysql")] +mod rocket_mysql; + #[cfg(feature = "mock")] pub use mock::*; +#[cfg(feature = "rocket-mysql")] +pub use rocket_mysql::*; #[cfg(feature = "sqlx-dep")] pub use sqlx_common::*; #[cfg(feature = "sqlx-mysql")] diff --git a/src/driver/rocket_mysql.rs b/src/driver/rocket_mysql.rs new file mode 100644 index 00000000..ef90f1ee --- /dev/null +++ b/src/driver/rocket_mysql.rs @@ -0,0 +1,39 @@ +use rocket::figment::Figment; +use rocket_db_pools::{Config, Error}; + +#[rocket::async_trait] +impl rocket_db_pools::Pool for crate::Database { + type Error = crate::DbErr; + + type Connection = crate::DatabaseConnection; + + async fn init(figment: &Figment) -> Result { + // let config = figment.extract::()?; + // let mut opts = config.url.parse::>().map_err(Error::Init)?; + // opts.disable_statement_logging(); + // specialize(&mut opts, &config); + + // sqlx::pool::PoolOptions::new() + // .max_connections(config.max_connections as u32) + // .connect_timeout(Duration::from_secs(config.connect_timeout)) + // .idle_timeout(config.idle_timeout.map(Duration::from_secs)) + // .min_connections(config.min_connections.unwrap_or_default()) + // .connect_with(opts) + // .await + // .map_err(Error::Init) + Ok(crate::Database {}) + } + + async fn get(&self) -> Result { + // self.acquire().await.map_err(Error::Get) + // let con = crate::Database::connect("sqlite::memory:").await; + + // Ok(crate::Database::connect("sqlite::memory:").await.unwrap()) + // "mysql://root:@localhost" + Ok( + crate::Database::connect("mysql://root:@localhost/rocket_example") + .await + .unwrap(), + ) + } +} diff --git a/src/entity/column.rs b/src/entity/column.rs index 16546057..611950f5 100644 --- a/src/entity/column.rs +++ b/src/entity/column.rs @@ -1,6 +1,6 @@ -use std::str::FromStr; use crate::{EntityName, IdenStatic, Iterable}; use sea_query::{DynIden, Expr, SeaRc, SelectStatement, SimpleExpr, Value}; +use std::str::FromStr; #[derive(Debug, Clone)] pub struct ColumnDef {