diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 2f78bdd7..476fe8e0 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -228,9 +228,9 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { true } #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(_) => { + DatabaseConnection::SqlxSqlitePoolConnection(conn) => { // Supported by SQLite on or after version 3.35.0 (2021-03-12) - false + conn.support_returning } #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() { @@ -255,9 +255,9 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { true } #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(_) => { + DatabaseConnection::SqlxSqlitePoolConnection(conn) => { // Supported by SQLite on or after version 3.35.0 (2021-03-12) - false + conn.support_returning } #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() { @@ -301,7 +301,7 @@ impl DatabaseConnection { #[cfg(feature = "sqlx-postgres")] DatabaseConnection::SqlxPostgresPoolConnection(_) => "".to_string(), #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(_) => "".to_string(), + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.version.to_string(), DatabaseConnection::Disconnected => panic!("Disconnected"), _ => unimplemented!(), } diff --git a/src/database/transaction.rs b/src/database/transaction.rs index 71573f0d..168461c4 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -54,11 +54,12 @@ impl DatabaseTransaction { #[cfg(feature = "sqlx-sqlite")] pub(crate) async fn new_sqlite( inner: PoolConnection, + support_returning: bool, ) -> Result { Self::begin( Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), DbBackend::Sqlite, - false, + support_returning, ) .await } @@ -378,7 +379,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { } DbBackend::Sqlite => { // Supported by SQLite on or after version 3.35.0 (2021-03-12) - false + self.support_returning } } } @@ -395,7 +396,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { } DbBackend::Sqlite => { // Supported by SQLite on or after version 3.35.0 (2021-03-12) - false + self.support_returning } } } diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index 69eee575..c98a3dea 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -1,8 +1,9 @@ +use regex::Regex; use std::{future::Future, pin::Pin}; use sqlx::{ sqlite::{SqliteArguments, SqliteConnectOptions, SqliteQueryResult, SqliteRow}, - Sqlite, SqlitePool, + Row, Sqlite, SqlitePool, }; sea_query::sea_query_driver_sqlite!(); @@ -10,7 +11,7 @@ use sea_query_driver_sqlite::bind_query; use crate::{ debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, - QueryStream, Statement, TransactionError, + DbBackend, QueryStream, Statement, TransactionError, }; use super::sqlx_common::*; @@ -23,6 +24,8 @@ pub struct SqlxSqliteConnector; #[derive(Debug, Clone)] pub struct SqlxSqlitePoolConnection { pool: SqlitePool, + pub(crate) version: String, + pub(crate) support_returning: bool, } impl SqlxSqliteConnector { @@ -46,9 +49,7 @@ impl SqlxSqliteConnector { options.max_connections(1); } if let Ok(pool) = options.pool_options().connect_with(opt).await { - Ok(DatabaseConnection::SqlxSqlitePoolConnection( - SqlxSqlitePoolConnection { pool }, - )) + into_db_connection(pool).await } else { Err(DbErr::Conn("Failed to connect.".to_owned())) } @@ -57,8 +58,8 @@ impl SqlxSqliteConnector { impl SqlxSqliteConnector { /// Instantiate a sqlx pool connection to a [DatabaseConnection] - pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection { - DatabaseConnection::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection { pool }) + pub async fn from_sqlx_sqlite_pool(pool: SqlitePool) -> Result { + into_db_connection(pool).await } } @@ -133,7 +134,7 @@ impl SqlxSqlitePoolConnection { /// Bundle a set of SQL statements that execute together. pub async fn begin(&self) -> Result { if let Ok(conn) = self.pool.acquire().await { - DatabaseTransaction::new_sqlite(conn).await + DatabaseTransaction::new_sqlite(conn, self.support_returning).await } else { Err(DbErr::Query( "Failed to acquire connection from pool.".to_owned(), @@ -152,7 +153,7 @@ impl SqlxSqlitePoolConnection { E: std::error::Error + Send, { if let Ok(conn) = self.pool.acquire().await { - let transaction = DatabaseTransaction::new_sqlite(conn) + let transaction = DatabaseTransaction::new_sqlite(conn, self.support_returning) .await .map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await @@ -187,3 +188,44 @@ pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, Sql } query } + +async fn into_db_connection(pool: SqlitePool) -> Result { + let (version, support_returning) = parse_support_returning(&pool).await?; + Ok(DatabaseConnection::SqlxSqlitePoolConnection( + SqlxSqlitePoolConnection { + pool, + version, + support_returning, + }, + )) +} + +async fn parse_support_returning(pool: &SqlitePool) -> Result<(String, bool), DbErr> { + let stmt = Statement::from_string( + DbBackend::Sqlite, + r#"SELECT sqlite_version() AS version"#.to_owned(), + ); + let query = sqlx_query(&stmt); + let row = query + .fetch_one(pool) + .await + .map_err(sqlx_error_to_query_err)?; + let version: String = row.try_get("version").map_err(sqlx_error_to_query_err)?; + let regex = Regex::new(r"^(\d+)?.(\d+)?.(\*|\d+)").unwrap(); + let captures = regex.captures(&version).unwrap(); + macro_rules! parse_captures { + ( $idx: expr ) => { + captures.get($idx).map_or(0, |m| { + m.as_str() + .parse::() + .map_err(|e| DbErr::Conn(e.to_string())) + .unwrap() + }) + }; + } + let ver_major = parse_captures!(1); + let ver_minor = parse_captures!(2); + // Supported if it's version 3.35.0 (2021-03-12) or after + let support_returning = ver_major >= 3 && ver_minor >= 35; + Ok((version, support_returning)) +} diff --git a/tests/returning_tests.rs b/tests/returning_tests.rs index 4f2a5a88..94e8a399 100644 --- a/tests/returning_tests.rs +++ b/tests/returning_tests.rs @@ -34,17 +34,12 @@ async fn main() -> Result<(), DbErr> { let mut returning = Query::select(); returning.columns(vec![Column::Id, Column::Name, Column::ProfitMargin]); - if db.returning_on_insert() { - insert.returning(returning.clone()); - } - if db.returning_on_update() { - update.returning(returning.clone()); - } create_tables(db).await?; println!("db_version: {:#?}", db.version()); if db.returning_on_insert() { + insert.returning(returning.clone()); let insert_res = db .query_one(builder.build(&insert)) .await? @@ -57,6 +52,7 @@ async fn main() -> Result<(), DbErr> { assert!(insert_res.rows_affected() > 0); } if db.returning_on_update() { + update.returning(returning.clone()); let update_res = db .query_one(builder.build(&update)) .await?