Returning support for SQLite

This commit is contained in:
Billy Chan 2021-11-09 16:10:52 +08:00
parent 533c3cf175
commit ec637b26a0
No known key found for this signature in database
GPG Key ID: A2D690CAC7DF3CC7
4 changed files with 62 additions and 23 deletions

View File

@ -228,9 +228,9 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
true true
} }
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(_) => { DatabaseConnection::SqlxSqlitePoolConnection(conn) => {
// Supported by SQLite on or after version 3.35.0 (2021-03-12) // Supported by SQLite on or after version 3.35.0 (2021-03-12)
false conn.support_returning
} }
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() { DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() {
@ -255,9 +255,9 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
true true
} }
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(_) => { DatabaseConnection::SqlxSqlitePoolConnection(conn) => {
// Supported by SQLite on or after version 3.35.0 (2021-03-12) // Supported by SQLite on or after version 3.35.0 (2021-03-12)
false conn.support_returning
} }
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() { DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() {
@ -301,7 +301,7 @@ impl DatabaseConnection {
#[cfg(feature = "sqlx-postgres")] #[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => "".to_string(), DatabaseConnection::SqlxPostgresPoolConnection(_) => "".to_string(),
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(_) => "".to_string(), DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.version.to_string(),
DatabaseConnection::Disconnected => panic!("Disconnected"), DatabaseConnection::Disconnected => panic!("Disconnected"),
_ => unimplemented!(), _ => unimplemented!(),
} }

View File

@ -54,11 +54,12 @@ impl DatabaseTransaction {
#[cfg(feature = "sqlx-sqlite")] #[cfg(feature = "sqlx-sqlite")]
pub(crate) async fn new_sqlite( pub(crate) async fn new_sqlite(
inner: PoolConnection<sqlx::Sqlite>, inner: PoolConnection<sqlx::Sqlite>,
support_returning: bool,
) -> Result<DatabaseTransaction, DbErr> { ) -> Result<DatabaseTransaction, DbErr> {
Self::begin( Self::begin(
Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), Arc::new(Mutex::new(InnerConnection::Sqlite(inner))),
DbBackend::Sqlite, DbBackend::Sqlite,
false, support_returning,
) )
.await .await
} }
@ -378,7 +379,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
} }
DbBackend::Sqlite => { DbBackend::Sqlite => {
// Supported by SQLite on or after version 3.35.0 (2021-03-12) // Supported by SQLite on or after version 3.35.0 (2021-03-12)
false self.support_returning
} }
} }
} }
@ -395,7 +396,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
} }
DbBackend::Sqlite => { DbBackend::Sqlite => {
// Supported by SQLite on or after version 3.35.0 (2021-03-12) // Supported by SQLite on or after version 3.35.0 (2021-03-12)
false self.support_returning
} }
} }
} }

View File

@ -1,8 +1,9 @@
use regex::Regex;
use std::{future::Future, pin::Pin}; use std::{future::Future, pin::Pin};
use sqlx::{ use sqlx::{
sqlite::{SqliteArguments, SqliteConnectOptions, SqliteQueryResult, SqliteRow}, sqlite::{SqliteArguments, SqliteConnectOptions, SqliteQueryResult, SqliteRow},
Sqlite, SqlitePool, Row, Sqlite, SqlitePool,
}; };
sea_query::sea_query_driver_sqlite!(); sea_query::sea_query_driver_sqlite!();
@ -10,7 +11,7 @@ use sea_query_driver_sqlite::bind_query;
use crate::{ use crate::{
debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction,
QueryStream, Statement, TransactionError, DbBackend, QueryStream, Statement, TransactionError,
}; };
use super::sqlx_common::*; use super::sqlx_common::*;
@ -23,6 +24,8 @@ pub struct SqlxSqliteConnector;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct SqlxSqlitePoolConnection { pub struct SqlxSqlitePoolConnection {
pool: SqlitePool, pool: SqlitePool,
pub(crate) version: String,
pub(crate) support_returning: bool,
} }
impl SqlxSqliteConnector { impl SqlxSqliteConnector {
@ -46,9 +49,7 @@ impl SqlxSqliteConnector {
options.max_connections(1); options.max_connections(1);
} }
if let Ok(pool) = options.pool_options().connect_with(opt).await { if let Ok(pool) = options.pool_options().connect_with(opt).await {
Ok(DatabaseConnection::SqlxSqlitePoolConnection( into_db_connection(pool).await
SqlxSqlitePoolConnection { pool },
))
} else { } else {
Err(DbErr::Conn("Failed to connect.".to_owned())) Err(DbErr::Conn("Failed to connect.".to_owned()))
} }
@ -57,8 +58,8 @@ impl SqlxSqliteConnector {
impl SqlxSqliteConnector { impl SqlxSqliteConnector {
/// Instantiate a sqlx pool connection to a [DatabaseConnection] /// Instantiate a sqlx pool connection to a [DatabaseConnection]
pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection { pub async fn from_sqlx_sqlite_pool(pool: SqlitePool) -> Result<DatabaseConnection, DbErr> {
DatabaseConnection::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection { pool }) into_db_connection(pool).await
} }
} }
@ -133,7 +134,7 @@ impl SqlxSqlitePoolConnection {
/// Bundle a set of SQL statements that execute together. /// Bundle a set of SQL statements that execute together.
pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> { pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
if let Ok(conn) = self.pool.acquire().await { if let Ok(conn) = self.pool.acquire().await {
DatabaseTransaction::new_sqlite(conn).await DatabaseTransaction::new_sqlite(conn, self.support_returning).await
} else { } else {
Err(DbErr::Query( Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(), "Failed to acquire connection from pool.".to_owned(),
@ -152,7 +153,7 @@ impl SqlxSqlitePoolConnection {
E: std::error::Error + Send, E: std::error::Error + Send,
{ {
if let Ok(conn) = self.pool.acquire().await { if let Ok(conn) = self.pool.acquire().await {
let transaction = DatabaseTransaction::new_sqlite(conn) let transaction = DatabaseTransaction::new_sqlite(conn, self.support_returning)
.await .await
.map_err(|e| TransactionError::Connection(e))?; .map_err(|e| TransactionError::Connection(e))?;
transaction.run(callback).await transaction.run(callback).await
@ -187,3 +188,44 @@ pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, Sql
} }
query query
} }
async fn into_db_connection(pool: SqlitePool) -> Result<DatabaseConnection, DbErr> {
let (version, support_returning) = parse_support_returning(&pool).await?;
Ok(DatabaseConnection::SqlxSqlitePoolConnection(
SqlxSqlitePoolConnection {
pool,
version,
support_returning,
},
))
}
async fn parse_support_returning(pool: &SqlitePool) -> Result<(String, bool), DbErr> {
let stmt = Statement::from_string(
DbBackend::Sqlite,
r#"SELECT sqlite_version() AS version"#.to_owned(),
);
let query = sqlx_query(&stmt);
let row = query
.fetch_one(pool)
.await
.map_err(sqlx_error_to_query_err)?;
let version: String = row.try_get("version").map_err(sqlx_error_to_query_err)?;
let regex = Regex::new(r"^(\d+)?.(\d+)?.(\*|\d+)").unwrap();
let captures = regex.captures(&version).unwrap();
macro_rules! parse_captures {
( $idx: expr ) => {
captures.get($idx).map_or(0, |m| {
m.as_str()
.parse::<usize>()
.map_err(|e| DbErr::Conn(e.to_string()))
.unwrap()
})
};
}
let ver_major = parse_captures!(1);
let ver_minor = parse_captures!(2);
// Supported if it's version 3.35.0 (2021-03-12) or after
let support_returning = ver_major >= 3 && ver_minor >= 35;
Ok((version, support_returning))
}

View File

@ -34,17 +34,12 @@ async fn main() -> Result<(), DbErr> {
let mut returning = Query::select(); let mut returning = Query::select();
returning.columns(vec![Column::Id, Column::Name, Column::ProfitMargin]); returning.columns(vec![Column::Id, Column::Name, Column::ProfitMargin]);
if db.returning_on_insert() {
insert.returning(returning.clone());
}
if db.returning_on_update() {
update.returning(returning.clone());
}
create_tables(db).await?; create_tables(db).await?;
println!("db_version: {:#?}", db.version()); println!("db_version: {:#?}", db.version());
if db.returning_on_insert() { if db.returning_on_insert() {
insert.returning(returning.clone());
let insert_res = db let insert_res = db
.query_one(builder.build(&insert)) .query_one(builder.build(&insert))
.await? .await?
@ -57,6 +52,7 @@ async fn main() -> Result<(), DbErr> {
assert!(insert_res.rows_affected() > 0); assert!(insert_res.rows_affected() > 0);
} }
if db.returning_on_update() { if db.returning_on_update() {
update.returning(returning.clone());
let update_res = db let update_res = db
.query_one(builder.build(&update)) .query_one(builder.build(&update))
.await? .await?