diff --git a/examples/rocket_example/Cargo.toml b/examples/rocket_example/Cargo.toml index 22dc9a6b..c896fe2d 100644 --- a/examples/rocket_example/Cargo.toml +++ b/examples/rocket_example/Cargo.toml @@ -15,15 +15,21 @@ futures-util = { version = "^0.3" } rocket = { git = "https://github.com/SergioBenitez/Rocket.git", features = [ "json", ] } -rocket_db_pools = { git = "https://github.com/SergioBenitez/Rocket.git" } rocket_dyn_templates = { git = "https://github.com/SergioBenitez/Rocket.git", features = [ "tera", ] } -# remove `path = ""` in your own project -sea-orm = { path = "../../", version = "^0.2.3", features = ["macros"], default-features = false } serde_json = { version = "^1" } +[dependencies.sea-orm] +path = "../../" # remove this line in your own project +version = "^0.2.3" +features = ["macros", "runtime-tokio-native-tls"] +default-features = false + +[dependencies.sea-orm-rocket] +path = "../../sea-orm-rocket/lib" + [features] default = ["sqlx-postgres"] -sqlx-mysql = ["sea-orm/sqlx-mysql", "rocket_db_pools/sqlx_mysql"] -sqlx-postgres = ["sea-orm/sqlx-postgres", "rocket_db_pools/sqlx_postgres"] +sqlx-mysql = ["sea-orm/sqlx-mysql"] +sqlx-postgres = ["sea-orm/sqlx-postgres"] diff --git a/examples/rocket_example/Rocket.toml b/examples/rocket_example/Rocket.toml index b7fcc12a..fc294bd2 100644 --- a/examples/rocket_example/Rocket.toml +++ b/examples/rocket_example/Rocket.toml @@ -1,7 +1,7 @@ [default] template_dir = "templates/" -[default.databases.rocket_example] +[default.databases.sea_orm] # Mysql # make sure to enable "sqlx-mysql" feature in Cargo.toml, i.e default = ["sqlx-mysql"] # url = "mysql://root:@localhost/rocket_example" diff --git a/examples/rocket_example/src/main.rs b/examples/rocket_example/src/main.rs index 43f227c6..e2ef9254 100644 --- a/examples/rocket_example/src/main.rs +++ b/examples/rocket_example/src/main.rs @@ -7,21 +7,17 @@ use rocket::fs::{relative, FileServer}; use rocket::request::FlashMessage; use rocket::response::{Flash, Redirect}; use rocket::{Build, Request, Rocket}; -use rocket_db_pools::{sqlx, Connection, Database}; use rocket_dyn_templates::{context, Template}; use sea_orm::{entity::*, query::*}; +use sea_orm_rocket::{Connection, Database}; mod pool; -use pool::RocketDbPool; +use pool::Db; mod setup; -#[derive(Database, Debug)] -#[database("rocket_example")] -struct Db(RocketDbPool); - -type Result> = std::result::Result; +type Result> = std::result::Result; mod post; pub use post::Entity as Post; @@ -34,7 +30,9 @@ async fn new() -> Template { } #[post("/", data = "")] -async fn create(conn: Connection, post_form: Form) -> Flash { +async fn create(conn: Connection<'_, Db>, post_form: Form) -> Flash { + let db = conn.into_inner(); + let form = post_form.into_inner(); post::ActiveModel { @@ -42,7 +40,7 @@ async fn create(conn: Connection, post_form: Form) -> Flash, post_form: Form) -> Flash", data = "")] -async fn update(conn: Connection, id: i32, post_form: Form) -> Flash { +async fn update(conn: Connection<'_, Db>, id: i32, post_form: Form) -> Flash { + let db = conn.into_inner(); + let post: post::ActiveModel = Post::find_by_id(id) - .one(&*conn) + .one(db) .await .unwrap() .unwrap() @@ -65,7 +65,7 @@ async fn update(conn: Connection, id: i32, post_form: Form) -> title: Set(form.title.to_owned()), text: Set(form.text.to_owned()), } - .save(&*conn) + .save(db) .await .expect("could not edit post"); @@ -74,11 +74,13 @@ async fn update(conn: Connection, id: i32, post_form: Form) -> #[get("/?&")] async fn list( - conn: Connection, + conn: Connection<'_, Db>, posts_per_page: Option, page: Option, flash: Option>, ) -> Template { + let db = conn.into_inner(); + // Set page number and items per page let page = page.unwrap_or(1); let posts_per_page = posts_per_page.unwrap_or(DEFAULT_POSTS_PER_PAGE); @@ -89,7 +91,7 @@ async fn list( // Setup paginator let paginator = Post::find() .order_by_asc(post::Column::Id) - .paginate(&*conn, posts_per_page); + .paginate(db, posts_per_page); let num_pages = paginator.num_pages().await.ok().unwrap(); // Fetch paginated posts @@ -111,9 +113,11 @@ async fn list( } #[get("/")] -async fn edit(conn: Connection, id: i32) -> Template { +async fn edit(conn: Connection<'_, Db>, id: i32) -> Template { + let db = conn.into_inner(); + let post: Option = Post::find_by_id(id) - .one(&*conn) + .one(db) .await .expect("could not find post"); @@ -126,22 +130,26 @@ async fn edit(conn: Connection, id: i32) -> Template { } #[delete("/")] -async fn delete(conn: Connection, id: i32) -> Flash { +async fn delete(conn: Connection<'_, Db>, id: i32) -> Flash { + let db = conn.into_inner(); + let post: post::ActiveModel = Post::find_by_id(id) - .one(&*conn) + .one(db) .await .unwrap() .unwrap() .into(); - post.delete(&*conn).await.unwrap(); + post.delete(db).await.unwrap(); Flash::success(Redirect::to("/"), "Post successfully deleted.") } #[delete("/")] -async fn destroy(conn: Connection) -> Result<()> { - Post::delete_many().exec(&*conn).await.unwrap(); +async fn destroy(conn: Connection<'_, Db>) -> Result<()> { + let db = conn.into_inner(); + + Post::delete_many().exec(db).await.unwrap(); Ok(()) } diff --git a/examples/rocket_example/src/pool.rs b/examples/rocket_example/src/pool.rs index c4140c1f..931a4712 100644 --- a/examples/rocket_example/src/pool.rs +++ b/examples/rocket_example/src/pool.rs @@ -1,13 +1,17 @@ use async_trait::async_trait; -use rocket_db_pools::{rocket::figment::Figment, Config}; +use sea_orm_rocket::{rocket::figment::Figment, Config, Database}; + +#[derive(Database, Debug)] +#[database("sea_orm")] +pub struct Db(SeaOrmPool); #[derive(Debug)] -pub struct RocketDbPool { +pub struct SeaOrmPool { pub conn: sea_orm::DatabaseConnection, } #[async_trait] -impl rocket_db_pools::Pool for RocketDbPool { +impl sea_orm_rocket::Pool for SeaOrmPool { type Error = sea_orm::DbErr; type Connection = sea_orm::DatabaseConnection; @@ -16,10 +20,10 @@ impl rocket_db_pools::Pool for RocketDbPool { let config = figment.extract::().unwrap(); let conn = sea_orm::Database::connect(&config.url).await.unwrap(); - Ok(RocketDbPool { conn }) + Ok(SeaOrmPool { conn }) } - async fn get(&self) -> Result { - Ok(self.conn.clone()) + fn borrow(&self) -> &Self::Connection { + &self.conn } } diff --git a/sea-orm-rocket/Cargo.toml b/sea-orm-rocket/Cargo.toml new file mode 100644 index 00000000..4975f8e1 --- /dev/null +++ b/sea-orm-rocket/Cargo.toml @@ -0,0 +1,2 @@ +[workspace] +members = ["codegen", "lib"] \ No newline at end of file diff --git a/sea-orm-rocket/README.md b/sea-orm-rocket/README.md new file mode 100644 index 00000000..f8d94b21 --- /dev/null +++ b/sea-orm-rocket/README.md @@ -0,0 +1 @@ +# SeaORM Rocket support crate. \ No newline at end of file diff --git a/sea-orm-rocket/codegen/Cargo.toml b/sea-orm-rocket/codegen/Cargo.toml new file mode 100644 index 00000000..75656487 --- /dev/null +++ b/sea-orm-rocket/codegen/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "sea-orm-rocket-codegen" +version = "0.1.0-rc" +authors = ["Sergio Benitez ", "Jeb Rosen "] +description = "Procedural macros for sea_orm_rocket." +repository = "https://github.com/SergioBenitez/Rocket/contrib/db_pools" +readme = "../README.md" +keywords = ["rocket", "framework", "database", "pools"] +license = "MIT OR Apache-2.0" +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +devise = "0.3" +quote = "1" + +[dev-dependencies] +rocket = { git = "https://github.com/SergioBenitez/Rocket.git", default-features = false } +trybuild = "1.0" +version_check = "0.9" diff --git a/sea-orm-rocket/codegen/src/database.rs b/sea-orm-rocket/codegen/src/database.rs new file mode 100644 index 00000000..a6f5a981 --- /dev/null +++ b/sea-orm-rocket/codegen/src/database.rs @@ -0,0 +1,110 @@ +use proc_macro::TokenStream; + +use devise::{DeriveGenerator, FromMeta, MapperBuild, Support, ValidatorBuild}; +use devise::proc_macro2_diagnostics::SpanDiagnosticExt; +use devise::syn::{self, spanned::Spanned}; + +const ONE_DATABASE_ATTR: &str = "missing `#[database(\"name\")]` attribute"; +const ONE_UNNAMED_FIELD: &str = "struct must have exactly one unnamed field"; + +#[derive(Debug, FromMeta)] +struct DatabaseAttribute { + #[meta(naked)] + name: String, +} + +pub fn derive_database(input: TokenStream) -> TokenStream { + DeriveGenerator::build_for(input, quote!(impl sea_orm_rocket::Database)) + .support(Support::TupleStruct) + .validator(ValidatorBuild::new() + .struct_validate(|_, s| { + if s.fields.len() == 1 { + Ok(()) + } else { + Err(s.span().error(ONE_UNNAMED_FIELD)) + } + }) + ) + .outer_mapper(MapperBuild::new() + .struct_map(|_, s| { + let pool_type = match &s.fields { + syn::Fields::Unnamed(f) => &f.unnamed[0].ty, + _ => unreachable!("Support::TupleStruct"), + }; + + let decorated_type = &s.ident; + let db_ty = quote_spanned!(decorated_type.span() => + <#decorated_type as sea_orm_rocket::Database> + ); + + quote_spanned! { decorated_type.span() => + impl From<#pool_type> for #decorated_type { + fn from(pool: #pool_type) -> Self { + Self(pool) + } + } + + impl std::ops::Deref for #decorated_type { + type Target = #pool_type; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl std::ops::DerefMut for #decorated_type { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + #[rocket::async_trait] + impl<'r> rocket::request::FromRequest<'r> for &'r #decorated_type { + type Error = (); + + async fn from_request( + req: &'r rocket::request::Request<'_> + ) -> rocket::request::Outcome { + match #db_ty::fetch(req.rocket()) { + Some(db) => rocket::outcome::Outcome::Success(db), + None => rocket::outcome::Outcome::Failure(( + rocket::http::Status::InternalServerError, ())) + } + } + } + + impl rocket::Sentinel for &#decorated_type { + fn abort(rocket: &rocket::Rocket) -> bool { + #db_ty::fetch(rocket).is_none() + } + } + } + }) + ) + .outer_mapper(quote!(#[rocket::async_trait])) + .inner_mapper(MapperBuild::new() + .try_struct_map(|_, s| { + let db_name = DatabaseAttribute::one_from_attrs("database", &s.attrs)? + .map(|attr| attr.name) + .ok_or_else(|| s.span().error(ONE_DATABASE_ATTR))?; + + let fairing_name = format!("'{}' Database Pool", db_name); + + let pool_type = match &s.fields { + syn::Fields::Unnamed(f) => &f.unnamed[0].ty, + _ => unreachable!("Support::TupleStruct"), + }; + + Ok(quote_spanned! { pool_type.span() => + type Pool = #pool_type; + + const NAME: &'static str = #db_name; + + fn init() -> sea_orm_rocket::Initializer { + sea_orm_rocket::Initializer::with_name(#fairing_name) + } + }) + }) + ) + .to_tokens() +} diff --git a/sea-orm-rocket/codegen/src/lib.rs b/sea-orm-rocket/codegen/src/lib.rs new file mode 100644 index 00000000..7cb32f94 --- /dev/null +++ b/sea-orm-rocket/codegen/src/lib.rs @@ -0,0 +1,52 @@ +#![recursion_limit="256"] +#![warn(rust_2018_idioms)] + +//! # `sea_orm_rocket` - Code Generation +//! +//! Implements the code generation portion of the `sea_orm_rocket` crate. This +//! is an implementation detail. This create should never be depended on +//! directly. + +#[macro_use] extern crate quote; + +mod database; + +/// Automatic derive for the [`Database`] trait. +/// +/// The derive generates an implementation of [`Database`] as follows: +/// +/// * [`Database::NAME`] is set to the value in the `#[database("name")]` +/// attribute. +/// +/// This names the database, providing an anchor to configure the database via +/// `Rocket.toml` or any other configuration source. Specifically, the +/// configuration in `databases.name` is used to configure the driver. +/// +/// * [`Database::Pool`] is set to the wrapped type: `PoolType` above. The type +/// must implement [`Pool`]. +/// +/// To meet the required [`Database`] supertrait bounds, this derive also +/// generates implementations for: +/// +/// * `From` +/// +/// * `Deref` +/// +/// * `DerefMut` +/// +/// * `FromRequest<'_> for &Db` +/// +/// * `Sentinel for &Db` +/// +/// The `Deref` impls enable accessing the database pool directly from +/// references `&Db` or `&mut Db`. To force a dereference to the underlying +/// type, use `&db.0` or `&**db` or their `&mut` variants. +/// +/// [`Database`]: ../sea_orm_rocket/trait.Database.html +/// [`Database::NAME`]: ../sea_orm_rocket/trait.Database.html#associatedconstant.NAME +/// [`Database::Pool`]: ../sea_orm_rocket/trait.Database.html#associatedtype.Pool +/// [`Pool`]: ../sea_orm_rocket/trait.Pool.html +#[proc_macro_derive(Database, attributes(database))] +pub fn derive_database(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + crate::database::derive_database(input) +} diff --git a/sea-orm-rocket/lib/Cargo.toml b/sea-orm-rocket/lib/Cargo.toml new file mode 100644 index 00000000..3a586fe1 --- /dev/null +++ b/sea-orm-rocket/lib/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "sea-orm-rocket" +version = "0.1.0" +authors = ["Sergio Benitez ", "Jeb Rosen "] +description = "SeaORM Rocket support crate" +repository = "https://github.com/SeaQL/sea-orm" +readme = "../README.md" +keywords = ["rocket", "framework", "database", "pools"] +license = "MIT OR Apache-2.0" +edition = "2018" + +[package.metadata.docs.rs] +all-features = true + +[dependencies.rocket] +git = "https://github.com/SergioBenitez/Rocket.git" +version = "0.5.0-rc.1" +default-features = false + +[dependencies.sea-orm-rocket-codegen] +path = "../codegen" +version = "0.1.0-rc" + +[dev-dependencies.rocket] +git = "https://github.com/SergioBenitez/Rocket.git" +default-features = false +features = ["json"] diff --git a/sea-orm-rocket/lib/src/config.rs b/sea-orm-rocket/lib/src/config.rs new file mode 100644 index 00000000..b30c2ce5 --- /dev/null +++ b/sea-orm-rocket/lib/src/config.rs @@ -0,0 +1,83 @@ +use rocket::serde::{Deserialize, Serialize}; + +/// Base configuration for all database drivers. +/// +/// A dictionary matching this structure is extracted from the active +/// [`Figment`](crate::figment::Figment), scoped to `databases.name`, where +/// `name` is the name of the database, by the +/// [`Initializer`](crate::Initializer) fairing on ignition and used to +/// configure the relevant database and database pool. +/// +/// With the default provider, these parameters are typically configured in a +/// `Rocket.toml` file: +/// +/// ```toml +/// [default.databases.db_name] +/// url = "/path/to/db.sqlite" +/// +/// # only `url` is required. `Initializer` provides defaults for the rest. +/// min_connections = 64 +/// max_connections = 1024 +/// connect_timeout = 5 +/// idle_timeout = 120 +/// ``` +/// +/// Alternatively, a custom provider can be used. For example, a custom `Figment` +/// with a global `databases.name` configuration: +/// +/// ```rust +/// # use rocket::launch; +/// #[launch] +/// fn rocket() -> _ { +/// let figment = rocket::Config::figment() +/// .merge(("databases.name", sea_orm_rocket::Config { +/// url: "db:specific@config&url".into(), +/// min_connections: None, +/// max_connections: 1024, +/// connect_timeout: 3, +/// idle_timeout: None, +/// })); +/// +/// rocket::custom(figment) +/// } +/// ``` +/// +/// For general information on configuration in Rocket, see [`rocket::config`]. +/// For higher-level details on configuring a database, see the [crate-level +/// docs](crate#configuration). +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(crate = "rocket::serde")] +pub struct Config { + /// Database-specific connection and configuration URL. + /// + /// The format of the URL is database specific; consult your database's + /// documentation. + pub url: String, + /// Minimum number of connections to maintain in the pool. + /// + /// **Note:** `deadpool` drivers do not support and thus ignore this value. + /// + /// _Default:_ `None`. + pub min_connections: Option, + /// Maximum number of connections to maintain in the pool. + /// + /// _Default:_ `workers * 4`. + pub max_connections: usize, + /// Number of seconds to wait for a connection before timing out. + /// + /// If the timeout elapses before a connection can be made or retrieved from + /// a pool, an error is returned. + /// + /// _Default:_ `5`. + pub connect_timeout: u64, + /// Maximum number of seconds to keep a connection alive for. + /// + /// After a connection is established, it is maintained in a pool for + /// efficient connection retrieval. When an `idle_timeout` is set, that + /// connection will be closed after the timeout elapses. If an + /// `idle_timeout` is not specified, the behavior is driver specific but + /// typically defaults to keeping a connection active indefinitely. + /// + /// _Default:_ `None`. + pub idle_timeout: Option, +} diff --git a/sea-orm-rocket/lib/src/database.rs b/sea-orm-rocket/lib/src/database.rs new file mode 100644 index 00000000..6eb98385 --- /dev/null +++ b/sea-orm-rocket/lib/src/database.rs @@ -0,0 +1,249 @@ +use std::marker::PhantomData; +use std::ops::{DerefMut}; + +use rocket::{error, info_, Build, Ignite, Phase, Rocket, Sentinel}; +use rocket::fairing::{self, Fairing, Info, Kind}; +use rocket::request::{FromRequest, Outcome, Request}; +use rocket::http::Status; + +use rocket::yansi::Paint; +use rocket::figment::providers::Serialized; + +use crate::Pool; + +/// Derivable trait which ties a database [`Pool`] with a configuration name. +/// +/// This trait should rarely, if ever, be implemented manually. Instead, it +/// should be derived: +/// +/// ```ignore +/// use sea_orm_rocket::{Database}; +/// # use sea_orm_rocket::MockPool as SeaOrmPool; +/// +/// #[derive(Database, Debug)] +/// #[database("sea_orm")] +/// struct Db(SeaOrmPool); +/// +/// #[launch] +/// fn rocket() -> _ { +/// rocket::build().attach(Db::init()) +/// } +/// ``` +/// +/// See the [`Database` derive](derive@crate::Database) for details. +pub trait Database: From + DerefMut + Send + Sync + 'static { + /// The [`Pool`] type of connections to this database. + /// + /// When `Database` is derived, this takes the value of the `Inner` type in + /// `struct Db(Inner)`. + type Pool: Pool; + + /// The configuration name for this database. + /// + /// When `Database` is derived, this takes the value `"name"` in the + /// `#[database("name")]` attribute. + const NAME: &'static str; + + /// Returns a fairing that initializes the database and its connection pool. + /// + /// # Example + /// + /// ```rust + /// # mod _inner { + /// # use rocket::launch; + /// use sea_orm_rocket::{Database}; + /// # use sea_orm_rocket::MockPool as SeaOrmPool; + /// + /// #[derive(Database)] + /// #[database("sea_orm")] + /// struct Db(SeaOrmPool); + /// + /// #[launch] + /// fn rocket() -> _ { + /// rocket::build().attach(Db::init()) + /// } + /// # } + /// ``` + fn init() -> Initializer { + Initializer::new() + } + + /// Returns a reference to the initialized database in `rocket`. The + /// initializer fairing returned by `init()` must have already executed for + /// `Option` to be `Some`. This is guaranteed to be the case if the fairing + /// is attached and either: + /// + /// * Rocket is in the [`Orbit`](rocket::Orbit) phase. That is, the + /// application is running. This is always the case in request guards + /// and liftoff fairings, + /// * _or_ Rocket is in the [`Build`](rocket::Build) or + /// [`Ignite`](rocket::Ignite) phase and the `Initializer` fairing has + /// already been run. This is the case in all fairing callbacks + /// corresponding to fairings attached _after_ the `Initializer` + /// fairing. + /// + /// # Example + /// + /// Run database migrations in an ignite fairing. It is imperative that the + /// migration fairing be registered _after_ the `init()` fairing. + /// + /// ```rust + /// # mod _inner { + /// # use rocket::launch; + /// use rocket::{Rocket, Build}; + /// use rocket::fairing::{self, AdHoc}; + /// + /// use sea_orm_rocket::{Database}; + /// # use sea_orm_rocket::MockPool as SeaOrmPool; + /// + /// #[derive(Database)] + /// #[database("sea_orm")] + /// struct Db(SeaOrmPool); + /// + /// async fn run_migrations(rocket: Rocket) -> fairing::Result { + /// if let Some(db) = Db::fetch(&rocket) { + /// // run migrations using `db`. get the inner type with &db.0. + /// Ok(rocket) + /// } else { + /// Err(rocket) + /// } + /// } + /// + /// #[launch] + /// fn rocket() -> _ { + /// rocket::build() + /// .attach(Db::init()) + /// .attach(AdHoc::try_on_ignite("DB Migrations", run_migrations)) + /// } + /// # } + /// ``` + fn fetch(rocket: &Rocket

) -> Option<&Self> { + if let Some(db) = rocket.state() { + return Some(db); + } + + let dbtype = std::any::type_name::(); + let fairing = Paint::default(format!("{}::init()", dbtype)).bold(); + error!("Attempted to fetch unattached database `{}`.", Paint::default(dbtype).bold()); + info_!("`{}` fairing must be attached prior to using this database.", fairing); + None + } +} + +/// A [`Fairing`] which initializes a [`Database`] and its connection pool. +/// +/// A value of this type can be created for any type `D` that implements +/// [`Database`] via the [`Database::init()`] method on the type. Normally, a +/// value of this type _never_ needs to be constructed directly. This +/// documentation exists purely as a reference. +/// +/// This fairing initializes a database pool. Specifically, it: +/// +/// 1. Reads the configuration at `database.db_name`, where `db_name` is +/// [`Database::NAME`]. +/// +/// 2. Sets [`Config`](crate::Config) defaults on the configuration figment. +/// +/// 3. Calls [`Pool::init()`]. +/// +/// 4. Stores the database instance in managed storage, retrievable via +/// [`Database::fetch()`]. +/// +/// The name of the fairing itself is `Initializer`, with `D` replaced with +/// the type name `D` unless a name is explicitly provided via +/// [`Self::with_name()`]. +pub struct Initializer(Option<&'static str>, PhantomData D>); + +/// A request guard which retrieves a single connection to a [`Database`]. +/// +/// For a database type of `Db`, a request guard of `Connection` retrieves a +/// single connection to `Db`. +/// +/// The request guard succeeds if the database was initialized by the +/// [`Initializer`] fairing and a connection is available within +/// [`connect_timeout`](crate::Config::connect_timeout) seconds. +/// * If the `Initializer` fairing was _not_ attached, the guard _fails_ with +/// status `InternalServerError`. A [`Sentinel`] guards this condition, and so +/// this type of failure is unlikely to occur. A `None` error is returned. +/// * If a connection is not available within `connect_timeout` seconds or +/// another error occurs, the gaurd _fails_ with status `ServiceUnavailable` +/// and the error is returned in `Some`. +/// +pub struct Connection<'a, D: Database>(&'a ::Connection); + +impl Initializer { + /// Returns a database initializer fairing for `D`. + /// + /// This method should never need to be called manually. See the [crate + /// docs](crate) for usage information. + pub fn new() -> Self { + Self(None, std::marker::PhantomData) + } + + /// Returns a database initializer fairing for `D` with name `name`. + /// + /// This method should never need to be called manually. See the [crate + /// docs](crate) for usage information. + pub fn with_name(name: &'static str) -> Self { + Self(Some(name), std::marker::PhantomData) + } +} + +impl<'a, D: Database> Connection<'a, D> { + /// Returns the internal connection value. See the [`Connection` Deref + /// column](crate#supported-drivers) for the expected type of this value. + /// + /// Note that `Connection` derefs to the internal connection type, so + /// using this method is likely unnecessary. See [deref](Connection#deref) + /// for examples. + pub fn into_inner(self) -> &'a ::Connection { + self.0 + } +} + +#[rocket::async_trait] +impl Fairing for Initializer { + fn info(&self) -> Info { + Info { + name: self.0.unwrap_or_else(std::any::type_name::), + kind: Kind::Ignite, + } + } + + async fn on_ignite(&self, rocket: Rocket) -> fairing::Result { + let workers: usize = rocket.figment() + .extract_inner(rocket::Config::WORKERS) + .unwrap_or_else(|_| rocket::Config::default().workers); + + let figment = rocket.figment() + .focus(&format!("databases.{}", D::NAME)) + .merge(Serialized::default("max_connections", workers * 4)) + .merge(Serialized::default("connect_timeout", 5)); + + match ::init(&figment).await { + Ok(pool) => Ok(rocket.manage(D::from(pool))), + Err(e) => { + error!("failed to initialize database: {}", e); + Err(rocket) + } + } + } +} + +#[rocket::async_trait] +impl<'r, D: Database> FromRequest<'r> for Connection<'r, D> { + type Error = Option<::Error>; + + async fn from_request(req: &'r Request<'_>) -> Outcome { + match D::fetch(req.rocket()) { + Some(pool) => Outcome::Success(Connection(pool.borrow())), + None => Outcome::Failure((Status::InternalServerError, None)), + } + } +} + +impl Sentinel for Connection<'_, D> { + fn abort(rocket: &Rocket) -> bool { + D::fetch(rocket).is_none() + } +} diff --git a/sea-orm-rocket/lib/src/error.rs b/sea-orm-rocket/lib/src/error.rs new file mode 100644 index 00000000..69bae106 --- /dev/null +++ b/sea-orm-rocket/lib/src/error.rs @@ -0,0 +1,35 @@ +use std::fmt; + +/// A general error type for use by [`Pool`](crate::Pool#implementing) +/// implementors and returned by the [`Connection`](crate::Connection) request +/// guard. +#[derive(Debug)] +pub enum Error { + /// An error that occured during database/pool initialization. + Init(A), + + /// An error that ocurred while retrieving a connection from the pool. + Get(B), + + /// A [`Figment`](crate::figment::Figment) configuration error. + Config(crate::figment::Error), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::Init(e) => write!(f, "failed to initialize database: {}", e), + Error::Get(e) => write!(f, "failed to get db connection: {}", e), + Error::Config(e) => write!(f, "bad configuration: {}", e), + } + } +} + +impl std::error::Error for Error + where A: fmt::Debug + fmt::Display, B: fmt::Debug + fmt::Display {} + +impl From for Error { + fn from(e: crate::figment::Error) -> Self { + Self::Config(e) + } +} diff --git a/sea-orm-rocket/lib/src/lib.rs b/sea-orm-rocket/lib/src/lib.rs new file mode 100644 index 00000000..4b98cd0b --- /dev/null +++ b/sea-orm-rocket/lib/src/lib.rs @@ -0,0 +1,20 @@ +//! SeaORM Rocket support crate. +#![deny(missing_docs)] + +/// Re-export of the `figment` crate. +#[doc(inline)] +pub use rocket::figment; + +pub use rocket; + +mod database; +mod error; +mod pool; +mod config; + +pub use self::database::{Connection, Database, Initializer}; +pub use self::error::Error; +pub use self::pool::{Pool, MockPool}; +pub use self::config::Config; + +pub use sea_orm_rocket_codegen::*; diff --git a/sea-orm-rocket/lib/src/pool.rs b/sea-orm-rocket/lib/src/pool.rs new file mode 100644 index 00000000..9dd8b75a --- /dev/null +++ b/sea-orm-rocket/lib/src/pool.rs @@ -0,0 +1,63 @@ +use rocket::figment::Figment; + +/// Generic [`Database`](crate::Database) driver connection pool trait. +/// +/// This trait provides a generic interface to various database pooling +/// implementations in the Rust ecosystem. It can be implemented by anyone, but +/// this crate provides implementations for common drivers. +/// ``` +#[rocket::async_trait] +pub trait Pool: Sized + Send + Sync + 'static { + /// The connection type managed by this pool, returned by [`Self::get()`]. + type Connection; + + /// The error type returned by [`Self::init()`] and [`Self::get()`]. + type Error: std::error::Error; + + /// Constructs a pool from a [Value](rocket::figment::value::Value). + /// + /// It is up to each implementor of `Pool` to define its accepted + /// configuration value(s) via the `Config` associated type. Most + /// integrations provided in `sea_orm_rocket` use [`Config`], which + /// accepts a (required) `url` and an (optional) `pool_size`. + /// + /// ## Errors + /// + /// This method returns an error if the configuration is not compatible, or + /// if creating a pool failed due to an unavailable database server, + /// insufficient resources, or another database-specific error. + async fn init(figment: &Figment) -> Result; + + /// Borrow the inner connection + fn borrow(&self) -> &Self::Connection; +} + +#[derive(Debug)] +/// A mock object which impl `Pool`, for testing only +pub struct MockPool; + +#[derive(Debug)] +pub struct MockPoolErr; + +impl std::fmt::Display for MockPoolErr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl std::error::Error for MockPoolErr {} + +#[rocket::async_trait] +impl Pool for MockPool { + type Error = MockPoolErr; + + type Connection = bool; + + async fn init(_figment: &Figment) -> Result { + Ok(MockPool) + } + + fn borrow(&self) -> &Self::Connection { + &true + } +}