use sea_query::Values; use std::{future::Future, pin::Pin, sync::Arc}; use sqlx::{ mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow}, pool::PoolConnection, Connection, Executor, MySql, MySqlPool, }; use sea_query_binder::SqlxValues; use tracing::instrument; use crate::{ debug_print, error::*, executor::*, AccessMode, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbBackend, IsolationLevel, QueryStream, Statement, TransactionError, }; use super::sqlx_common::*; /// Defines the [sqlx::mysql] connector #[derive(Debug)] pub struct SqlxMySqlConnector; /// Defines a sqlx MySQL pool #[derive(Clone)] pub struct SqlxMySqlPoolConnection { pub(crate) pool: MySqlPool, metric_callback: Option, } impl std::fmt::Debug for SqlxMySqlPoolConnection { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool) } } impl SqlxMySqlConnector { /// Check if the URI provided corresponds to `mysql://` for a MySQL database pub fn accepts(string: &str) -> bool { string.starts_with("mysql://") && string.parse::().is_ok() } /// Add configuration options for the MySQL database #[instrument(level = "trace")] pub async fn connect(options: ConnectOptions) -> Result { let mut opt = options .url .parse::() .map_err(sqlx_error_to_conn_err)?; use sqlx::ConnectOptions; if !options.sqlx_logging { opt = opt.disable_statement_logging(); } else { opt = opt.log_statements(options.sqlx_logging_level); } match options.pool_options().connect_with(opt).await { Ok(pool) => Ok(DatabaseConnection::SqlxMySqlPoolConnection( SqlxMySqlPoolConnection { pool, metric_callback: None, }, )), Err(e) => Err(sqlx_error_to_conn_err(e)), } } } impl SqlxMySqlConnector { /// Instantiate a sqlx pool connection to a [DatabaseConnection] pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection { DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection { pool, metric_callback: None, }) } } impl SqlxMySqlPoolConnection { /// Execute a [Statement] on a MySQL backend #[instrument(level = "trace")] pub async fn execute(&self, stmt: Statement) -> Result { debug_print!("{}", stmt); let query = sqlx_query(&stmt); let mut conn = self.pool.acquire().await.map_err(conn_acquire_err)?; crate::metric::metric!(self.metric_callback, &stmt, { match query.execute(&mut *conn).await { Ok(res) => Ok(res.into()), Err(err) => Err(sqlx_error_to_exec_err(err)), } }) } /// Execute an unprepared SQL statement on a MySQL backend #[instrument(level = "trace")] pub async fn execute_unprepared(&self, sql: &str) -> Result { debug_print!("{}", sql); let conn = &mut self.pool.acquire().await.map_err(conn_acquire_err)?; match conn.execute(sql).await { Ok(res) => Ok(res.into()), Err(err) => Err(sqlx_error_to_exec_err(err)), } } /// Get one result from a SQL query. Returns [Option::None] if no match was found #[instrument(level = "trace")] pub async fn query_one(&self, stmt: Statement) -> Result, DbErr> { debug_print!("{}", stmt); let query = sqlx_query(&stmt); let mut conn = self.pool.acquire().await.map_err(conn_acquire_err)?; crate::metric::metric!(self.metric_callback, &stmt, { match query.fetch_one(&mut *conn).await { Ok(row) => Ok(Some(row.into())), Err(err) => match err { sqlx::Error::RowNotFound => Ok(None), _ => Err(sqlx_error_to_query_err(err)), }, } }) } /// Get the results of a query returning them as a Vec<[QueryResult]> #[instrument(level = "trace")] pub async fn query_all(&self, stmt: Statement) -> Result, DbErr> { debug_print!("{}", stmt); let query = sqlx_query(&stmt); let mut conn = self.pool.acquire().await.map_err(conn_acquire_err)?; crate::metric::metric!(self.metric_callback, &stmt, { match query.fetch_all(&mut *conn).await { Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()), Err(err) => Err(sqlx_error_to_query_err(err)), } }) } /// Stream the results of executing a SQL query #[instrument(level = "trace")] pub async fn stream(&self, stmt: Statement) -> Result { debug_print!("{}", stmt); let conn = self.pool.acquire().await.map_err(conn_acquire_err)?; Ok(QueryStream::from(( conn, stmt, self.metric_callback.clone(), ))) } /// Bundle a set of SQL statements that execute together. #[instrument(level = "trace")] pub async fn begin( &self, isolation_level: Option, access_mode: Option, ) -> Result { let conn = self.pool.acquire().await.map_err(conn_acquire_err)?; DatabaseTransaction::new_mysql( conn, self.metric_callback.clone(), isolation_level, access_mode, ) .await } /// Create a MySQL transaction #[instrument(level = "trace", skip(callback))] pub async fn transaction( &self, callback: F, isolation_level: Option, access_mode: Option, ) -> Result> where F: for<'b> FnOnce( &'b DatabaseTransaction, ) -> Pin> + Send + 'b>> + Send, T: Send, E: std::error::Error + Send, { let conn = self.pool.acquire().await.map_err(conn_acquire_err)?; let transaction = DatabaseTransaction::new_mysql( conn, self.metric_callback.clone(), isolation_level, access_mode, ) .await .map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await } pub(crate) fn set_metric_callback(&mut self, callback: F) where F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static, { self.metric_callback = Some(Arc::new(callback)); } /// Checks if a connection to the database is still valid. pub async fn ping(&self) -> Result<(), DbErr> { let conn = &mut self.pool.acquire().await.map_err(conn_acquire_err)?; match conn.ping().await { Ok(_) => Ok(()), Err(err) => Err(sqlx_error_to_conn_err(err)), } } /// Explicitly close the MySQL connection pub async fn close(self) -> Result<(), DbErr> { self.pool.close().await; Ok(()) } } impl From for QueryResult { fn from(row: MySqlRow) -> QueryResult { QueryResult { row: QueryResultRow::SqlxMySql(row), } } } impl From for ExecResult { fn from(result: MySqlQueryResult) -> ExecResult { ExecResult { result: ExecResultHolder::SqlxMySql(result), } } } pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> { let values = stmt .values .as_ref() .map_or(Values(Vec::new()), |values| values.clone()); sqlx::query_with(&stmt.sql, SqlxValues(values)) } pub(crate) async fn set_transaction_config( conn: &mut PoolConnection, isolation_level: Option, access_mode: Option, ) -> Result<(), DbErr> { if let Some(isolation_level) = isolation_level { let stmt = Statement { sql: format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}"), values: None, db_backend: DbBackend::MySql, }; let query = sqlx_query(&stmt); conn.execute(query).await.map_err(sqlx_error_to_exec_err)?; } if let Some(access_mode) = access_mode { let stmt = Statement { sql: format!("SET TRANSACTION {access_mode}"), values: None, db_backend: DbBackend::MySql, }; let query = sqlx_query(&stmt); conn.execute(query).await.map_err(sqlx_error_to_exec_err)?; } Ok(()) }