From c673017b975e3cf9e3127d6719b8fc97a140f5f3 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Wed, 13 Oct 2021 02:25:39 +0800 Subject: [PATCH] ConnectOptions --- src/database/mod.rs | 111 ++++++++++++++++++++++++++++++++---- src/driver/sqlx_mysql.rs | 17 ++++-- src/driver/sqlx_postgres.rs | 17 ++++-- src/driver/sqlx_sqlite.rs | 20 ++++--- 4 files changed, 136 insertions(+), 29 deletions(-) diff --git a/src/database/mod.rs b/src/database/mod.rs index a1dfea93..33a6b7c5 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + mod connection; mod db_connection; #[cfg(feature = "mock")] @@ -19,27 +21,116 @@ use crate::DbErr; #[derive(Debug, Default)] pub struct Database; +#[derive(Debug)] +pub struct ConnectOptions { + pub(crate) url: String, + pub(crate) max_connections: Option, + pub(crate) min_connections: Option, + pub(crate) connect_timeout: Option, + pub(crate) idle_timeout: Option, +} + impl Database { - pub async fn connect(string: &str) -> Result { + pub async fn connect(opt: C) -> Result + where + C: Into, + { + let opt: ConnectOptions = opt.into(); + #[cfg(feature = "sqlx-mysql")] - if DbBackend::MySql.is_prefix_of(string) { - return crate::SqlxMySqlConnector::connect(string).await; + if DbBackend::MySql.is_prefix_of(&opt.url) { + return crate::SqlxMySqlConnector::connect(opt).await; } #[cfg(feature = "sqlx-postgres")] - if DbBackend::Postgres.is_prefix_of(string) { - return crate::SqlxPostgresConnector::connect(string).await; + if DbBackend::Postgres.is_prefix_of(&opt.url) { + return crate::SqlxPostgresConnector::connect(opt).await; } #[cfg(feature = "sqlx-sqlite")] - if DbBackend::Sqlite.is_prefix_of(string) { - return crate::SqlxSqliteConnector::connect(string).await; + if DbBackend::Sqlite.is_prefix_of(&opt.url) { + return crate::SqlxSqliteConnector::connect(opt).await; } #[cfg(feature = "mock")] - if crate::MockDatabaseConnector::accepts(string) { - return crate::MockDatabaseConnector::connect(string).await; + if crate::MockDatabaseConnector::accepts(&opt.url) { + return crate::MockDatabaseConnector::connect(&opt.url).await; } Err(DbErr::Conn(format!( "The connection string '{}' has no supporting driver.", - string + opt.url ))) } } + +impl From<&str> for ConnectOptions { + fn from(string: &str) -> ConnectOptions { + ConnectOptions::from_str(string) + } +} + +impl From<&String> for ConnectOptions { + fn from(string: &String) -> ConnectOptions { + ConnectOptions::from_str(string.as_str()) + } +} + +impl From for ConnectOptions { + fn from(string: String) -> ConnectOptions { + ConnectOptions::new(string) + } +} + +impl ConnectOptions { + pub fn new(url: String) -> Self { + Self { + url, + max_connections: None, + min_connections: None, + connect_timeout: None, + idle_timeout: None, + } + } + + fn from_str(url: &str) -> Self { + Self::new(url.to_owned()) + } + + #[cfg(feature = "sqlx-dep")] + pub fn pool_options(self) -> sqlx::pool::PoolOptions + where + DB: sqlx::Database, + { + let mut opt = sqlx::pool::PoolOptions::new(); + if let Some(max_connections) = self.max_connections { + opt = opt.max_connections(max_connections); + } + if let Some(min_connections) = self.min_connections { + opt = opt.min_connections(min_connections); + } + if let Some(connect_timeout) = self.connect_timeout { + opt = opt.connect_timeout(connect_timeout); + } + if let Some(idle_timeout) = self.idle_timeout { + opt = opt.idle_timeout(Some(idle_timeout)); + } + opt + } + + /// Set the maximum number of connections of the pool + pub fn max_connections(&mut self, value: u32) { + self.max_connections = Some(value); + } + + /// Set the minimum number of connections of the pool + pub fn min_connections(&mut self, value: u32) { + self.min_connections = Some(value); + } + + /// Set the timeout duration when acquiring a connection + pub fn connect_timeout(&mut self, value: Duration) { + self.connect_timeout = Some(value); + } + + /// Set the idle duration before closing a connection + pub fn idle_timeout(&mut self, value: Duration) { + self.idle_timeout = Some(value); + } +} diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index 6b6f9507..08989a25 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -1,7 +1,7 @@ use std::{future::Future, pin::Pin}; use sqlx::{ - mysql::{MySqlArguments, MySqlQueryResult, MySqlRow}, + mysql::{MySqlArguments, MySqlConnectOptions, MySqlQueryResult, MySqlRow}, MySql, MySqlPool, }; @@ -9,8 +9,8 @@ sea_query::sea_query_driver_mysql!(); use sea_query_driver_mysql::bind_query; use crate::{ - debug_print, error::*, executor::*, DatabaseConnection, DatabaseTransaction, QueryStream, - Statement, TransactionError, + debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, + QueryStream, Statement, TransactionError, }; use super::sqlx_common::*; @@ -25,11 +25,16 @@ pub struct SqlxMySqlPoolConnection { impl SqlxMySqlConnector { pub fn accepts(string: &str) -> bool { - string.starts_with("mysql://") + string.starts_with("mysql://") && string.parse::().is_ok() } - pub async fn connect(string: &str) -> Result { - if let Ok(pool) = MySqlPool::connect(string).await { + pub async fn connect(options: ConnectOptions) -> Result { + let opt = options + .url + .parse::() + .map_err(|e| DbErr::Conn(e.to_string()))?; + // opt.disable_statement_logging(); + if let Ok(pool) = options.pool_options().connect_with(opt).await { Ok(DatabaseConnection::SqlxMySqlPoolConnection( SqlxMySqlPoolConnection { pool }, )) diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index 13cb51cd..a52a7682 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -1,7 +1,7 @@ use std::{future::Future, pin::Pin}; use sqlx::{ - postgres::{PgArguments, PgQueryResult, PgRow}, + postgres::{PgArguments, PgConnectOptions, PgQueryResult, PgRow}, PgPool, Postgres, }; @@ -9,8 +9,8 @@ sea_query::sea_query_driver_postgres!(); use sea_query_driver_postgres::bind_query; use crate::{ - debug_print, error::*, executor::*, DatabaseConnection, DatabaseTransaction, QueryStream, - Statement, TransactionError, + debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, + QueryStream, Statement, TransactionError, }; use super::sqlx_common::*; @@ -25,11 +25,16 @@ pub struct SqlxPostgresPoolConnection { impl SqlxPostgresConnector { pub fn accepts(string: &str) -> bool { - string.starts_with("postgres://") + string.starts_with("postgres://") && string.parse::().is_ok() } - pub async fn connect(string: &str) -> Result { - if let Ok(pool) = PgPool::connect(string).await { + pub async fn connect(options: ConnectOptions) -> Result { + let opt = options + .url + .parse::() + .map_err(|e| DbErr::Conn(e.to_string()))?; + // opt.disable_statement_logging(); + if let Ok(pool) = options.pool_options().connect_with(opt).await { Ok(DatabaseConnection::SqlxPostgresPoolConnection( SqlxPostgresPoolConnection { pool }, )) diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index ea2e05e3..41824350 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -1,7 +1,7 @@ use std::{future::Future, pin::Pin}; use sqlx::{ - sqlite::{SqliteArguments, SqlitePoolOptions, SqliteQueryResult, SqliteRow}, + sqlite::{SqliteArguments, SqliteConnectOptions, SqliteQueryResult, SqliteRow}, Sqlite, SqlitePool, }; @@ -9,8 +9,8 @@ sea_query::sea_query_driver_sqlite!(); use sea_query_driver_sqlite::bind_query; use crate::{ - debug_print, error::*, executor::*, DatabaseConnection, DatabaseTransaction, QueryStream, - Statement, TransactionError, + debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, + QueryStream, Statement, TransactionError, }; use super::sqlx_common::*; @@ -25,13 +25,19 @@ pub struct SqlxSqlitePoolConnection { impl SqlxSqliteConnector { pub fn accepts(string: &str) -> bool { - string.starts_with("sqlite:") + string.starts_with("sqlite:") && string.parse::().is_ok() } - pub async fn connect(string: &str) -> Result { - if let Ok(pool) = SqlitePoolOptions::new() + pub async fn connect(options: ConnectOptions) -> Result { + let opt = options + .url + .parse::() + .map_err(|e| DbErr::Conn(e.to_string()))?; + // opt.disable_statement_logging(); + if let Ok(pool) = options + .pool_options() .max_connections(1) - .connect(string) + .connect_with(opt) .await { Ok(DatabaseConnection::SqlxSqlitePoolConnection(