sea-orm/src/database/transaction.rs
2022-11-06 12:55:21 +08:00

464 lines
16 KiB
Rust

use crate::{
debug_print, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult,
Statement, StreamTrait, TransactionStream, TransactionTrait,
};
#[cfg(feature = "sqlx-dep")]
use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
use futures::lock::Mutex;
#[cfg(feature = "sqlx-dep")]
use sqlx::{pool::PoolConnection, TransactionManager};
use std::{future::Future, pin::Pin, sync::Arc};
use tracing::instrument;
// a Transaction is just a sugar for a connection where START TRANSACTION has been executed
/// Defines a database transaction, whether it is an open transaction and the type of
/// backend to use
pub struct DatabaseTransaction {
conn: Arc<Mutex<InnerConnection>>,
backend: DbBackend,
open: bool,
metric_callback: Option<crate::metric::Callback>,
}
impl std::fmt::Debug for DatabaseTransaction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "DatabaseTransaction")
}
}
impl DatabaseTransaction {
#[cfg(feature = "sqlx-mysql")]
pub(crate) async fn new_mysql(
inner: PoolConnection<sqlx::MySql>,
metric_callback: Option<crate::metric::Callback>,
) -> Result<DatabaseTransaction, DbErr> {
Self::begin(
Arc::new(Mutex::new(InnerConnection::MySql(inner))),
DbBackend::MySql,
metric_callback,
)
.await
}
#[cfg(feature = "sqlx-postgres")]
pub(crate) async fn new_postgres(
inner: PoolConnection<sqlx::Postgres>,
metric_callback: Option<crate::metric::Callback>,
) -> Result<DatabaseTransaction, DbErr> {
Self::begin(
Arc::new(Mutex::new(InnerConnection::Postgres(inner))),
DbBackend::Postgres,
metric_callback,
)
.await
}
#[cfg(feature = "sqlx-sqlite")]
pub(crate) async fn new_sqlite(
inner: PoolConnection<sqlx::Sqlite>,
metric_callback: Option<crate::metric::Callback>,
) -> Result<DatabaseTransaction, DbErr> {
Self::begin(
Arc::new(Mutex::new(InnerConnection::Sqlite(inner))),
DbBackend::Sqlite,
metric_callback,
)
.await
}
#[cfg(feature = "mock")]
pub(crate) async fn new_mock(
inner: Arc<crate::MockDatabaseConnection>,
metric_callback: Option<crate::metric::Callback>,
) -> Result<DatabaseTransaction, DbErr> {
let backend = inner.get_database_backend();
Self::begin(
Arc::new(Mutex::new(InnerConnection::Mock(inner))),
backend,
metric_callback,
)
.await
}
#[instrument(level = "trace", skip(metric_callback))]
#[allow(unreachable_code)]
async fn begin(
conn: Arc<Mutex<InnerConnection>>,
backend: DbBackend,
metric_callback: Option<crate::metric::Callback>,
) -> Result<DatabaseTransaction, DbErr> {
let res = DatabaseTransaction {
conn,
backend,
open: true,
metric_callback,
};
match *res.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(ref mut c) => {
<sqlx::MySql as sqlx::Database>::TransactionManager::begin(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(ref mut c) => {
<sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(ref mut c) => {
<sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "mock")]
InnerConnection::Mock(ref mut c) => {
c.begin();
}
}
Ok(res)
}
/// Runs a transaction to completion returning an rolling back the transaction on
/// encountering an error if it fails
#[instrument(level = "trace", skip(callback))]
pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'b> FnOnce(
&'b DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
+ Send,
T: Send,
E: std::error::Error + Send,
{
let res = callback(&self).await.map_err(TransactionError::Transaction);
if res.is_ok() {
self.commit().await.map_err(TransactionError::Connection)?;
} else {
self.rollback()
.await
.map_err(TransactionError::Connection)?;
}
res
}
/// Commit a transaction atomically
#[instrument(level = "trace")]
#[allow(unreachable_code)]
pub async fn commit(mut self) -> Result<(), DbErr> {
self.open = false;
match *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(ref mut c) => {
<sqlx::MySql as sqlx::Database>::TransactionManager::commit(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(ref mut c) => {
<sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(ref mut c) => {
<sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "mock")]
InnerConnection::Mock(ref mut c) => {
c.commit();
}
}
Ok(())
}
/// rolls back a transaction in case error are encountered during the operation
#[instrument(level = "trace")]
#[allow(unreachable_code)]
pub async fn rollback(mut self) -> Result<(), DbErr> {
self.open = false;
match *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(ref mut c) => {
<sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(ref mut c) => {
<sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(ref mut c) => {
<sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c)
.await
.map_err(sqlx_error_to_query_err)?
}
#[cfg(feature = "mock")]
InnerConnection::Mock(ref mut c) => {
c.rollback();
}
}
Ok(())
}
// the rollback is queued and will be performed on next async operation, like returning the connection to the pool
#[instrument(level = "trace")]
fn start_rollback(&mut self) {
if self.open {
if let Some(mut conn) = self.conn.try_lock() {
match &mut *conn {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(c) => {
<sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(c) => {
<sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(c) => {
<sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
}
#[cfg(feature = "mock")]
InnerConnection::Mock(c) => {
c.rollback();
}
#[allow(unreachable_patterns)]
_ => unreachable!(),
}
} else {
//this should never happen
panic!("Dropping a locked Transaction");
}
}
}
#[cfg(feature = "sqlx-dep")]
fn map_err_ignore_not_found<T: std::fmt::Debug>(
err: Result<Option<T>, sqlx::Error>,
) -> Result<Option<T>, DbErr> {
if let Err(sqlx::Error::RowNotFound) = err {
Ok(None)
} else {
err.map_err(|e| sqlx_error_to_query_err(e))
}
}
}
impl Drop for DatabaseTransaction {
fn drop(&mut self) {
self.start_rollback();
}
}
#[async_trait::async_trait]
impl ConnectionTrait for DatabaseTransaction {
fn get_database_backend(&self) -> DbBackend {
// this way we don't need to lock
self.backend
}
#[instrument(level = "trace")]
#[allow(unused_variables)]
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
debug_print!("{}", stmt);
match &mut *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(conn) => {
let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
crate::metric::metric!(self.metric_callback, &stmt, {
query.execute(conn).await.map(Into::into)
})
.map_err(sqlx_error_to_exec_err)
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(conn) => {
let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
crate::metric::metric!(self.metric_callback, &stmt, {
query.execute(conn).await.map(Into::into)
})
.map_err(sqlx_error_to_exec_err)
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(conn) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
crate::metric::metric!(self.metric_callback, &stmt, {
query.execute(conn).await.map(Into::into)
})
.map_err(sqlx_error_to_exec_err)
}
#[cfg(feature = "mock")]
InnerConnection::Mock(conn) => return conn.execute(stmt),
#[allow(unreachable_patterns)]
_ => unreachable!(),
}
}
#[instrument(level = "trace")]
#[allow(unused_variables)]
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
debug_print!("{}", stmt);
match &mut *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(conn) => {
let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
crate::metric::metric!(self.metric_callback, &stmt, {
Self::map_err_ignore_not_found(
query.fetch_one(conn).await.map(|row| Some(row.into())),
)
})
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(conn) => {
let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
crate::metric::metric!(self.metric_callback, &stmt, {
Self::map_err_ignore_not_found(
query.fetch_one(conn).await.map(|row| Some(row.into())),
)
})
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(conn) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
crate::metric::metric!(self.metric_callback, &stmt, {
Self::map_err_ignore_not_found(
query.fetch_one(conn).await.map(|row| Some(row.into())),
)
})
}
#[cfg(feature = "mock")]
InnerConnection::Mock(conn) => return conn.query_one(stmt),
#[allow(unreachable_patterns)]
_ => unreachable!(),
}
}
#[instrument(level = "trace")]
#[allow(unused_variables)]
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
debug_print!("{}", stmt);
match &mut *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(conn) => {
let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
crate::metric::metric!(self.metric_callback, &stmt, {
query
.fetch_all(conn)
.await
.map(|rows| rows.into_iter().map(|r| r.into()).collect())
.map_err(sqlx_error_to_query_err)
})
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(conn) => {
let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
crate::metric::metric!(self.metric_callback, &stmt, {
query
.fetch_all(conn)
.await
.map(|rows| rows.into_iter().map(|r| r.into()).collect())
.map_err(sqlx_error_to_query_err)
})
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(conn) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
crate::metric::metric!(self.metric_callback, &stmt, {
query
.fetch_all(conn)
.await
.map(|rows| rows.into_iter().map(|r| r.into()).collect())
.map_err(sqlx_error_to_query_err)
})
}
#[cfg(feature = "mock")]
InnerConnection::Mock(conn) => return conn.query_all(stmt),
#[allow(unreachable_patterns)]
_ => unreachable!(),
}
}
}
impl StreamTrait for DatabaseTransaction {
type Stream<'a> = TransactionStream<'a>;
#[instrument(level = "trace")]
fn stream<'a>(
&'a self,
stmt: Statement,
) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
Box::pin(async move {
let conn = self.conn.lock().await;
Ok(crate::TransactionStream::build(
conn,
stmt,
self.metric_callback.clone(),
))
})
}
}
#[async_trait::async_trait]
impl TransactionTrait for DatabaseTransaction {
#[instrument(level = "trace")]
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
DatabaseTransaction::begin(
Arc::clone(&self.conn),
self.backend,
self.metric_callback.clone(),
)
.await
}
/// Execute the function inside a transaction.
/// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
#[instrument(level = "trace", skip(_callback))]
async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
T: Send,
E: std::error::Error + Send,
{
let transaction = self.begin().await.map_err(TransactionError::Connection)?;
transaction.run(_callback).await
}
}
/// Defines errors for handling transaction failures
#[derive(Debug)]
pub enum TransactionError<E>
where
E: std::error::Error,
{
/// A Database connection error
Connection(DbErr),
/// An error occurring when doing database transactions
Transaction(E),
}
impl<E> std::fmt::Display for TransactionError<E>
where
E: std::error::Error,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
}
}
}
impl<E> std::error::Error for TransactionError<E> where E: std::error::Error {}