diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index b9249d99..3a090b6f 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -166,7 +166,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?, #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => { - crate::QueryStream::from((Arc::clone(conn), stmt)) + crate::QueryStream::from((Arc::clone(conn), stmt, None)) } DatabaseConnection::Disconnected => panic!("Disconnected"), }) @@ -184,7 +184,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await, #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => { - DatabaseTransaction::new_mock(Arc::clone(conn)).await + DatabaseTransaction::new_mock(Arc::clone(conn), None).await } DatabaseConnection::Disconnected => panic!("Disconnected"), } @@ -213,7 +213,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await, #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => { - let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)) + let transaction = DatabaseTransaction::new_mock(Arc::clone(conn), None) .await .map_err(TransactionError::Connection)?; transaction.run(_callback).await @@ -245,6 +245,24 @@ impl DatabaseConnection { } } +impl DatabaseConnection { + /// Sets a callback to metric this connection + pub fn set_metric_callback(&mut self, callback: F) + where + F: Into, + { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.set_metric_callback(callback), + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.set_metric_callback(callback), + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.set_metric_callback(callback), + _ => {}, + } + } +} + impl DbBackend { /// Check if the URI is the same as the specified database backend. /// Returns true if they match. diff --git a/src/database/stream/query.rs b/src/database/stream/query.rs index 144a6b83..2147a8dc 100644 --- a/src/database/stream/query.rs +++ b/src/database/stream/query.rs @@ -21,36 +21,37 @@ use crate::{DbErr, InnerConnection, QueryResult, Statement}; pub struct QueryStream { stmt: Statement, conn: InnerConnection, - #[borrows(mut conn, stmt)] + metric_callback: Option, + #[borrows(mut conn, stmt, metric_callback)] #[not_covariant] stream: Pin> + 'this>>, } #[cfg(feature = "sqlx-mysql")] -impl From<(PoolConnection, Statement)> for QueryStream { - fn from((conn, stmt): (PoolConnection, Statement)) -> Self { - QueryStream::build(stmt, InnerConnection::MySql(conn)) +impl From<(PoolConnection, Statement, Option)> for QueryStream { + fn from((conn, stmt, metric_callback): (PoolConnection, Statement, Option)) -> Self { + QueryStream::build(stmt, InnerConnection::MySql(conn), metric_callback) } } #[cfg(feature = "sqlx-postgres")] -impl From<(PoolConnection, Statement)> for QueryStream { - fn from((conn, stmt): (PoolConnection, Statement)) -> Self { - QueryStream::build(stmt, InnerConnection::Postgres(conn)) +impl From<(PoolConnection, Statement, Option)> for QueryStream { + fn from((conn, stmt, metric_callback): (PoolConnection, Statement, Option)) -> Self { + QueryStream::build(stmt, InnerConnection::Postgres(conn), metric_callback) } } #[cfg(feature = "sqlx-sqlite")] -impl From<(PoolConnection, Statement)> for QueryStream { - fn from((conn, stmt): (PoolConnection, Statement)) -> Self { - QueryStream::build(stmt, InnerConnection::Sqlite(conn)) +impl From<(PoolConnection, Statement, Option)> for QueryStream { + fn from((conn, stmt, metric_callback): (PoolConnection, Statement, Option)) -> Self { + QueryStream::build(stmt, InnerConnection::Sqlite(conn), metric_callback) } } #[cfg(feature = "mock")] -impl From<(Arc, Statement)> for QueryStream { - fn from((conn, stmt): (Arc, Statement)) -> Self { - QueryStream::build(stmt, InnerConnection::Mock(conn)) +impl From<(Arc, Statement, Option)> for QueryStream { + fn from((conn, stmt, metric_callback): (Arc, Statement, Option)) -> Self { + QueryStream::build(stmt, InnerConnection::Mock(conn), metric_callback) } } @@ -61,12 +62,13 @@ impl std::fmt::Debug for QueryStream { } impl QueryStream { - #[instrument(level = "trace")] - fn build(stmt: Statement, conn: InnerConnection) -> QueryStream { + #[instrument(level = "trace", skip(metric_callback))] + fn build(stmt: Statement, conn: InnerConnection, metric_callback: Option) -> QueryStream { QueryStreamBuilder { stmt, conn, - stream_builder: |conn, stmt| { + metric_callback, + stream_builder: |conn, stmt, metric_callback| { match conn { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(c) => { @@ -77,7 +79,7 @@ impl QueryStream { .map_ok(Into::into) .map_err(crate::sqlx_error_to_query_err), ); - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: stmt, @@ -95,7 +97,7 @@ impl QueryStream { .map_ok(Into::into) .map_err(crate::sqlx_error_to_query_err), ); - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: stmt, @@ -113,7 +115,7 @@ impl QueryStream { .map_ok(Into::into) .map_err(crate::sqlx_error_to_query_err), ); - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: stmt, diff --git a/src/database/stream/transaction.rs b/src/database/stream/transaction.rs index d2643f63..fe1212d8 100644 --- a/src/database/stream/transaction.rs +++ b/src/database/stream/transaction.rs @@ -21,7 +21,8 @@ use crate::{DbErr, InnerConnection, QueryResult, Statement}; pub struct TransactionStream<'a> { stmt: Statement, conn: MutexGuard<'a, InnerConnection>, - #[borrows(mut conn, stmt)] + metric_callback: Option, + #[borrows(mut conn, stmt, metric_callback)] #[not_covariant] stream: Pin> + 'this>>, } @@ -33,15 +34,17 @@ impl<'a> std::fmt::Debug for TransactionStream<'a> { } impl<'a> TransactionStream<'a> { - #[instrument(level = "trace")] + #[instrument(level = "trace", skip(metric_callback))] pub(crate) async fn build( conn: MutexGuard<'a, InnerConnection>, stmt: Statement, + metric_callback: Option, ) -> TransactionStream<'a> { TransactionStreamAsyncBuilder { stmt, conn, - stream_builder: |conn, stmt| { + metric_callback, + stream_builder: |conn, stmt, metric_callback| { Box::pin(async move { match conn.deref_mut() { #[cfg(feature = "sqlx-mysql")] @@ -53,7 +56,7 @@ impl<'a> TransactionStream<'a> { .map_ok(Into::into) .map_err(crate::sqlx_error_to_query_err), ) as Pin>>>; - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: stmt, @@ -71,7 +74,7 @@ impl<'a> TransactionStream<'a> { .map_ok(Into::into) .map_err(crate::sqlx_error_to_query_err), ) as Pin>>>; - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: stmt, @@ -89,7 +92,7 @@ impl<'a> TransactionStream<'a> { .map_ok(Into::into) .map_err(crate::sqlx_error_to_query_err), ) as Pin>>>; - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: stmt, diff --git a/src/database/transaction.rs b/src/database/transaction.rs index e26067db..e6dbf83a 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -17,6 +17,7 @@ pub struct DatabaseTransaction { conn: Arc>, backend: DbBackend, open: bool, + metric_callback: Option, } impl std::fmt::Debug for DatabaseTransaction { @@ -29,10 +30,12 @@ impl DatabaseTransaction { #[cfg(feature = "sqlx-mysql")] pub(crate) async fn new_mysql( inner: PoolConnection, + metric_callback: Option, ) -> Result { Self::begin( Arc::new(Mutex::new(InnerConnection::MySql(inner))), DbBackend::MySql, + metric_callback, ) .await } @@ -40,10 +43,12 @@ impl DatabaseTransaction { #[cfg(feature = "sqlx-postgres")] pub(crate) async fn new_postgres( inner: PoolConnection, + metric_callback: Option, ) -> Result { Self::begin( Arc::new(Mutex::new(InnerConnection::Postgres(inner))), DbBackend::Postgres, + metric_callback, ) .await } @@ -51,10 +56,12 @@ impl DatabaseTransaction { #[cfg(feature = "sqlx-sqlite")] pub(crate) async fn new_sqlite( inner: PoolConnection, + metric_callback: Option, ) -> Result { Self::begin( Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), DbBackend::Sqlite, + metric_callback, ) .await } @@ -62,20 +69,27 @@ impl DatabaseTransaction { #[cfg(feature = "mock")] pub(crate) async fn new_mock( inner: Arc, + metric_callback: Option, ) -> Result { let backend = inner.get_database_backend(); - Self::begin(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await + Self::begin( + Arc::new(Mutex::new(InnerConnection::Mock(inner))), + backend, + metric_callback, + ).await } - #[instrument(level = "trace")] + #[instrument(level = "trace", skip(metric_callback))] async fn begin( conn: Arc>, backend: DbBackend, + metric_callback: Option, ) -> Result { let res = DatabaseTransaction { conn, backend, open: true, + metric_callback, }; match *res.conn.lock().await { #[cfg(feature = "sqlx-mysql")] @@ -245,7 +259,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); let _start = std::time::SystemTime::now(); let res = query.execute(conn).await.map(Into::into); - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = self.metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: &stmt, @@ -259,7 +273,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); let _start = std::time::SystemTime::now(); let res = query.execute(conn).await.map(Into::into); - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = self.metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: &stmt, @@ -273,7 +287,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); let _start = std::time::SystemTime::now(); let res = query.execute(conn).await.map(Into::into); - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = self.metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: &stmt, @@ -362,13 +376,17 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { stmt: Statement, ) -> Pin> + 'a>> { Box::pin( - async move { Ok(crate::TransactionStream::build(self.conn.lock().await, stmt).await) }, + async move { Ok(crate::TransactionStream::build(self.conn.lock().await, stmt, self.metric_callback.clone()).await) }, ) } #[instrument(level = "trace")] async fn begin(&self) -> Result { - DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend).await + DatabaseTransaction::begin( + Arc::clone(&self.conn), + self.backend, + self.metric_callback.clone() + ).await } /// Execute the function inside a transaction. diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index 2b2efd68..7216c226 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -21,9 +21,16 @@ use super::sqlx_common::*; pub struct SqlxMySqlConnector; /// Defines a sqlx MySQL pool -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct SqlxMySqlPoolConnection { pool: MySqlPool, + metric_callback: Option, +} + +impl std::fmt::Debug for SqlxMySqlPoolConnection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool) + } } impl SqlxMySqlConnector { @@ -45,7 +52,7 @@ impl SqlxMySqlConnector { } match options.pool_options().connect_with(opt).await { Ok(pool) => Ok(DatabaseConnection::SqlxMySqlPoolConnection( - SqlxMySqlPoolConnection { pool }, + SqlxMySqlPoolConnection { pool, metric_callback: None }, )), Err(e) => Err(sqlx_error_to_conn_err(e)), } @@ -55,7 +62,7 @@ impl SqlxMySqlConnector { impl SqlxMySqlConnector { /// Instantiate a sqlx pool connection to a [DatabaseConnection] pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection { - DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection { pool }) + DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection { pool, metric_callback: None }) } } @@ -72,7 +79,7 @@ impl SqlxMySqlPoolConnection { Ok(res) => Ok(res.into()), Err(err) => Err(sqlx_error_to_exec_err(err)), }; - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = self.metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: &stmt, @@ -102,7 +109,7 @@ impl SqlxMySqlPoolConnection { _ => Err(DbErr::Query(err.to_string())), }, }; - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = self.metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: &stmt, @@ -129,7 +136,7 @@ impl SqlxMySqlPoolConnection { Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()), Err(err) => Err(sqlx_error_to_query_err(err)), }; - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = self.metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: &stmt, @@ -150,7 +157,7 @@ impl SqlxMySqlPoolConnection { debug_print!("{}", stmt); if let Ok(conn) = self.pool.acquire().await { - Ok(QueryStream::from((conn, stmt))) + Ok(QueryStream::from((conn, stmt, self.metric_callback.clone()))) } else { Err(DbErr::Query( "Failed to acquire connection from pool.".to_owned(), @@ -162,7 +169,7 @@ impl SqlxMySqlPoolConnection { #[instrument(level = "trace")] pub async fn begin(&self) -> Result { if let Ok(conn) = self.pool.acquire().await { - DatabaseTransaction::new_mysql(conn).await + DatabaseTransaction::new_mysql(conn, self.metric_callback.clone()).await } else { Err(DbErr::Query( "Failed to acquire connection from pool.".to_owned(), @@ -182,7 +189,7 @@ impl SqlxMySqlPoolConnection { E: std::error::Error + Send, { if let Ok(conn) = self.pool.acquire().await { - let transaction = DatabaseTransaction::new_mysql(conn) + let transaction = DatabaseTransaction::new_mysql(conn, self.metric_callback.clone()) .await .map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await @@ -192,6 +199,13 @@ impl SqlxMySqlPoolConnection { ))) } } + + pub(crate) fn set_metric_callback(&mut self, callback: F) + where + F: Into, + { + self.metric_callback = Some(callback.into()); + } } impl From for QueryResult { diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index 121e4e72..5f90e515 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -21,9 +21,16 @@ use super::sqlx_common::*; pub struct SqlxPostgresConnector; /// Defines a sqlx PostgreSQL pool -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct SqlxPostgresPoolConnection { pool: PgPool, + metric_callback: Option, +} + +impl std::fmt::Debug for SqlxPostgresPoolConnection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SqlxPostgresPoolConnection {{ pool: {:?} }}", self.pool) + } } impl SqlxPostgresConnector { @@ -45,7 +52,7 @@ impl SqlxPostgresConnector { } match options.pool_options().connect_with(opt).await { Ok(pool) => Ok(DatabaseConnection::SqlxPostgresPoolConnection( - SqlxPostgresPoolConnection { pool }, + SqlxPostgresPoolConnection { pool, metric_callback: None }, )), Err(e) => Err(sqlx_error_to_conn_err(e)), } @@ -55,7 +62,7 @@ impl SqlxPostgresConnector { impl SqlxPostgresConnector { /// Instantiate a sqlx pool connection to a [DatabaseConnection] pub fn from_sqlx_postgres_pool(pool: PgPool) -> DatabaseConnection { - DatabaseConnection::SqlxPostgresPoolConnection(SqlxPostgresPoolConnection { pool }) + DatabaseConnection::SqlxPostgresPoolConnection(SqlxPostgresPoolConnection { pool, metric_callback: None }) } } @@ -72,7 +79,7 @@ impl SqlxPostgresPoolConnection { Ok(res) => Ok(res.into()), Err(err) => Err(sqlx_error_to_exec_err(err)), }; - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = self.metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: &stmt, @@ -102,7 +109,7 @@ impl SqlxPostgresPoolConnection { _ => Err(DbErr::Query(err.to_string())), }, }; - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = self.metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: &stmt, @@ -129,7 +136,7 @@ impl SqlxPostgresPoolConnection { Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()), Err(err) => Err(sqlx_error_to_query_err(err)), }; - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = self.metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: &stmt, @@ -150,7 +157,7 @@ impl SqlxPostgresPoolConnection { debug_print!("{}", stmt); if let Ok(conn) = self.pool.acquire().await { - Ok(QueryStream::from((conn, stmt))) + Ok(QueryStream::from((conn, stmt, self.metric_callback.clone()))) } else { Err(DbErr::Query( "Failed to acquire connection from pool.".to_owned(), @@ -162,7 +169,7 @@ impl SqlxPostgresPoolConnection { #[instrument(level = "trace")] pub async fn begin(&self) -> Result { if let Ok(conn) = self.pool.acquire().await { - DatabaseTransaction::new_postgres(conn).await + DatabaseTransaction::new_postgres(conn, self.metric_callback.clone()).await } else { Err(DbErr::Query( "Failed to acquire connection from pool.".to_owned(), @@ -182,7 +189,7 @@ impl SqlxPostgresPoolConnection { E: std::error::Error + Send, { if let Ok(conn) = self.pool.acquire().await { - let transaction = DatabaseTransaction::new_postgres(conn) + let transaction = DatabaseTransaction::new_postgres(conn, self.metric_callback.clone()) .await .map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await @@ -192,6 +199,13 @@ impl SqlxPostgresPoolConnection { ))) } } + + pub(crate) fn set_metric_callback(&mut self, callback: F) + where + F: Into, + { + self.metric_callback = Some(callback.into()); + } } impl From for QueryResult { diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index 9cfd9a72..20f4954e 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -21,9 +21,16 @@ use super::sqlx_common::*; pub struct SqlxSqliteConnector; /// Defines a sqlx SQLite pool -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct SqlxSqlitePoolConnection { pool: SqlitePool, + metric_callback: Option, +} + +impl std::fmt::Debug for SqlxSqlitePoolConnection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool) + } } impl SqlxSqliteConnector { @@ -49,7 +56,7 @@ impl SqlxSqliteConnector { } match options.pool_options().connect_with(opt).await { Ok(pool) => Ok(DatabaseConnection::SqlxSqlitePoolConnection( - SqlxSqlitePoolConnection { pool }, + SqlxSqlitePoolConnection { pool, metric_callback: None }, )), Err(e) => Err(sqlx_error_to_conn_err(e)), } @@ -59,7 +66,7 @@ impl SqlxSqliteConnector { impl SqlxSqliteConnector { /// Instantiate a sqlx pool connection to a [DatabaseConnection] pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection { - DatabaseConnection::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection { pool }) + DatabaseConnection::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection { pool, metric_callback: None }) } } @@ -76,7 +83,7 @@ impl SqlxSqlitePoolConnection { Ok(res) => Ok(res.into()), Err(err) => Err(sqlx_error_to_exec_err(err)), }; - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = self.metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: &stmt, @@ -106,7 +113,7 @@ impl SqlxSqlitePoolConnection { _ => Err(DbErr::Query(err.to_string())), }, }; - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = self.metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: &stmt, @@ -133,7 +140,7 @@ impl SqlxSqlitePoolConnection { Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()), Err(err) => Err(sqlx_error_to_query_err(err)), }; - if let Some(callback) = crate::metric::get_callback() { + if let Some(callback) = self.metric_callback.as_deref() { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: &stmt, @@ -154,7 +161,7 @@ impl SqlxSqlitePoolConnection { debug_print!("{}", stmt); if let Ok(conn) = self.pool.acquire().await { - Ok(QueryStream::from((conn, stmt))) + Ok(QueryStream::from((conn, stmt, self.metric_callback.clone()))) } else { Err(DbErr::Query( "Failed to acquire connection from pool.".to_owned(), @@ -166,7 +173,7 @@ impl SqlxSqlitePoolConnection { #[instrument(level = "trace")] pub async fn begin(&self) -> Result { if let Ok(conn) = self.pool.acquire().await { - DatabaseTransaction::new_sqlite(conn).await + DatabaseTransaction::new_sqlite(conn, self.metric_callback.clone()).await } else { Err(DbErr::Query( "Failed to acquire connection from pool.".to_owned(), @@ -186,7 +193,7 @@ impl SqlxSqlitePoolConnection { E: std::error::Error + Send, { if let Ok(conn) = self.pool.acquire().await { - let transaction = DatabaseTransaction::new_sqlite(conn) + let transaction = DatabaseTransaction::new_sqlite(conn, self.metric_callback.clone()) .await .map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await @@ -196,6 +203,13 @@ impl SqlxSqlitePoolConnection { ))) } } + + pub(crate) fn set_metric_callback(&mut self, callback: F) + where + F: Into, + { + self.metric_callback = Some(callback.into()); + } } impl From for QueryResult { diff --git a/src/metric.rs b/src/metric.rs index 4eab0e74..4ea62ef2 100644 --- a/src/metric.rs +++ b/src/metric.rs @@ -1,10 +1,6 @@ -use std::time::Duration; +use std::{time::Duration, sync::Arc}; -use once_cell::sync::OnceCell; - -type Callback = Box) + Send + Sync>; - -static METRIC: OnceCell = OnceCell::new(); +pub(crate) type Callback = Arc) + Send + Sync>; #[derive(Debug)] /// Query execution infos @@ -14,15 +10,3 @@ pub struct Info<'a> { /// Query data pub statement: &'a crate::Statement, } - -/// Sets a new metric callback, returning it if already set -pub fn set_callback(callback: F) -> Result<(), Callback> -where - F: Fn(&Info<'_>) + Send + Sync + 'static, -{ - METRIC.set(Box::new(callback)) -} - -pub(crate) fn get_callback() -> Option<&'static Callback> { - METRIC.get() -}