Transaction 3

This commit is contained in:
Marco Napetti 2021-10-04 10:42:37 +08:00 committed by Chris Tsang
parent 700a0206e7
commit 02ebc9745c
31 changed files with 1093 additions and 161 deletions

View File

@ -36,6 +36,7 @@ serde = { version = "^1.0", features = ["derive"] }
serde_json = { version = "^1", optional = true } serde_json = { version = "^1", optional = true }
sqlx = { version = "^0.5", optional = true } sqlx = { version = "^0.5", optional = true }
uuid = { version = "0.8", features = ["serde", "v4"], optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true }
ouroboros = "0.11"
[dev-dependencies] [dev-dependencies]
smol = { version = "^1.2" } smol = { version = "^1.2" }

View File

@ -42,7 +42,7 @@ async fn create(conn: Connection<Db>, post_form: Form<post::Model>) -> Flash<Red
text: Set(form.text.to_owned()), text: Set(form.text.to_owned()),
..Default::default() ..Default::default()
} }
.save(&conn) .save(&*conn)
.await .await
.expect("could not insert post"); .expect("could not insert post");
@ -52,7 +52,7 @@ async fn create(conn: Connection<Db>, post_form: Form<post::Model>) -> Flash<Red
#[post("/<id>", data = "<post_form>")] #[post("/<id>", data = "<post_form>")]
async fn update(conn: Connection<Db>, id: i32, post_form: Form<post::Model>) -> Flash<Redirect> { async fn update(conn: Connection<Db>, id: i32, post_form: Form<post::Model>) -> Flash<Redirect> {
let post: post::ActiveModel = Post::find_by_id(id) let post: post::ActiveModel = Post::find_by_id(id)
.one(&conn) .one(&*conn)
.await .await
.unwrap() .unwrap()
.unwrap() .unwrap()
@ -65,7 +65,7 @@ async fn update(conn: Connection<Db>, id: i32, post_form: Form<post::Model>) ->
title: Set(form.title.to_owned()), title: Set(form.title.to_owned()),
text: Set(form.text.to_owned()), text: Set(form.text.to_owned()),
} }
.save(&conn) .save(&*conn)
.await .await
.expect("could not edit post"); .expect("could not edit post");
@ -89,7 +89,7 @@ async fn list(
// Setup paginator // Setup paginator
let paginator = Post::find() let paginator = Post::find()
.order_by_asc(post::Column::Id) .order_by_asc(post::Column::Id)
.paginate(&conn, posts_per_page); .paginate(&*conn, posts_per_page);
let num_pages = paginator.num_pages().await.ok().unwrap(); let num_pages = paginator.num_pages().await.ok().unwrap();
// Fetch paginated posts // Fetch paginated posts
@ -113,7 +113,7 @@ async fn list(
#[get("/<id>")] #[get("/<id>")]
async fn edit(conn: Connection<Db>, id: i32) -> Template { async fn edit(conn: Connection<Db>, id: i32) -> Template {
let post: Option<post::Model> = Post::find_by_id(id) let post: Option<post::Model> = Post::find_by_id(id)
.one(&conn) .one(&*conn)
.await .await
.expect("could not find post"); .expect("could not find post");
@ -128,20 +128,20 @@ async fn edit(conn: Connection<Db>, id: i32) -> Template {
#[delete("/<id>")] #[delete("/<id>")]
async fn delete(conn: Connection<Db>, id: i32) -> Flash<Redirect> { async fn delete(conn: Connection<Db>, id: i32) -> Flash<Redirect> {
let post: post::ActiveModel = Post::find_by_id(id) let post: post::ActiveModel = Post::find_by_id(id)
.one(&conn) .one(&*conn)
.await .await
.unwrap() .unwrap()
.unwrap() .unwrap()
.into(); .into();
post.delete(&conn).await.unwrap(); post.delete(&*conn).await.unwrap();
Flash::success(Redirect::to("/"), "Post successfully deleted.") Flash::success(Redirect::to("/"), "Post successfully deleted.")
} }
#[delete("/")] #[delete("/")]
async fn destroy(conn: Connection<Db>) -> Result<()> { async fn destroy(conn: Connection<Db>) -> Result<()> {
Post::delete_many().exec(&conn).await.unwrap(); Post::delete_many().exec(&*conn).await.unwrap();
Ok(()) Ok(())
} }

View File

@ -1,5 +1,5 @@
use sea_orm::sea_query::{ColumnDef, TableCreateStatement}; use sea_orm::sea_query::{ColumnDef, TableCreateStatement};
use sea_orm::{error::*, sea_query, DbConn, ExecResult}; use sea_orm::{query::*, error::*, sea_query, DbConn, ExecResult};
async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result<ExecResult, DbErr> { async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result<ExecResult, DbErr> {
let builder = db.get_database_backend(); let builder = db.get_database_backend();

View File

@ -1,4 +1,5 @@
use crate::{error::*, ExecResult, QueryResult, Statement, StatementBuilder}; use std::{future::Future, pin::Pin, sync::Arc};
use crate::{DatabaseTransaction, ConnectionTrait, ExecResult, QueryResult, Statement, StatementBuilder, TransactionError, error::*};
use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder};
#[cfg_attr(not(feature = "mock"), derive(Clone))] #[cfg_attr(not(feature = "mock"), derive(Clone))]
@ -10,7 +11,7 @@ pub enum DatabaseConnection {
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection), SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection),
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
MockDatabaseConnection(crate::MockDatabaseConnection), MockDatabaseConnection(Arc<crate::MockDatabaseConnection>),
Disconnected, Disconnected,
} }
@ -51,8 +52,11 @@ impl std::fmt::Debug for DatabaseConnection {
} }
} }
impl DatabaseConnection { #[async_trait::async_trait]
pub fn get_database_backend(&self) -> DbBackend { impl<'a> ConnectionTrait<'a> for DatabaseConnection {
type Stream = crate::QueryStream;
fn get_database_backend(&self) -> DbBackend {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql, DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
@ -66,7 +70,7 @@ impl DatabaseConnection {
} }
} }
pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> { async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await, DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await,
@ -75,12 +79,12 @@ impl DatabaseConnection {
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await, DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await,
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt).await, DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt),
DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())),
} }
} }
pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> { async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await, DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await,
@ -89,12 +93,12 @@ impl DatabaseConnection {
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await, DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await,
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt).await, DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt),
DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())),
} }
} }
pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> { async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
match self { match self {
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await, DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await,
@ -103,12 +107,76 @@ impl DatabaseConnection {
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await, DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await,
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt).await, DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt),
DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())),
} }
} }
fn stream(&'a self, stmt: Statement) -> Pin<Box<dyn Future<Output=Result<Self::Stream, DbErr>> + 'a>> {
Box::pin(async move {
Ok(match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => crate::QueryStream::from((Arc::clone(conn), stmt)),
DatabaseConnection::Disconnected => panic!("Disconnected"),
})
})
}
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => DatabaseTransaction::new_mock(Arc::clone(conn)).await,
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}
/// Execute the function inside a transaction.
/// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>> + Send,
T: Send,
E: std::error::Error + Send,
{
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.transaction(_callback).await,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => {
let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)).await.map_err(|e| TransactionError::Connection(e))?;
transaction.run(_callback).await
},
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}
#[cfg(feature = "mock")]
fn is_mock_connection(&self) -> bool {
match self {
DatabaseConnection::MockDatabaseConnection(_) => true,
_ => false,
}
}
}
#[cfg(feature = "mock")]
impl DatabaseConnection {
pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection { pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection {
match self { match self {
DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn, DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn,
@ -116,12 +184,6 @@ impl DatabaseConnection {
} }
} }
#[cfg(not(feature = "mock"))]
pub fn as_mock_connection(&self) -> Option<bool> {
None
}
#[cfg(feature = "mock")]
pub fn into_transaction_log(self) -> Vec<crate::Transaction> { pub fn into_transaction_log(self) -> Vec<crate::Transaction> {
let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap(); let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap();
mocker.drain_transaction_log() mocker.drain_transaction_log()

View File

@ -0,0 +1,45 @@
use std::{future::Future, pin::Pin, sync::Arc};
use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, MockDatabaseConnection, QueryResult, Statement, TransactionError};
use futures::Stream;
#[cfg(feature = "sqlx-dep")]
use sqlx::pool::PoolConnection;
pub(crate) enum InnerConnection {
#[cfg(feature = "sqlx-mysql")]
MySql(PoolConnection<sqlx::MySql>),
#[cfg(feature = "sqlx-postgres")]
Postgres(PoolConnection<sqlx::Postgres>),
#[cfg(feature = "sqlx-sqlite")]
Sqlite(PoolConnection<sqlx::Sqlite>),
#[cfg(feature = "mock")]
Mock(Arc<MockDatabaseConnection>),
}
#[async_trait::async_trait]
pub trait ConnectionTrait<'a>: Sync {
type Stream: Stream<Item=Result<QueryResult, DbErr>>;
fn get_database_backend(&self) -> DbBackend;
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr>;
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr>;
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;
fn stream(&'a self, stmt: Statement) -> Pin<Box<dyn Future<Output=Result<Self::Stream, DbErr>> + 'a>>;
async fn begin(&self) -> Result<DatabaseTransaction, DbErr>;
/// Execute the function inside a transaction.
/// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>> + Send,
T: Send,
E: std::error::Error + Send;
fn is_mock_connection(&self) -> bool {
false
}
}

View File

@ -0,0 +1,308 @@
use std::{sync::Arc, future::Future, pin::Pin};
use crate::{ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, Statement, TransactionStream, debug_print};
use futures::lock::Mutex;
#[cfg(feature = "sqlx-dep")]
use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err};
#[cfg(feature = "sqlx-dep")]
use sqlx::{pool::PoolConnection, TransactionManager};
// a Transaction is just a sugar for a connection where START TRANSACTION has been executed
pub struct DatabaseTransaction {
conn: Arc<Mutex<InnerConnection>>,
backend: DbBackend,
open: bool,
}
impl std::fmt::Debug for DatabaseTransaction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "DatabaseTransaction")
}
}
impl DatabaseTransaction {
#[cfg(feature = "sqlx-mysql")]
pub(crate) async fn new_mysql(inner: PoolConnection<sqlx::MySql>) -> Result<DatabaseTransaction, DbErr> {
Self::build(Arc::new(Mutex::new(InnerConnection::MySql(inner))), DbBackend::MySql).await
}
#[cfg(feature = "sqlx-postgres")]
pub(crate) async fn new_postgres(inner: PoolConnection<sqlx::Postgres>) -> Result<DatabaseTransaction, DbErr> {
Self::build(Arc::new(Mutex::new(InnerConnection::Postgres(inner))), DbBackend::Postgres).await
}
#[cfg(feature = "sqlx-sqlite")]
pub(crate) async fn new_sqlite(inner: PoolConnection<sqlx::Sqlite>) -> Result<DatabaseTransaction, DbErr> {
Self::build(Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), DbBackend::Sqlite).await
}
#[cfg(feature = "mock")]
pub(crate) async fn new_mock(inner: Arc<crate::MockDatabaseConnection>) -> Result<DatabaseTransaction, DbErr> {
let backend = inner.get_database_backend();
Self::build(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await
}
async fn build(conn: Arc<Mutex<InnerConnection>>, backend: DbBackend) -> Result<DatabaseTransaction, DbErr> {
let res = DatabaseTransaction {
conn,
backend,
open: true,
};
match *res.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(ref mut c) => {
<sqlx::MySql as sqlx::Database>::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)?
},
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(ref mut c) => {
<sqlx::Postgres as sqlx::Database>::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)?
},
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(ref mut c) => {
<sqlx::Sqlite as sqlx::Database>::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)?
},
// should we do something for mocked connections?
#[cfg(feature = "mock")]
InnerConnection::Mock(_) => {},
}
Ok(res)
}
pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>> + Send,
T: Send,
E: std::error::Error + Send,
{
let res = callback(&self).await.map_err(|e| TransactionError::Transaction(e));
if res.is_ok() {
self.commit().await.map_err(|e| TransactionError::Connection(e))?;
}
else {
self.rollback().await.map_err(|e| TransactionError::Connection(e))?;
}
res
}
pub async fn commit(mut self) -> Result<(), DbErr> {
self.open = false;
match *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(ref mut c) => {
<sqlx::MySql as sqlx::Database>::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)?
},
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(ref mut c) => {
<sqlx::Postgres as sqlx::Database>::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)?
},
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(ref mut c) => {
<sqlx::Sqlite as sqlx::Database>::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)?
},
//Should we do something for mocked connections?
#[cfg(feature = "mock")]
InnerConnection::Mock(_) => {},
}
Ok(())
}
pub async fn rollback(mut self) -> Result<(), DbErr> {
self.open = false;
match *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(ref mut c) => {
<sqlx::MySql as sqlx::Database>::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)?
},
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(ref mut c) => {
<sqlx::Postgres as sqlx::Database>::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)?
},
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(ref mut c) => {
<sqlx::Sqlite as sqlx::Database>::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)?
},
//Should we do something for mocked connections?
#[cfg(feature = "mock")]
InnerConnection::Mock(_) => {},
}
Ok(())
}
// the rollback is queued and will be performed on next async operation, like returning the connection to the pool
fn start_rollback(&mut self) {
if self.open {
if let Some(mut conn) = self.conn.try_lock() {
match &mut *conn {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(c) => {
<sqlx::MySql as sqlx::Database>::TransactionManager::start_rollback(c);
},
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(c) => {
<sqlx::Postgres as sqlx::Database>::TransactionManager::start_rollback(c);
},
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(c) => {
<sqlx::Sqlite as sqlx::Database>::TransactionManager::start_rollback(c);
},
//Should we do something for mocked connections?
#[cfg(feature = "mock")]
InnerConnection::Mock(_) => {},
}
}
else {
//this should never happen
panic!("Dropping a locked Transaction");
}
}
}
}
impl Drop for DatabaseTransaction {
fn drop(&mut self) {
self.start_rollback();
}
}
#[async_trait::async_trait]
impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
type Stream = TransactionStream<'a>;
fn get_database_backend(&self) -> DbBackend {
// this way we don't need to lock
self.backend
}
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
debug_print!("{}", stmt);
let _res = match &mut *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(conn) => {
let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
query.execute(conn).await
.map(Into::into)
},
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(conn) => {
let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
query.execute(conn).await
.map(Into::into)
},
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(conn) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
query.execute(conn).await
.map(Into::into)
},
#[cfg(feature = "mock")]
InnerConnection::Mock(conn) => return conn.execute(stmt),
};
#[cfg(feature = "sqlx-dep")]
_res.map_err(sqlx_error_to_exec_err)
}
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
debug_print!("{}", stmt);
let _res = match &mut *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(conn) => {
let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
query.fetch_one(conn).await
.map(|row| Some(row.into()))
},
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(conn) => {
let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
query.fetch_one(conn).await
.map(|row| Some(row.into()))
},
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(conn) => {
let query= crate::driver::sqlx_sqlite::sqlx_query(&stmt);
query.fetch_one(conn).await
.map(|row| Some(row.into()))
},
#[cfg(feature = "mock")]
InnerConnection::Mock(conn) => return conn.query_one(stmt),
};
#[cfg(feature = "sqlx-dep")]
if let Err(sqlx::Error::RowNotFound) = _res {
Ok(None)
}
else {
_res.map_err(sqlx_error_to_query_err)
}
}
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
debug_print!("{}", stmt);
let _res = match &mut *self.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(conn) => {
let query = crate::driver::sqlx_mysql::sqlx_query(&stmt);
query.fetch_all(conn).await
.map(|rows| rows.into_iter().map(|r| r.into()).collect())
},
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(conn) => {
let query = crate::driver::sqlx_postgres::sqlx_query(&stmt);
query.fetch_all(conn).await
.map(|rows| rows.into_iter().map(|r| r.into()).collect())
},
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(conn) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt);
query.fetch_all(conn).await
.map(|rows| rows.into_iter().map(|r| r.into()).collect())
},
#[cfg(feature = "mock")]
InnerConnection::Mock(conn) => return conn.query_all(stmt),
};
#[cfg(feature = "sqlx-dep")]
_res.map_err(sqlx_error_to_query_err)
}
fn stream(&'a self, stmt: Statement) -> Pin<Box<dyn Future<Output=Result<Self::Stream, DbErr>> + 'a>> {
Box::pin(async move {
Ok(crate::TransactionStream::build(self.conn.lock().await, stmt).await)
})
}
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
DatabaseTransaction::build(Arc::clone(&self.conn), self.backend).await
}
/// Execute the function inside a transaction.
/// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>> + Send,
T: Send,
E: std::error::Error + Send,
{
let transaction = self.begin().await.map_err(|e| TransactionError::Connection(e))?;
transaction.run(_callback).await
}
}
#[derive(Debug)]
pub enum TransactionError<E>
where E: std::error::Error {
Connection(DbErr),
Transaction(E),
}
impl<E> std::fmt::Display for TransactionError<E>
where E: std::error::Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransactionError::Connection(e) => std::fmt::Display::fmt(e, f),
TransactionError::Transaction(e) => std::fmt::Display::fmt(e, f),
}
}
}
impl<E> std::error::Error for TransactionError<E>
where E: std::error::Error {}

View File

@ -4,7 +4,7 @@ use crate::{
Statement, Transaction, Statement, Transaction,
}; };
use sea_query::{Value, ValueType}; use sea_query::{Value, ValueType};
use std::collections::BTreeMap; use std::{collections::BTreeMap, sync::Arc};
#[derive(Debug)] #[derive(Debug)]
pub struct MockDatabase { pub struct MockDatabase {
@ -40,7 +40,7 @@ impl MockDatabase {
} }
pub fn into_connection(self) -> DatabaseConnection { pub fn into_connection(self) -> DatabaseConnection {
DatabaseConnection::MockDatabaseConnection(MockDatabaseConnection::new(self)) DatabaseConnection::MockDatabaseConnection(Arc::new(MockDatabaseConnection::new(self)))
} }
pub fn append_exec_results(mut self, mut vec: Vec<MockExecResult>) -> Self { pub fn append_exec_results(mut self, mut vec: Vec<MockExecResult>) -> Self {
@ -100,7 +100,8 @@ impl MockRow {
where where
T: ValueType, T: ValueType,
{ {
Ok(self.values.get(col).unwrap().clone().unwrap()) T::try_from(self.values.get(col).unwrap().clone())
.map_err(|e| DbErr::Query(e.to_string()))
} }
pub fn into_column_value_tuples(self) -> impl Iterator<Item = (String, Value)> { pub fn into_column_value_tuples(self) -> impl Iterator<Item = (String, Value)> {

View File

@ -3,12 +3,18 @@ mod connection;
mod mock; mod mock;
mod statement; mod statement;
mod transaction; mod transaction;
mod db_connection;
mod db_transaction;
mod stream;
pub use connection::*; pub use connection::*;
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
pub use mock::*; pub use mock::*;
pub use statement::*; pub use statement::*;
pub use transaction::*; pub use transaction::*;
pub use db_connection::*;
pub use db_transaction::*;
pub use stream::*;
use crate::DbErr; use crate::DbErr;

View File

@ -0,0 +1,5 @@
mod query;
mod transaction;
pub use query::*;
pub use transaction::*;

View File

@ -0,0 +1,108 @@
use std::{pin::Pin, task::Poll, sync::Arc};
use futures::Stream;
#[cfg(feature = "sqlx-dep")]
use futures::TryStreamExt;
#[cfg(feature = "sqlx-dep")]
use sqlx::{pool::PoolConnection, Executor};
use crate::{DbErr, InnerConnection, QueryResult, Statement};
#[ouroboros::self_referencing]
pub struct QueryStream {
stmt: Statement,
conn: InnerConnection,
#[borrows(mut conn, stmt)]
#[not_covariant]
stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + 'this>>,
}
#[cfg(feature = "sqlx-mysql")]
impl From<(PoolConnection<sqlx::MySql>, Statement)> for QueryStream {
fn from((conn, stmt): (PoolConnection<sqlx::MySql>, Statement)) -> Self {
QueryStream::build(stmt, InnerConnection::MySql(conn))
}
}
#[cfg(feature = "sqlx-postgres")]
impl From<(PoolConnection<sqlx::Postgres>, Statement)> for QueryStream {
fn from((conn, stmt): (PoolConnection<sqlx::Postgres>, Statement)) -> Self {
QueryStream::build(stmt, InnerConnection::Postgres(conn))
}
}
#[cfg(feature = "sqlx-sqlite")]
impl From<(PoolConnection<sqlx::Sqlite>, Statement)> for QueryStream {
fn from((conn, stmt): (PoolConnection<sqlx::Sqlite>, Statement)) -> Self {
QueryStream::build(stmt, InnerConnection::Sqlite(conn))
}
}
#[cfg(feature = "mock")]
impl From<(Arc<crate::MockDatabaseConnection>, Statement)> for QueryStream {
fn from((conn, stmt): (Arc<crate::MockDatabaseConnection>, Statement)) -> Self {
QueryStream::build(stmt, InnerConnection::Mock(conn))
}
}
impl std::fmt::Debug for QueryStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "QueryStream")
}
}
impl QueryStream {
fn build(stmt: Statement, conn: InnerConnection) -> QueryStream {
QueryStreamBuilder {
stmt,
conn,
stream_builder: |conn, stmt| {
match conn {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(c) => {
let query = crate::driver::sqlx_mysql::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err)
)
},
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(c) => {
let query = crate::driver::sqlx_postgres::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err)
)
},
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(c) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err)
)
},
#[cfg(feature = "mock")]
InnerConnection::Mock(c) => {
c.fetch(stmt)
},
}
},
}.build()
}
}
impl Stream for QueryStream {
type Item = Result<QueryResult, DbErr>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
this.with_stream_mut(|stream| {
stream.as_mut().poll_next(cx)
})
}
}

View File

@ -0,0 +1,82 @@
use std::{ops::DerefMut, pin::Pin, task::Poll};
use futures::Stream;
#[cfg(feature = "sqlx-dep")]
use futures::TryStreamExt;
#[cfg(feature = "sqlx-dep")]
use sqlx::Executor;
use futures::lock::MutexGuard;
use crate::{DbErr, InnerConnection, QueryResult, Statement};
#[ouroboros::self_referencing]
pub struct TransactionStream<'a> {
stmt: Statement,
conn: MutexGuard<'a, InnerConnection>,
#[borrows(mut conn, stmt)]
#[not_covariant]
stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + 'this>>,
}
impl<'a> std::fmt::Debug for TransactionStream<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TransactionStream")
}
}
impl<'a> TransactionStream<'a> {
pub(crate) async fn build(conn: MutexGuard<'a, InnerConnection>, stmt: Statement) -> TransactionStream<'a> {
TransactionStreamAsyncBuilder {
stmt,
conn,
stream_builder: |conn, stmt| Box::pin(async move {
match conn.deref_mut() {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(c) => {
let query = crate::driver::sqlx_mysql::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err)
) as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>>
},
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(c) => {
let query = crate::driver::sqlx_postgres::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err)
) as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>>
},
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(c) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(stmt);
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err)
) as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>>
},
#[cfg(feature = "mock")]
InnerConnection::Mock(c) => {
c.fetch(stmt)
},
}
}),
}.build().await
}
}
impl<'a> Stream for TransactionStream<'a> {
type Item = Result<QueryResult, DbErr>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
this.with_stream_mut(|stream| {
stream.as_mut().poll_next(cx)
})
}
}

View File

@ -2,11 +2,11 @@ use crate::{
debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult, debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult,
Statement, Transaction, Statement, Transaction,
}; };
use std::fmt::Debug; use std::{fmt::Debug, pin::Pin, sync::{Arc,
use std::sync::{
atomic::{AtomicUsize, Ordering}, atomic::{AtomicUsize, Ordering},
Mutex, Mutex,
}; }};
use futures::Stream;
#[derive(Debug)] #[derive(Debug)]
pub struct MockDatabaseConnector; pub struct MockDatabaseConnector;
@ -50,7 +50,7 @@ impl MockDatabaseConnector {
macro_rules! connect_mock_db { macro_rules! connect_mock_db {
( $syntax: expr ) => { ( $syntax: expr ) => {
Ok(DatabaseConnection::MockDatabaseConnection( Ok(DatabaseConnection::MockDatabaseConnection(
MockDatabaseConnection::new(MockDatabase::new($syntax)), Arc::new(MockDatabaseConnection::new(MockDatabase::new($syntax))),
)) ))
}; };
} }
@ -86,25 +86,32 @@ impl MockDatabaseConnection {
&self.mocker &self.mocker
} }
pub async fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> { pub fn execute(&self, statement: Statement) -> Result<ExecResult, DbErr> {
debug_print!("{}", statement); debug_print!("{}", statement);
let counter = self.counter.fetch_add(1, Ordering::SeqCst); let counter = self.counter.fetch_add(1, Ordering::SeqCst);
self.mocker.lock().unwrap().execute(counter, statement) self.mocker.lock().unwrap().execute(counter, statement)
} }
pub async fn query_one(&self, statement: Statement) -> Result<Option<QueryResult>, DbErr> { pub fn query_one(&self, statement: Statement) -> Result<Option<QueryResult>, DbErr> {
debug_print!("{}", statement); debug_print!("{}", statement);
let counter = self.counter.fetch_add(1, Ordering::SeqCst); let counter = self.counter.fetch_add(1, Ordering::SeqCst);
let result = self.mocker.lock().unwrap().query(counter, statement)?; let result = self.mocker.lock().unwrap().query(counter, statement)?;
Ok(result.into_iter().next()) Ok(result.into_iter().next())
} }
pub async fn query_all(&self, statement: Statement) -> Result<Vec<QueryResult>, DbErr> { pub fn query_all(&self, statement: Statement) -> Result<Vec<QueryResult>, DbErr> {
debug_print!("{}", statement); debug_print!("{}", statement);
let counter = self.counter.fetch_add(1, Ordering::SeqCst); let counter = self.counter.fetch_add(1, Ordering::SeqCst);
self.mocker.lock().unwrap().query(counter, statement) self.mocker.lock().unwrap().query(counter, statement)
} }
pub fn fetch(&self, statement: &Statement) -> Pin<Box<dyn Stream<Item=Result<QueryResult, DbErr>>>> {
match self.query_all(statement.clone()) {
Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(|r| Ok(r)))),
Err(e) => Box::pin(futures::stream::iter(Some(Err(e)).into_iter())),
}
}
pub fn get_database_backend(&self) -> DbBackend { pub fn get_database_backend(&self) -> DbBackend {
self.mocker.lock().unwrap().get_database_backend() self.mocker.lock().unwrap().get_database_backend()
} }

View File

@ -3,11 +3,11 @@ mod mock;
#[cfg(feature = "sqlx-dep")] #[cfg(feature = "sqlx-dep")]
mod sqlx_common; mod sqlx_common;
#[cfg(feature = "sqlx-mysql")] #[cfg(feature = "sqlx-mysql")]
mod sqlx_mysql; pub(crate) mod sqlx_mysql;
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
mod sqlx_postgres; pub(crate) mod sqlx_postgres;
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
mod sqlx_sqlite; pub(crate) mod sqlx_sqlite;
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
pub use mock::*; pub use mock::*;

View File

@ -1,12 +1,11 @@
use sqlx::{ use std::{future::Future, pin::Pin};
mysql::{MySqlArguments, MySqlQueryResult, MySqlRow},
MySql, MySqlPool, use sqlx::{MySql, MySqlPool, mysql::{MySqlArguments, MySqlQueryResult, MySqlRow}};
};
sea_query::sea_query_driver_mysql!(); sea_query::sea_query_driver_mysql!();
use sea_query_driver_mysql::bind_query; use sea_query_driver_mysql::bind_query;
use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*};
use super::sqlx_common::*; use super::sqlx_common::*;
@ -20,7 +19,7 @@ pub struct SqlxMySqlPoolConnection {
impl SqlxMySqlConnector { impl SqlxMySqlConnector {
pub fn accepts(string: &str) -> bool { pub fn accepts(string: &str) -> bool {
DbBackend::MySql.is_prefix_of(string) string.starts_with("mysql://")
} }
pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> { pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
@ -91,6 +90,44 @@ impl SqlxMySqlPoolConnection {
)) ))
} }
} }
pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
debug_print!("{}", stmt);
if let Ok(conn) = self.pool.acquire().await {
Ok(QueryStream::from((conn, stmt)))
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
))
}
}
pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
if let Ok(conn) = self.pool.acquire().await {
DatabaseTransaction::new_mysql(conn).await
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
))
}
}
pub async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>> + Send,
T: Send,
E: std::error::Error + Send,
{
if let Ok(conn) = self.pool.acquire().await {
let transaction = DatabaseTransaction::new_mysql(conn).await.map_err(|e| TransactionError::Connection(e))?;
transaction.run(callback).await
} else {
Err(TransactionError::Connection(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
)))
}
}
} }
impl From<MySqlRow> for QueryResult { impl From<MySqlRow> for QueryResult {
@ -109,7 +146,7 @@ impl From<MySqlQueryResult> for ExecResult {
} }
} }
fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySqlArguments> { pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySqlArguments> {
let mut query = sqlx::query(&stmt.sql); let mut query = sqlx::query(&stmt.sql);
if let Some(values) = &stmt.values { if let Some(values) = &stmt.values {
query = bind_query(query, values); query = bind_query(query, values);

View File

@ -1,12 +1,11 @@
use sqlx::{ use std::{future::Future, pin::Pin};
postgres::{PgArguments, PgQueryResult, PgRow},
PgPool, Postgres, use sqlx::{PgPool, Postgres, postgres::{PgArguments, PgQueryResult, PgRow}};
};
sea_query::sea_query_driver_postgres!(); sea_query::sea_query_driver_postgres!();
use sea_query_driver_postgres::bind_query; use sea_query_driver_postgres::bind_query;
use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*};
use super::sqlx_common::*; use super::sqlx_common::*;
@ -20,7 +19,7 @@ pub struct SqlxPostgresPoolConnection {
impl SqlxPostgresConnector { impl SqlxPostgresConnector {
pub fn accepts(string: &str) -> bool { pub fn accepts(string: &str) -> bool {
DbBackend::Postgres.is_prefix_of(string) string.starts_with("postgres://")
} }
pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> { pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
@ -91,6 +90,44 @@ impl SqlxPostgresPoolConnection {
)) ))
} }
} }
pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
debug_print!("{}", stmt);
if let Ok(conn) = self.pool.acquire().await {
Ok(QueryStream::from((conn, stmt)))
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
))
}
}
pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
if let Ok(conn) = self.pool.acquire().await {
DatabaseTransaction::new_postgres(conn).await
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
))
}
}
pub async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>> + Send,
T: Send,
E: std::error::Error + Send,
{
if let Ok(conn) = self.pool.acquire().await {
let transaction = DatabaseTransaction::new_postgres(conn).await.map_err(|e| TransactionError::Connection(e))?;
transaction.run(callback).await
} else {
Err(TransactionError::Connection(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
)))
}
}
} }
impl From<PgRow> for QueryResult { impl From<PgRow> for QueryResult {
@ -109,7 +146,7 @@ impl From<PgQueryResult> for ExecResult {
} }
} }
fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Postgres, PgArguments> { pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Postgres, PgArguments> {
let mut query = sqlx::query(&stmt.sql); let mut query = sqlx::query(&stmt.sql);
if let Some(values) = &stmt.values { if let Some(values) = &stmt.values {
query = bind_query(query, values); query = bind_query(query, values);

View File

@ -1,12 +1,11 @@
use sqlx::{ use std::{future::Future, pin::Pin};
sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow},
Sqlite, SqlitePool, use sqlx::{Sqlite, SqlitePool, sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}};
};
sea_query::sea_query_driver_sqlite!(); sea_query::sea_query_driver_sqlite!();
use sea_query_driver_sqlite::bind_query; use sea_query_driver_sqlite::bind_query;
use crate::{debug_print, error::*, executor::*, DatabaseConnection, DbBackend, Statement}; use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*};
use super::sqlx_common::*; use super::sqlx_common::*;
@ -20,7 +19,7 @@ pub struct SqlxSqlitePoolConnection {
impl SqlxSqliteConnector { impl SqlxSqliteConnector {
pub fn accepts(string: &str) -> bool { pub fn accepts(string: &str) -> bool {
DbBackend::Sqlite.is_prefix_of(string) string.starts_with("sqlite:")
} }
pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> { pub async fn connect(string: &str) -> Result<DatabaseConnection, DbErr> {
@ -91,6 +90,44 @@ impl SqlxSqlitePoolConnection {
)) ))
} }
} }
pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
debug_print!("{}", stmt);
if let Ok(conn) = self.pool.acquire().await {
Ok(QueryStream::from((conn, stmt)))
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
))
}
}
pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
if let Ok(conn) = self.pool.acquire().await {
DatabaseTransaction::new_sqlite(conn).await
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
))
}
}
pub async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>> + Send,
T: Send,
E: std::error::Error + Send,
{
if let Ok(conn) = self.pool.acquire().await {
let transaction = DatabaseTransaction::new_sqlite(conn).await.map_err(|e| TransactionError::Connection(e))?;
transaction.run(callback).await
} else {
Err(TransactionError::Connection(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
)))
}
}
} }
impl From<SqliteRow> for QueryResult { impl From<SqliteRow> for QueryResult {
@ -109,7 +146,7 @@ impl From<SqliteQueryResult> for ExecResult {
} }
} }
fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqliteArguments> { pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqliteArguments> {
let mut query = sqlx::query(&stmt.sql); let mut query = sqlx::query(&stmt.sql);
if let Some(values) = &stmt.values { if let Some(values) = &stmt.values {
query = bind_query(query, values); query = bind_query(query, values);

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
error::*, DatabaseConnection, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn, error::*, ConnectionTrait, DeleteResult, EntityTrait, Iterable, PrimaryKeyToColumn,
PrimaryKeyTrait, Value, PrimaryKeyTrait, Value,
}; };
use async_trait::async_trait; use async_trait::async_trait;
@ -67,9 +67,11 @@ pub trait ActiveModelTrait: Clone + Debug {
fn default() -> Self; fn default() -> Self;
async fn insert(self, db: &DatabaseConnection) -> Result<Self, DbErr> async fn insert<'a, C>(self, db: &'a C) -> Result<Self, DbErr>
where where
<Self::Entity as EntityTrait>::Model: IntoActiveModel<Self>, <Self::Entity as EntityTrait>::Model: IntoActiveModel<Self>,
C: ConnectionTrait<'a>,
Self: 'a,
{ {
let am = self; let am = self;
let exec = <Self::Entity as EntityTrait>::insert(am).exec(db); let exec = <Self::Entity as EntityTrait>::insert(am).exec(db);
@ -90,17 +92,22 @@ pub trait ActiveModelTrait: Clone + Debug {
} }
} }
async fn update(self, db: &DatabaseConnection) -> Result<Self, DbErr> { async fn update<'a, C>(self, db: &'a C) -> Result<Self, DbErr>
where
C: ConnectionTrait<'a>,
Self: 'a,
{
let exec = Self::Entity::update(self).exec(db); let exec = Self::Entity::update(self).exec(db);
exec.await exec.await
} }
/// Insert the model if primary key is unset, update otherwise. /// Insert the model if primary key is unset, update otherwise.
/// Only works if the entity has auto increment primary key. /// Only works if the entity has auto increment primary key.
async fn save(self, db: &DatabaseConnection) -> Result<Self, DbErr> async fn save<'a, C>(self, db: &'a C) -> Result<Self, DbErr>
where where
Self: ActiveModelBehavior, Self: ActiveModelBehavior + 'a,
<Self::Entity as EntityTrait>::Model: IntoActiveModel<Self>, <Self::Entity as EntityTrait>::Model: IntoActiveModel<Self>,
C: ConnectionTrait<'a>,
{ {
let mut am = self; let mut am = self;
am = ActiveModelBehavior::before_save(am); am = ActiveModelBehavior::before_save(am);
@ -122,9 +129,10 @@ pub trait ActiveModelTrait: Clone + Debug {
} }
/// Delete an active model by its primary key /// Delete an active model by its primary key
async fn delete(self, db: &DatabaseConnection) -> Result<DeleteResult, DbErr> async fn delete<'a, C>(self, db: &'a C) -> Result<DeleteResult, DbErr>
where where
Self: ActiveModelBehavior, Self: ActiveModelBehavior + 'a,
C: ConnectionTrait<'a>,
{ {
let mut am = self; let mut am = self;
am = ActiveModelBehavior::before_delete(am); am = ActiveModelBehavior::before_delete(am);

View File

@ -510,7 +510,7 @@ pub trait EntityTrait: EntityName {
/// ///
/// ``` /// ```
/// # #[cfg(feature = "mock")] /// # #[cfg(feature = "mock")]
/// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend}; /// # use sea_orm::{entity::*, error::*, query::*, tests_cfg::*, MockDatabase, MockExecResult, Transaction, DbBackend};
/// # /// #
/// # let db = MockDatabase::new(DbBackend::Postgres) /// # let db = MockDatabase::new(DbBackend::Postgres)
/// # .append_exec_results(vec![ /// # .append_exec_results(vec![

View File

@ -1,6 +1,4 @@
use crate::{ use crate::{ActiveModelTrait, ConnectionTrait, DeleteMany, DeleteOne, EntityTrait, Statement, error::*};
error::*, ActiveModelTrait, DatabaseConnection, DeleteMany, DeleteOne, EntityTrait, Statement,
};
use sea_query::DeleteStatement; use sea_query::DeleteStatement;
use std::future::Future; use std::future::Future;
@ -18,10 +16,11 @@ impl<'a, A: 'a> DeleteOne<A>
where where
A: ActiveModelTrait, A: ActiveModelTrait,
{ {
pub fn exec( pub fn exec<C>(
self, self,
db: &'a DatabaseConnection, db: &'a C,
) -> impl Future<Output = Result<DeleteResult, DbErr>> + 'a { ) -> impl Future<Output = Result<DeleteResult, DbErr>> + 'a
where C: ConnectionTrait<'a> {
// so that self is dropped before entering await // so that self is dropped before entering await
exec_delete_only(self.query, db) exec_delete_only(self.query, db)
} }
@ -31,10 +30,11 @@ impl<'a, E> DeleteMany<E>
where where
E: EntityTrait, E: EntityTrait,
{ {
pub fn exec( pub fn exec<C>(
self, self,
db: &'a DatabaseConnection, db: &'a C,
) -> impl Future<Output = Result<DeleteResult, DbErr>> + 'a { ) -> impl Future<Output = Result<DeleteResult, DbErr>> + 'a
where C: ConnectionTrait<'a> {
// so that self is dropped before entering await // so that self is dropped before entering await
exec_delete_only(self.query, db) exec_delete_only(self.query, db)
} }
@ -45,24 +45,27 @@ impl Deleter {
Self { query } Self { query }
} }
pub fn exec( pub fn exec<'a, C>(
self, self,
db: &DatabaseConnection, db: &'a C,
) -> impl Future<Output = Result<DeleteResult, DbErr>> + '_ { ) -> impl Future<Output = Result<DeleteResult, DbErr>> + '_
where C: ConnectionTrait<'a> {
let builder = db.get_database_backend(); let builder = db.get_database_backend();
exec_delete(builder.build(&self.query), db) exec_delete(builder.build(&self.query), db)
} }
} }
async fn exec_delete_only( async fn exec_delete_only<'a, C>(
query: DeleteStatement, query: DeleteStatement,
db: &DatabaseConnection, db: &'a C,
) -> Result<DeleteResult, DbErr> { ) -> Result<DeleteResult, DbErr>
where C: ConnectionTrait<'a> {
Deleter::new(query).exec(db).await Deleter::new(query).exec(db).await
} }
// Only Statement impl Send // Only Statement impl Send
async fn exec_delete(statement: Statement, db: &DatabaseConnection) -> Result<DeleteResult, DbErr> { async fn exec_delete<'a, C>(statement: Statement, db: &C) -> Result<DeleteResult, DbErr>
where C: ConnectionTrait<'a> {
let result = db.execute(statement).await?; let result = db.execute(statement).await?;
Ok(DeleteResult { Ok(DeleteResult {
rows_affected: result.rows_affected(), rows_affected: result.rows_affected(),

View File

@ -1,9 +1,6 @@
use crate::{ use crate::{ActiveModelTrait, ConnectionTrait, EntityTrait, Insert, PrimaryKeyTrait, Statement, TryFromU64, error::*};
error::*, ActiveModelTrait, DatabaseConnection, DbBackend, EntityTrait, Insert,
PrimaryKeyTrait, Statement, TryFromU64,
};
use sea_query::InsertStatement; use sea_query::InsertStatement;
use std::{future::Future, marker::PhantomData}; use std::marker::PhantomData;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Inserter<A> pub struct Inserter<A>
@ -27,11 +24,12 @@ where
A: ActiveModelTrait, A: ActiveModelTrait,
{ {
#[allow(unused_mut)] #[allow(unused_mut)]
pub fn exec<'a>( pub async fn exec<'a, C>(
self, self,
db: &'a DatabaseConnection, db: &'a C,
) -> impl Future<Output = Result<InsertResult<A>, DbErr>> + 'a ) -> Result<InsertResult<A>, DbErr>
where where
C: ConnectionTrait<'a>,
A: 'a, A: 'a,
{ {
// TODO: extract primary key's value from query // TODO: extract primary key's value from query
@ -47,7 +45,7 @@ where
); );
} }
} }
Inserter::<A>::new(query).exec(db) Inserter::<A>::new(query).exec(db).await
// TODO: return primary key if extracted before, otherwise use InsertResult // TODO: return primary key if extracted before, otherwise use InsertResult
} }
} }
@ -63,24 +61,26 @@ where
} }
} }
pub fn exec<'a>( pub async fn exec<'a, C>(
self, self,
db: &'a DatabaseConnection, db: &'a C,
) -> impl Future<Output = Result<InsertResult<A>, DbErr>> + 'a ) -> Result<InsertResult<A>, DbErr>
where where
C: ConnectionTrait<'a>,
A: 'a, A: 'a,
{ {
let builder = db.get_database_backend(); let builder = db.get_database_backend();
exec_insert(builder.build(&self.query), db) exec_insert(builder.build(&self.query), db).await
} }
} }
// Only Statement impl Send // Only Statement impl Send
async fn exec_insert<A>( async fn exec_insert<'a, A, C>(
statement: Statement, statement: Statement,
db: &DatabaseConnection, db: &C,
) -> Result<InsertResult<A>, DbErr> ) -> Result<InsertResult<A>, DbErr>
where where
C: ConnectionTrait<'a>,
A: ActiveModelTrait, A: ActiveModelTrait,
{ {
type PrimaryKey<A> = <<A as ActiveModelTrait>::Entity as EntityTrait>::PrimaryKey; type PrimaryKey<A> = <<A as ActiveModelTrait>::Entity as EntityTrait>::PrimaryKey;
@ -93,13 +93,13 @@ where
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let res = db.query_one(statement).await?.unwrap(); let res = db.query_one(statement).await?.unwrap();
res.try_get_many("", cols.as_ref()).unwrap_or_default() res.try_get_many("", cols.as_ref()).unwrap_or_default()
} },
_ => { _ => {
let last_insert_id = db.execute(statement).await?.last_insert_id(); let last_insert_id = db.execute(statement).await?.last_insert_id();
ValueTypeOf::<A>::try_from_u64(last_insert_id) ValueTypeOf::<A>::try_from_u64(last_insert_id)
.ok() .ok()
.unwrap_or_default() .unwrap_or_default()
} },
}; };
Ok(InsertResult { last_insert_id }) Ok(InsertResult { last_insert_id })
} }

View File

@ -1,4 +1,4 @@
use crate::{error::*, DatabaseConnection, DbBackend, SelectorTrait}; use crate::{ConnectionTrait, SelectorTrait, error::*};
use async_stream::stream; use async_stream::stream;
use futures::Stream; use futures::Stream;
use sea_query::{Alias, Expr, SelectStatement}; use sea_query::{Alias, Expr, SelectStatement};
@ -7,21 +7,23 @@ use std::{marker::PhantomData, pin::Pin};
pub type PinBoxStream<'db, Item> = Pin<Box<dyn Stream<Item = Item> + 'db>>; pub type PinBoxStream<'db, Item> = Pin<Box<dyn Stream<Item = Item> + 'db>>;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Paginator<'db, S> pub struct Paginator<'db, C, S>
where where
C: ConnectionTrait<'db>,
S: SelectorTrait + 'db, S: SelectorTrait + 'db,
{ {
pub(crate) query: SelectStatement, pub(crate) query: SelectStatement,
pub(crate) page: usize, pub(crate) page: usize,
pub(crate) page_size: usize, pub(crate) page_size: usize,
pub(crate) db: &'db DatabaseConnection, pub(crate) db: &'db C,
pub(crate) selector: PhantomData<S>, pub(crate) selector: PhantomData<S>,
} }
// LINT: warn if paginator is used without an order by clause // LINT: warn if paginator is used without an order by clause
impl<'db, S> Paginator<'db, S> impl<'db, C, S> Paginator<'db, C, S>
where where
C: ConnectionTrait<'db>,
S: SelectorTrait + 'db, S: SelectorTrait + 'db,
{ {
/// Fetch a specific page; page index starts from zero /// Fetch a specific page; page index starts from zero
@ -155,7 +157,7 @@ where
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
mod tests { mod tests {
use crate::entity::prelude::*; use crate::entity::prelude::*;
use crate::tests_cfg::*; use crate::{ConnectionTrait, tests_cfg::*};
use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction}; use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction};
use futures::TryStreamExt; use futures::TryStreamExt;
use sea_query::{Alias, Expr, SelectStatement, Value}; use sea_query::{Alias, Expr, SelectStatement, Value};

View File

@ -1,8 +1,8 @@
use crate::{ #[cfg(feature = "sqlx-dep")]
error::*, DatabaseConnection, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, use std::pin::Pin;
ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, use crate::{ConnectionTrait, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, SelectTwoMany, Statement, error::*};
SelectTwoMany, Statement, TryGetableMany, #[cfg(feature = "sqlx-dep")]
}; use futures::{Stream, TryStreamExt};
use sea_query::SelectStatement; use sea_query::SelectStatement;
use std::marker::PhantomData; use std::marker::PhantomData;
@ -235,23 +235,35 @@ where
Selector::<SelectGetableValue<T, C>>::with_columns(self.query) Selector::<SelectGetableValue<T, C>>::with_columns(self.query)
} }
pub async fn one(self, db: &DatabaseConnection) -> Result<Option<E::Model>, DbErr> { pub async fn one<'a, C>(self, db: &C) -> Result<Option<E::Model>, DbErr>
where C: ConnectionTrait<'a> {
self.into_model().one(db).await self.into_model().one(db).await
} }
pub async fn all(self, db: &DatabaseConnection) -> Result<Vec<E::Model>, DbErr> { pub async fn all<'a, C>(self, db: &C) -> Result<Vec<E::Model>, DbErr>
where C: ConnectionTrait<'a> {
self.into_model().all(db).await self.into_model().all(db).await
} }
pub fn paginate( #[cfg(feature = "sqlx-dep")]
pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result<impl Stream<Item=Result<E::Model, DbErr>> + 'b, DbErr>
where
C: ConnectionTrait<'a>,
{
self.into_model().stream(db).await
}
pub fn paginate<'a, C>(
self, self,
db: &DatabaseConnection, db: &'a C,
page_size: usize, page_size: usize,
) -> Paginator<'_, SelectModel<E::Model>> { ) -> Paginator<'a, C, SelectModel<E::Model>>
where C: ConnectionTrait<'a> {
self.into_model().paginate(db, page_size) self.into_model().paginate(db, page_size)
} }
pub async fn count(self, db: &DatabaseConnection) -> Result<usize, DbErr> { pub async fn count<'a, C>(self, db: &'a C) -> Result<usize, DbErr>
where C: ConnectionTrait<'a> {
self.paginate(db, 1).num_items().await self.paginate(db, 1).num_items().await
} }
} }
@ -280,29 +292,41 @@ where
} }
} }
pub async fn one( pub async fn one<'a, C>(
self, self,
db: &DatabaseConnection, db: &C,
) -> Result<Option<(E::Model, Option<F::Model>)>, DbErr> { ) -> Result<Option<(E::Model, Option<F::Model>)>, DbErr>
where C: ConnectionTrait<'a> {
self.into_model().one(db).await self.into_model().one(db).await
} }
pub async fn all( pub async fn all<'a, C>(
self, self,
db: &DatabaseConnection, db: &C,
) -> Result<Vec<(E::Model, Option<F::Model>)>, DbErr> { ) -> Result<Vec<(E::Model, Option<F::Model>)>, DbErr>
where C: ConnectionTrait<'a> {
self.into_model().all(db).await self.into_model().all(db).await
} }
pub fn paginate( #[cfg(feature = "sqlx-dep")]
pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result<impl Stream<Item=Result<(E::Model, Option<F::Model>), DbErr>> + 'b, DbErr>
where
C: ConnectionTrait<'a>,
{
self.into_model().stream(db).await
}
pub fn paginate<'a, C>(
self, self,
db: &DatabaseConnection, db: &'a C,
page_size: usize, page_size: usize,
) -> Paginator<'_, SelectTwoModel<E::Model, F::Model>> { ) -> Paginator<'a, C, SelectTwoModel<E::Model, F::Model>>
where C: ConnectionTrait<'a> {
self.into_model().paginate(db, page_size) self.into_model().paginate(db, page_size)
} }
pub async fn count(self, db: &DatabaseConnection) -> Result<usize, DbErr> { pub async fn count<'a, C>(self, db: &'a C) -> Result<usize, DbErr>
where C: ConnectionTrait<'a> {
self.paginate(db, 1).num_items().await self.paginate(db, 1).num_items().await
} }
} }
@ -331,17 +355,27 @@ where
} }
} }
pub async fn one( pub async fn one<'a, C>(
self, self,
db: &DatabaseConnection, db: &C,
) -> Result<Option<(E::Model, Option<F::Model>)>, DbErr> { ) -> Result<Option<(E::Model, Option<F::Model>)>, DbErr>
where C: ConnectionTrait<'a> {
self.into_model().one(db).await self.into_model().one(db).await
} }
pub async fn all( #[cfg(feature = "sqlx-dep")]
pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result<impl Stream<Item=Result<(E::Model, Option<F::Model>), DbErr>> + 'b, DbErr>
where
C: ConnectionTrait<'a>,
{
self.into_model().stream(db).await
}
pub async fn all<'a, C>(
self, self,
db: &DatabaseConnection, db: &C,
) -> Result<Vec<(E::Model, Vec<F::Model>)>, DbErr> { ) -> Result<Vec<(E::Model, Vec<F::Model>)>, DbErr>
where C: ConnectionTrait<'a> {
let rows = self.into_model().all(db).await?; let rows = self.into_model().all(db).await?;
Ok(consolidate_query_result::<E, F>(rows)) Ok(consolidate_query_result::<E, F>(rows))
} }
@ -376,7 +410,8 @@ where
} }
} }
pub async fn one(mut self, db: &DatabaseConnection) -> Result<Option<S::Item>, DbErr> { pub async fn one<'a, C>(mut self, db: &C) -> Result<Option<S::Item>, DbErr>
where C: ConnectionTrait<'a> {
let builder = db.get_database_backend(); let builder = db.get_database_backend();
self.query.limit(1); self.query.limit(1);
let row = db.query_one(builder.build(&self.query)).await?; let row = db.query_one(builder.build(&self.query)).await?;
@ -386,7 +421,8 @@ where
} }
} }
pub async fn all(self, db: &DatabaseConnection) -> Result<Vec<S::Item>, DbErr> { pub async fn all<'a, C>(self, db: &C) -> Result<Vec<S::Item>, DbErr>
where C: ConnectionTrait<'a> {
let builder = db.get_database_backend(); let builder = db.get_database_backend();
let rows = db.query_all(builder.build(&self.query)).await?; let rows = db.query_all(builder.build(&self.query)).await?;
let mut models = Vec::new(); let mut models = Vec::new();
@ -396,7 +432,21 @@ where
Ok(models) Ok(models)
} }
pub fn paginate(self, db: &DatabaseConnection, page_size: usize) -> Paginator<'_, S> { #[cfg(feature = "sqlx-dep")]
pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result<Pin<Box<dyn Stream<Item=Result<S::Item, DbErr>> + 'b>>, DbErr>
where
C: ConnectionTrait<'a>,
S: 'b,
{
let builder = db.get_database_backend();
let stream = db.stream(builder.build(&self.query)).await?;
Ok(Box::pin(stream.and_then(|row| {
futures::future::ready(S::from_raw_query_result(row))
})))
}
pub fn paginate<'a, C>(self, db: &'a C, page_size: usize) -> Paginator<'a, C, S>
where C: ConnectionTrait<'a> {
Paginator { Paginator {
query: self.query, query: self.query,
page: 0, page: 0,
@ -606,7 +656,8 @@ where
/// ),] /// ),]
/// ); /// );
/// ``` /// ```
pub async fn one(self, db: &DatabaseConnection) -> Result<Option<S::Item>, DbErr> { pub async fn one<'a, C>(self, db: &C) -> Result<Option<S::Item>, DbErr>
where C: ConnectionTrait<'a> {
let row = db.query_one(self.stmt).await?; let row = db.query_one(self.stmt).await?;
match row { match row {
Some(row) => Ok(Some(S::from_raw_query_result(row)?)), Some(row) => Ok(Some(S::from_raw_query_result(row)?)),
@ -645,7 +696,8 @@ where
/// ),] /// ),]
/// ); /// );
/// ``` /// ```
pub async fn all(self, db: &DatabaseConnection) -> Result<Vec<S::Item>, DbErr> { pub async fn all<'a, C>(self, db: &C) -> Result<Vec<S::Item>, DbErr>
where C: ConnectionTrait<'a> {
let rows = db.query_all(self.stmt).await?; let rows = db.query_all(self.stmt).await?;
let mut models = Vec::new(); let mut models = Vec::new();
for row in rows.into_iter() { for row in rows.into_iter() {

View File

@ -1,6 +1,4 @@
use crate::{ use crate::{ActiveModelTrait, ConnectionTrait, EntityTrait, Statement, UpdateMany, UpdateOne, error::*};
error::*, ActiveModelTrait, DatabaseConnection, EntityTrait, Statement, UpdateMany, UpdateOne,
};
use sea_query::UpdateStatement; use sea_query::UpdateStatement;
use std::future::Future; use std::future::Future;
@ -18,9 +16,10 @@ impl<'a, A: 'a> UpdateOne<A>
where where
A: ActiveModelTrait, A: ActiveModelTrait,
{ {
pub fn exec(self, db: &'a DatabaseConnection) -> impl Future<Output = Result<A, DbErr>> + 'a { pub async fn exec<'b, C>(self, db: &'b C) -> Result<A, DbErr>
where C: ConnectionTrait<'b> {
// so that self is dropped before entering await // so that self is dropped before entering await
exec_update_and_return_original(self.query, self.model, db) exec_update_and_return_original(self.query, self.model, db).await
} }
} }
@ -28,10 +27,11 @@ impl<'a, E> UpdateMany<E>
where where
E: EntityTrait, E: EntityTrait,
{ {
pub fn exec( pub fn exec<C>(
self, self,
db: &'a DatabaseConnection, db: &'a C,
) -> impl Future<Output = Result<UpdateResult, DbErr>> + 'a { ) -> impl Future<Output = Result<UpdateResult, DbErr>> + 'a
where C: ConnectionTrait<'a> {
// so that self is dropped before entering await // so that self is dropped before entering await
exec_update_only(self.query, db) exec_update_only(self.query, db)
} }
@ -42,36 +42,40 @@ impl Updater {
Self { query } Self { query }
} }
pub fn exec( pub async fn exec<'a, C>(
self, self,
db: &DatabaseConnection, db: &'a C,
) -> impl Future<Output = Result<UpdateResult, DbErr>> + '_ { ) -> Result<UpdateResult, DbErr>
where C: ConnectionTrait<'a> {
let builder = db.get_database_backend(); let builder = db.get_database_backend();
exec_update(builder.build(&self.query), db) exec_update(builder.build(&self.query), db).await
} }
} }
async fn exec_update_only( async fn exec_update_only<'a, C>(
query: UpdateStatement, query: UpdateStatement,
db: &DatabaseConnection, db: &'a C,
) -> Result<UpdateResult, DbErr> { ) -> Result<UpdateResult, DbErr>
where C: ConnectionTrait<'a> {
Updater::new(query).exec(db).await Updater::new(query).exec(db).await
} }
async fn exec_update_and_return_original<A>( async fn exec_update_and_return_original<'a, A, C>(
query: UpdateStatement, query: UpdateStatement,
model: A, model: A,
db: &DatabaseConnection, db: &'a C,
) -> Result<A, DbErr> ) -> Result<A, DbErr>
where where
A: ActiveModelTrait, A: ActiveModelTrait,
C: ConnectionTrait<'a>,
{ {
Updater::new(query).exec(db).await?; Updater::new(query).exec(db).await?;
Ok(model) Ok(model)
} }
// Only Statement impl Send // Only Statement impl Send
async fn exec_update(statement: Statement, db: &DatabaseConnection) -> Result<UpdateResult, DbErr> { async fn exec_update<'a, C>(statement: Statement, db: &'a C) -> Result<UpdateResult, DbErr>
where C: ConnectionTrait<'a> {
let result = db.execute(statement).await?; let result = db.execute(statement).await?;
Ok(UpdateResult { Ok(UpdateResult {
rows_affected: result.rows_affected(), rows_affected: result.rows_affected(),

View File

@ -20,4 +20,4 @@ pub use select::*;
pub use traits::*; pub use traits::*;
pub use update::*; pub use update::*;
pub use crate::{InsertResult, Statement, UpdateResult, Value, Values}; pub use crate::{InsertResult, Statement, UpdateResult, Value, Values, ConnectionTrait};

View File

@ -1,6 +1,6 @@
pub mod common; pub mod common;
pub use sea_orm::{entity::*, error::*, sea_query, tests_cfg::*, Database, DbConn}; pub use sea_orm::{entity::*, error::*, query::*, sea_query, tests_cfg::*, Database, DbConn};
// cargo test --features sqlx-sqlite,runtime-async-std-native-tls --test basic // cargo test --features sqlx-sqlite,runtime-async-std-native-tls --test basic
#[sea_orm_macros::test] #[sea_orm_macros::test]

View File

@ -1,4 +1,4 @@
use sea_orm::{Database, DatabaseBackend, DatabaseConnection, Statement}; use sea_orm::{Database, DatabaseBackend, DatabaseConnection, ConnectionTrait, Statement};
pub mod schema; pub mod schema;
pub use schema::*; pub use schema::*;

View File

@ -1,6 +1,6 @@
pub use super::super::bakery_chain::*; pub use super::super::bakery_chain::*;
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use sea_orm::{error::*, sea_query, DbBackend, DbConn, EntityTrait, ExecResult, Schema}; use sea_orm::{error::*, sea_query, ConnectionTrait, DbBackend, DbConn, EntityTrait, ExecResult, Schema};
use sea_query::{ use sea_query::{
Alias, ColumnDef, ForeignKey, ForeignKeyAction, Index, Table, TableCreateStatement, Alias, ColumnDef, ForeignKey, ForeignKeyAction, Index, Table, TableCreateStatement,
}; };

View File

@ -2,7 +2,7 @@ pub mod common;
pub use common::{bakery_chain::*, setup::*, TestContext}; pub use common::{bakery_chain::*, setup::*, TestContext};
pub use sea_orm::entity::*; pub use sea_orm::entity::*;
pub use sea_orm::QueryFilter; pub use sea_orm::{QueryFilter, ConnectionTrait};
// Run the test locally: // Run the test locally:
// DATABASE_URL="mysql://root:@localhost" cargo test --features sqlx-mysql,runtime-async-std --test query_tests // DATABASE_URL="mysql://root:@localhost" cargo test --features sqlx-mysql,runtime-async-std --test query_tests

View File

@ -179,7 +179,7 @@ async fn find_baker_least_sales(db: &DatabaseConnection) -> Option<baker::Model>
let mut results: Vec<LeastSalesBakerResult> = select let mut results: Vec<LeastSalesBakerResult> = select
.into_model::<SelectResult>() .into_model::<SelectResult>()
.all(&db) .all(db)
.await .await
.unwrap() .unwrap()
.into_iter() .into_iter()

37
tests/stream_tests.rs Normal file
View File

@ -0,0 +1,37 @@
pub mod common;
pub use common::{bakery_chain::*, setup::*, TestContext};
pub use sea_orm::entity::*;
pub use sea_orm::{QueryFilter, ConnectionTrait, DbErr};
use futures::StreamExt;
#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
feature = "sqlx-sqlite",
feature = "sqlx-postgres"
))]
pub async fn stream() -> Result<(), DbErr> {
let ctx = TestContext::new("stream").await;
let bakery = bakery::ActiveModel {
name: Set("SeaSide Bakery".to_owned()),
profit_margin: Set(10.4),
..Default::default()
}
.save(&ctx.db)
.await?;
let result = Bakery::find_by_id(bakery.id.clone().unwrap())
.stream(&ctx.db)
.await?
.next()
.await
.unwrap()?;
assert_eq!(result.id, bakery.id.unwrap());
ctx.delete().await;
Ok(())
}

View File

@ -0,0 +1,90 @@
pub mod common;
pub use common::{bakery_chain::*, setup::*, TestContext};
use sea_orm::{DatabaseTransaction, DbErr};
pub use sea_orm::entity::*;
pub use sea_orm::{QueryFilter, ConnectionTrait};
#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
feature = "sqlx-sqlite",
feature = "sqlx-postgres"
))]
pub async fn transaction() {
let ctx = TestContext::new("transaction_test").await;
ctx.db.transaction::<_, _, DbErr>(|txn| Box::pin(async move {
let _ = bakery::ActiveModel {
name: Set("SeaSide Bakery".to_owned()),
profit_margin: Set(10.4),
..Default::default()
}
.save(txn)
.await?;
let _ = bakery::ActiveModel {
name: Set("Top Bakery".to_owned()),
profit_margin: Set(15.0),
..Default::default()
}
.save(txn)
.await?;
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains("Bakery"))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 2);
Ok(())
})).await.unwrap();
ctx.delete().await;
}
#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
feature = "sqlx-sqlite",
feature = "sqlx-postgres"
))]
pub async fn transaction_with_reference() {
let ctx = TestContext::new("transaction_with_reference_test").await;
let name1 = "SeaSide Bakery";
let name2 = "Top Bakery";
let search_name = "Bakery";
ctx.db.transaction(|txn| _transaction_with_reference(txn, name1, name2, search_name)).await.unwrap();
ctx.delete().await;
}
fn _transaction_with_reference<'a>(txn: &'a DatabaseTransaction, name1: &'a str, name2: &'a str, search_name: &'a str) -> std::pin::Pin<Box<dyn std::future::Future<Output=Result<(), DbErr>> + Send + 'a>> {
Box::pin(async move {
let _ = bakery::ActiveModel {
name: Set(name1.to_owned()),
profit_margin: Set(10.4),
..Default::default()
}
.save(txn)
.await?;
let _ = bakery::ActiveModel {
name: Set(name2.to_owned()),
profit_margin: Set(15.0),
..Default::default()
}
.save(txn)
.await?;
let bakeries = Bakery::find()
.filter(bakery::Column::Name.contains(search_name))
.all(txn)
.await?;
assert_eq!(bakeries.len(), 2);
Ok(())
})
}