diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 9e3bb66e..4e1259c3 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -28,6 +28,19 @@ jobs: ports: - "3306:3306" options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 + postgres: + image: postgres:11 + env: + POSTGRES_HOST: 127.0.0.1 + POSTGRES_USER: root + POSTGRES_PASSWORD: root + ports: + - "5432:5432" + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 steps: - uses: actions/checkout@v2 diff --git a/Cargo.toml b/Cargo.toml index 436b5042..bfcca7c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,12 +64,13 @@ maplit = { version = "^1" } rust_decimal_macros = { version = "^1" } sea-orm = { path = ".", features = [ + "sqlx-postgres", + "sqlx-mysql", "sqlx-sqlite", "sqlx-json", "sqlx-chrono", "sqlx-decimal", "runtime-async-std-native-tls", - "sqlx-mysql", ] } [features] diff --git a/src/database/connection.rs b/src/database/connection.rs index ee48ed43..4215d8fc 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -4,6 +4,8 @@ use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQue pub enum DatabaseConnection { #[cfg(feature = "sqlx-mysql")] SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection), + #[cfg(feature = "sqlx-postgres")] + SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection), #[cfg(feature = "sqlx-sqlite")] SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection), #[cfg(feature = "mock")] @@ -36,6 +38,8 @@ impl std::fmt::Debug for DatabaseConnection { match self { #[cfg(feature = "sqlx-mysql")] Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection", + #[cfg(feature = "sqlx-postgres")] + Self::SqlxPostgresPoolConnection(_) => "SqlxMySqlPoolConnection", #[cfg(feature = "sqlx-sqlite")] Self::SqlxSqlitePoolConnection(_) => "SqlxSqlitePoolConnection", #[cfg(feature = "mock")] @@ -51,6 +55,8 @@ impl DatabaseConnection { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres, #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(_) => DbBackend::Sqlite, #[cfg(feature = "mock")] @@ -63,6 +69,8 @@ impl DatabaseConnection { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await, #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await, #[cfg(feature = "mock")] @@ -75,6 +83,8 @@ impl DatabaseConnection { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await, #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await, #[cfg(feature = "mock")] @@ -87,6 +97,8 @@ impl DatabaseConnection { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await, #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await, #[cfg(feature = "mock")] diff --git a/src/database/mock.rs b/src/database/mock.rs index 2ca13d33..e35280cc 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -1,5 +1,5 @@ use crate::{ - error::*, DbBackend, DatabaseConnection, EntityTrait, ExecResult, ExecResultHolder, Iden, + error::*, DatabaseConnection, DbBackend, EntityTrait, ExecResult, ExecResultHolder, Iden, Iterable, MockDatabaseConnection, MockDatabaseTrait, ModelTrait, QueryResult, QueryResultRow, Statement, Transaction, }; diff --git a/src/database/mod.rs b/src/database/mod.rs index a5182a34..36f825bb 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -21,6 +21,10 @@ impl Database { if crate::SqlxMySqlConnector::accepts(string) { return crate::SqlxMySqlConnector::connect(string).await; } + #[cfg(feature = "sqlx-postgres")] + if crate::SqlxPostgresConnector::accepts(string) { + return crate::SqlxPostgresConnector::connect(string).await; + } #[cfg(feature = "sqlx-sqlite")] if crate::SqlxSqliteConnector::accepts(string) { return crate::SqlxSqliteConnector::connect(string).await; diff --git a/src/database/statement.rs b/src/database/statement.rs index 91a1c3c8..12b07487 100644 --- a/src/database/statement.rs +++ b/src/database/statement.rs @@ -1,7 +1,6 @@ use crate::DbBackend; -use sea_query::{ - inject_parameters, MysqlQueryBuilder, PostgresQueryBuilder, SqliteQueryBuilder, Value, Values -}; +use sea_query::{inject_parameters, MysqlQueryBuilder, PostgresQueryBuilder, SqliteQueryBuilder}; +pub use sea_query::{Value, Values}; use std::fmt; #[derive(Debug, Clone, PartialEq)] diff --git a/src/driver/mock.rs b/src/driver/mock.rs index e16185ee..7de6aaaf 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -1,6 +1,6 @@ use crate::{ - debug_print, error::*, DbBackend, DatabaseConnection, ExecResult, MockDatabase, - QueryResult, Statement, Transaction, + debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult, + Statement, Transaction, }; use std::sync::{ atomic::{AtomicUsize, Ordering}, @@ -30,6 +30,10 @@ impl MockDatabaseConnector { if crate::SqlxMySqlConnector::accepts(string) { return true; } + #[cfg(feature = "sqlx-postgres")] + if crate::SqlxPostgresConnector::accepts(string) { + return true; + } #[cfg(feature = "sqlx-sqlite")] if crate::SqlxSqliteConnector::accepts(string) { return true; @@ -50,6 +54,10 @@ impl MockDatabaseConnector { if crate::SqlxMySqlConnector::accepts(string) { return connect_mock_db!(DbBackend::MySql); } + #[cfg(feature = "sqlx-postgres")] + if crate::SqlxPostgresConnector::accepts(string) { + return connect_mock_db!(DbBackend::Postgres); + } #[cfg(feature = "sqlx-sqlite")] if crate::SqlxSqliteConnector::accepts(string) { return connect_mock_db!(DbBackend::Sqlite); diff --git a/src/driver/mod.rs b/src/driver/mod.rs index 904e06a9..6f6cfb64 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -4,6 +4,8 @@ mod mock; mod sqlx_common; #[cfg(feature = "sqlx-mysql")] mod sqlx_mysql; +#[cfg(feature = "sqlx-postgres")] +mod sqlx_postgres; #[cfg(feature = "sqlx-sqlite")] mod sqlx_sqlite; @@ -13,5 +15,7 @@ pub use mock::*; pub use sqlx_common::*; #[cfg(feature = "sqlx-mysql")] pub use sqlx_mysql::*; +#[cfg(feature = "sqlx-postgres")] +pub use sqlx_postgres::*; #[cfg(feature = "sqlx-sqlite")] pub use sqlx_sqlite::*; diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs new file mode 100644 index 00000000..a283ddf3 --- /dev/null +++ b/src/driver/sqlx_postgres.rs @@ -0,0 +1,129 @@ +use sqlx::{ + postgres::{PgArguments, PgQueryResult, PgRow}, + PgPool, Postgres, +}; + +sea_query::sea_query_driver_postgres!(); +use sea_query_driver_postgres::bind_query; + +use crate::{debug_print, error::*, executor::*, DatabaseConnection, Statement}; + +use super::sqlx_common::*; + +pub struct SqlxPostgresConnector; + +pub struct SqlxPostgresPoolConnection { + pool: PgPool, +} + +impl SqlxPostgresConnector { + pub fn accepts(string: &str) -> bool { + string.starts_with("postgres://") + } + + pub async fn connect(string: &str) -> Result { + if let Ok(pool) = PgPool::connect(string).await { + Ok(DatabaseConnection::SqlxPostgresPoolConnection( + SqlxPostgresPoolConnection { pool }, + )) + } else { + Err(DbErr::Conn("Failed to connect.".to_owned())) + } + } +} + +impl SqlxPostgresConnector { + pub fn from_sqlx_postgres_pool(pool: PgPool) -> DatabaseConnection { + DatabaseConnection::SqlxPostgresPoolConnection(SqlxPostgresPoolConnection { pool }) + } +} + +impl SqlxPostgresPoolConnection { + pub async fn execute(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + let query = sqlx_query(&stmt); + if let Ok(conn) = &mut self.pool.acquire().await { + match query.execute(conn).await { + Ok(res) => Ok(res.into()), + Err(err) => Err(sqlx_error_to_exec_err(err)), + } + } else { + Err(DbErr::Exec( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn query_one(&self, stmt: Statement) -> Result, DbErr> { + debug_print!("{}", stmt); + + let query = sqlx_query(&stmt); + if let Ok(conn) = &mut self.pool.acquire().await { + match query.fetch_one(conn).await { + Ok(row) => Ok(Some(row.into())), + Err(err) => match err { + sqlx::Error::RowNotFound => Ok(None), + _ => Err(DbErr::Query(err.to_string())), + }, + } + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn query_all(&self, stmt: Statement) -> Result, DbErr> { + debug_print!("{}", stmt); + + let query = sqlx_query(&stmt); + if let Ok(conn) = &mut self.pool.acquire().await { + match query.fetch_all(conn).await { + Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()), + Err(err) => Err(sqlx_error_to_query_err(err)), + } + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } +} + +impl From for QueryResult { + fn from(row: PgRow) -> QueryResult { + QueryResult { + row: QueryResultRow::SqlxPostgres(row), + } + } +} + +impl From for ExecResult { + fn from(result: PgQueryResult) -> ExecResult { + ExecResult { + result: ExecResultHolder::SqlxPostgres { + last_insert_id: 0, + rows_affected: result.rows_affected(), + }, + } + } +} + +pub(crate) fn query_result_into_exec_result(res: QueryResult) -> Result { + let last_insert_id: i32 = res.try_get("", "last_insert_id")?; + Ok(ExecResult { + result: ExecResultHolder::SqlxPostgres { + last_insert_id: last_insert_id as u64, + rows_affected: 0, + }, + }) +} + +fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Postgres, PgArguments> { + let mut query = sqlx::query(&stmt.sql); + if let Some(values) = &stmt.values { + query = bind_query(query, values); + } + query +} diff --git a/src/executor/execute.rs b/src/executor/execute.rs index b795466a..00375bb7 100644 --- a/src/executor/execute.rs +++ b/src/executor/execute.rs @@ -7,6 +7,11 @@ pub struct ExecResult { pub(crate) enum ExecResultHolder { #[cfg(feature = "sqlx-mysql")] SqlxMySql(sqlx::mysql::MySqlQueryResult), + #[cfg(feature = "sqlx-postgres")] + SqlxPostgres { + last_insert_id: u64, + rows_affected: u64, + }, #[cfg(feature = "sqlx-sqlite")] SqlxSqlite(sqlx::sqlite::SqliteQueryResult), #[cfg(feature = "mock")] @@ -20,6 +25,8 @@ impl ExecResult { match &self.result { #[cfg(feature = "sqlx-mysql")] ExecResultHolder::SqlxMySql(result) => result.last_insert_id(), + #[cfg(feature = "sqlx-postgres")] + ExecResultHolder::SqlxPostgres { last_insert_id, .. } => last_insert_id.to_owned(), #[cfg(feature = "sqlx-sqlite")] ExecResultHolder::SqlxSqlite(result) => { let last_insert_rowid = result.last_insert_rowid(); @@ -38,6 +45,8 @@ impl ExecResult { match &self.result { #[cfg(feature = "sqlx-mysql")] ExecResultHolder::SqlxMySql(result) => result.rows_affected(), + #[cfg(feature = "sqlx-postgres")] + ExecResultHolder::SqlxPostgres { rows_affected, .. } => rows_affected.to_owned(), #[cfg(feature = "sqlx-sqlite")] ExecResultHolder::SqlxSqlite(result) => result.rows_affected(), #[cfg(feature = "mock")] diff --git a/src/executor/insert.rs b/src/executor/insert.rs index b7fd1163..f8d60137 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,4 +1,6 @@ -use crate::{error::*, ActiveModelTrait, DatabaseConnection, Insert, QueryTrait, Statement}; +use crate::{ + error::*, ActiveModelTrait, DatabaseConnection, EntityTrait, Insert, Iterable, Statement, +}; use sea_query::InsertStatement; use std::future::Future; @@ -21,7 +23,19 @@ where db: &DatabaseConnection, ) -> impl Future> + '_ { // so that self is dropped before entering await - Inserter::new(self.into_query()).exec(db) + let mut query = self.query; + #[cfg(feature = "sqlx-postgres")] + if let DatabaseConnection::SqlxPostgresPoolConnection(_) = db { + use sea_query::{Alias, Expr, Query}; + for key in ::PrimaryKey::iter() { + query.returning( + Query::select() + .expr_as(Expr::col(key), Alias::new("last_insert_id")) + .to_owned(), + ); + } + } + Inserter::new(query).exec(db) } } @@ -41,8 +55,15 @@ impl Inserter { // Only Statement impl Send async fn exec_insert(statement: Statement, db: &DatabaseConnection) -> Result { - let result = db.execute(statement).await?; // TODO: Postgres instead use query_one + returning clause + let result = match db { + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => { + let res = conn.query_one(statement).await?.unwrap(); + crate::query_result_into_exec_result(res)? + } + _ => db.execute(statement).await?, + }; Ok(InsertResult { last_insert_id: result.last_insert_id(), }) diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index 3698456a..a5aa3445 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -97,7 +97,7 @@ where /// let mut cake_pages = cake::Entity::find() /// .order_by_asc(cake::Column::Id) /// .paginate(db, 50); - /// + /// /// while let Some(cakes) = cake_pages.fetch_and_next().await? { /// // Do something on cakes: Vec /// } @@ -127,7 +127,7 @@ where /// .order_by_asc(cake::Column::Id) /// .paginate(db, 50) /// .into_stream(); - /// + /// /// while let Some(cakes) = cake_stream.try_next().await? { /// // Do something on cakes: Vec /// } @@ -153,7 +153,7 @@ where mod tests { use crate::entity::prelude::*; use crate::tests_cfg::*; - use crate::{DbBackend, DatabaseConnection, MockDatabase, Transaction}; + use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction}; use futures::TryStreamExt; use sea_query::{Alias, Expr, SelectStatement, Value}; diff --git a/src/executor/query.rs b/src/executor/query.rs index aad4c3fa..06e8ec7b 100644 --- a/src/executor/query.rs +++ b/src/executor/query.rs @@ -11,6 +11,8 @@ pub struct QueryResult { pub(crate) enum QueryResultRow { #[cfg(feature = "sqlx-mysql")] SqlxMySql(sqlx::mysql::MySqlRow), + #[cfg(feature = "sqlx-postgres")] + SqlxPostgres(sqlx::postgres::PgRow), #[cfg(feature = "sqlx-sqlite")] SqlxSqlite(sqlx::sqlite::SqliteRow), #[cfg(feature = "mock")] @@ -39,6 +41,8 @@ impl fmt::Debug for QueryResultRow { match self { #[cfg(feature = "sqlx-mysql")] Self::SqlxMySql(row) => write!(f, "{:?}", row), + #[cfg(feature = "sqlx-postgres")] + Self::SqlxPostgres(_) => panic!("QueryResultRow::SqlxPostgres cannot be inspected"), #[cfg(feature = "sqlx-sqlite")] Self::SqlxSqlite(_) => panic!("QueryResultRow::SqlxSqlite cannot be inspected"), #[cfg(feature = "mock")] @@ -61,6 +65,12 @@ macro_rules! try_getable_all { row.try_get(column.as_str()) .map_err(crate::sqlx_error_to_query_err) } + #[cfg(feature = "sqlx-postgres")] + QueryResultRow::SqlxPostgres(row) => { + use sqlx::Row; + row.try_get(column.as_str()) + .map_err(crate::sqlx_error_to_query_err) + } #[cfg(feature = "sqlx-sqlite")] QueryResultRow::SqlxSqlite(row) => { use sqlx::Row; @@ -85,6 +95,75 @@ macro_rules! try_getable_all { Err(_) => Ok(None), } } + #[cfg(feature = "sqlx-postgres")] + QueryResultRow::SqlxPostgres(row) => { + use sqlx::Row; + match row.try_get(column.as_str()) { + Ok(v) => Ok(Some(v)), + Err(_) => Ok(None), + } + } + #[cfg(feature = "sqlx-sqlite")] + QueryResultRow::SqlxSqlite(row) => { + use sqlx::Row; + match row.try_get(column.as_str()) { + Ok(v) => Ok(Some(v)), + Err(_) => Ok(None), + } + } + #[cfg(feature = "mock")] + QueryResultRow::Mock(row) => match row.try_get(column.as_str()) { + Ok(v) => Ok(Some(v)), + Err(_) => Ok(None), + }, + } + } + } + }; +} + +macro_rules! try_getable_unsigned { + ( $type: ty ) => { + impl TryGetable for $type { + fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result { + let column = format!("{}{}", pre, col); + match &res.row { + #[cfg(feature = "sqlx-mysql")] + QueryResultRow::SqlxMySql(row) => { + use sqlx::Row; + row.try_get(column.as_str()) + .map_err(crate::sqlx_error_to_query_err) + } + QueryResultRow::SqlxPostgres(_) => { + panic!("{} unsupported by sqlx-postgres", stringify!($type)) + } + #[cfg(feature = "sqlx-sqlite")] + QueryResultRow::SqlxSqlite(row) => { + use sqlx::Row; + row.try_get(column.as_str()) + .map_err(crate::sqlx_error_to_query_err) + } + #[cfg(feature = "mock")] + QueryResultRow::Mock(row) => Ok(row.try_get(column.as_str())?), + } + } + } + + impl TryGetable for Option<$type> { + fn try_get(res: &QueryResult, pre: &str, col: &str) -> Result { + let column = format!("{}{}", pre, col); + match &res.row { + #[cfg(feature = "sqlx-mysql")] + QueryResultRow::SqlxMySql(row) => { + use sqlx::Row; + match row.try_get(column.as_str()) { + Ok(v) => Ok(Some(v)), + Err(_) => Ok(None), + } + } + QueryResultRow::SqlxPostgres(_) => { + panic!("{} unsupported by sqlx-postgres", stringify!($type)) + } #[cfg(feature = "sqlx-sqlite")] QueryResultRow::SqlxSqlite(row) => { use sqlx::Row; @@ -116,6 +195,10 @@ macro_rules! try_getable_mysql { row.try_get(column.as_str()) .map_err(crate::sqlx_error_to_query_err) } + #[cfg(feature = "sqlx-postgres")] + QueryResultRow::SqlxPostgres(_) => { + panic!("{} unsupported by sqlx-postgres", stringify!($type)) + } #[cfg(feature = "sqlx-sqlite")] QueryResultRow::SqlxSqlite(_) => { panic!("{} unsupported by sqlx-sqlite", stringify!($type)) @@ -138,6 +221,10 @@ macro_rules! try_getable_mysql { Err(_) => Ok(None), } } + #[cfg(feature = "sqlx-postgres")] + QueryResultRow::SqlxPostgres(_) => { + panic!("{} unsupported by sqlx-sqlite", stringify!($type)) + } #[cfg(feature = "sqlx-sqlite")] QueryResultRow::SqlxSqlite(_) => { panic!("{} unsupported by sqlx-sqlite", stringify!($type)) @@ -158,8 +245,8 @@ try_getable_all!(i8); try_getable_all!(i16); try_getable_all!(i32); try_getable_all!(i64); -try_getable_all!(u8); -try_getable_all!(u16); +try_getable_unsigned!(u8); +try_getable_unsigned!(u16); try_getable_all!(u32); try_getable_mysql!(u64); try_getable_all!(f32); @@ -188,6 +275,12 @@ impl TryGetable for Decimal { row.try_get(column.as_str()) .map_err(crate::sqlx_error_to_query_err) } + #[cfg(feature = "sqlx-postgres")] + QueryResultRow::SqlxPostgres(row) => { + use sqlx::Row; + row.try_get(column.as_str()) + .map_err(crate::sqlx_error_to_query_err) + } #[cfg(feature = "sqlx-sqlite")] QueryResultRow::SqlxSqlite(row) => { use sqlx::Row; @@ -217,6 +310,14 @@ impl TryGetable for Option { Err(_) => Ok(None), } } + #[cfg(feature = "sqlx-postgres")] + QueryResultRow::SqlxPostgres(row) => { + use sqlx::Row; + match row.try_get(column.as_str()) { + Ok(v) => Ok(Some(v)), + Err(_) => Ok(None), + } + } #[cfg(feature = "sqlx-sqlite")] QueryResultRow::SqlxSqlite(_) => { let result: Result = TryGetable::try_get(res, pre, col); diff --git a/src/executor/select.rs b/src/executor/select.rs index ffc94d64..647b6c95 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -80,32 +80,17 @@ where /// # #[cfg(feature = "mock")] /// # use sea_orm::{error::*, tests_cfg::*, MockDatabase, Transaction, DbBackend}; /// # - /// # let db = MockDatabase::new(DbBackend::Postgres) - /// # .append_query_results(vec![ - /// # vec![ - /// # cake::Model { - /// # id: 1, - /// # name: "New York Cheese".to_owned(), - /// # }, - /// # ], - /// # ]) - /// # .into_connection(); + /// # let db = MockDatabase::new(DbBackend::Postgres).into_connection(); /// # /// use sea_orm::{entity::*, query::*, tests_cfg::cake}; /// /// # let _: Result<(), DbErr> = async_std::task::block_on(async { /// # - /// assert_eq!( - /// cake::Entity::find().from_raw_sql( - /// Statement::from_sql_and_values( - /// DbBackend::Postgres, r#"SELECT "cake"."id", "cake"."name" FROM "cake""#, vec![] - /// ) - /// ).one(&db).await?, - /// Some(cake::Model { - /// id: 1, - /// name: "New York Cheese".to_owned(), - /// }) - /// ); + /// let cheese: Option = cake::Entity::find().from_raw_sql( + /// Statement::from_sql_and_values( + /// DbBackend::Postgres, r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "id" = $1"#, vec![1.into()] + /// ) + /// ).one(&db).await?; /// # /// # Ok(()) /// # }); @@ -114,7 +99,7 @@ where /// db.into_transaction_log(), /// vec![ /// Transaction::from_sql_and_values( - /// DbBackend::Postgres, r#"SELECT "cake"."id", "cake"."name" FROM "cake""#, vec![] + /// DbBackend::Postgres, r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "id" = $1"#, vec![1.into()] /// ), /// ]); /// ``` diff --git a/src/query/json.rs b/src/query/json.rs index 1d372308..95538032 100644 --- a/src/query/json.rs +++ b/src/query/json.rs @@ -43,6 +43,44 @@ impl FromQueryResult for JsonValue { } Ok(JsonValue::Object(map)) } + #[cfg(feature = "sqlx-postgres")] + QueryResultRow::SqlxPostgres(row) => { + use serde_json::json; + use sqlx::{Column, Postgres, Row, Type}; + let mut map = Map::new(); + for column in row.columns() { + let col = if !column.name().starts_with(pre) { + continue; + } else { + column.name().replacen(pre, "", 1) + }; + let col_type = column.type_info(); + macro_rules! match_postgres_type { + ( $type: ty ) => { + if <$type as Type>::type_info().eq(col_type) { + map.insert( + col.to_owned(), + json!(res.try_get::>(pre, &col)?), + ); + continue; + } + }; + } + match_postgres_type!(bool); + match_postgres_type!(i8); + match_postgres_type!(i16); + match_postgres_type!(i32); + match_postgres_type!(i64); + // match_postgres_type!(u8); // unsupported by SQLx Postgres + // match_postgres_type!(u16); // unsupported by SQLx Postgres + match_postgres_type!(u32); + // match_postgres_type!(u64); // unsupported by SQLx Postgres + match_postgres_type!(f32); + match_postgres_type!(f64); + match_postgres_type!(String); + } + Ok(JsonValue::Object(map)) + } #[cfg(feature = "sqlx-sqlite")] QueryResultRow::SqlxSqlite(row) => { use serde_json::json; diff --git a/src/query/mod.rs b/src/query/mod.rs index e41b8dc2..899882ba 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -20,4 +20,4 @@ pub use select::*; pub use traits::*; pub use update::*; -pub use crate::{Statement, InsertResult, UpdateResult}; \ No newline at end of file +pub use crate::{InsertResult, Statement, UpdateResult, Value, Values}; diff --git a/tests/pg_tests.rs b/tests/pg_tests.rs new file mode 100644 index 00000000..f3ccc471 --- /dev/null +++ b/tests/pg_tests.rs @@ -0,0 +1,99 @@ +use sea_orm::{ + entity::prelude::*, Database, DatabaseBackend, DatabaseConnection, DbErr, ExecResult, Set, + Statement, +}; + +pub mod bakery_chain; +pub use bakery_chain::*; +use sea_query::{ColumnDef, TableCreateStatement}; + +// cargo test --test pg_tests -- --nocapture +#[async_std::test] +async fn main() { + let base_url = "postgres://root:root@localhost"; + let db_name = "bakery_chain_schema_crud_tests"; + + let db = setup(base_url, db_name).await; + setup_schema(&db).await; + create_entities(&db).await; +} + +pub async fn setup(base_url: &str, db_name: &str) -> DatabaseConnection { + let url = format!("{}/postgres", base_url); + let db = Database::connect(&url).await.unwrap(); + + let _drop_db_result = db + .execute(Statement::from_string( + DatabaseBackend::Postgres, + format!("DROP DATABASE IF EXISTS \"{}\";", db_name), + )) + .await + .unwrap(); + + let _create_db_result = db + .execute(Statement::from_string( + DatabaseBackend::Postgres, + format!("CREATE DATABASE \"{}\";", db_name), + )) + .await + .unwrap(); + + let url = format!("{}/{}", base_url, db_name); + Database::connect(&url).await.unwrap() +} + +async fn setup_schema(db: &DatabaseConnection) { + assert!(create_bakery_table(db).await.is_ok()); +} + +async fn create_table( + db: &DatabaseConnection, + stmt: &TableCreateStatement, +) -> Result { + let builder = db.get_database_backend(); + db.execute(builder.build(stmt)).await +} + +pub async fn create_bakery_table(db: &DatabaseConnection) -> Result { + let stmt = sea_query::Table::create() + .table(bakery::Entity) + .if_not_exists() + .col( + ColumnDef::new(bakery::Column::Id) + .integer() + .not_null() + .auto_increment() + .primary_key(), + ) + .col(ColumnDef::new(bakery::Column::Name).string()) + .col(ColumnDef::new(bakery::Column::ProfitMargin).double()) + .to_owned(); + + create_table(db, &stmt).await +} + +async fn create_entities(db: &DatabaseConnection) { + test_create_bakery(db).await; +} + +pub async fn test_create_bakery(db: &DatabaseConnection) { + let seaside_bakery = bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + }; + let res = Bakery::insert(seaside_bakery) + .exec(db) + .await + .expect("could not insert bakery"); + + let bakery = Bakery::find_by_id(res.last_insert_id) + .one(db) + .await + .expect("could not find bakery"); + + assert!(bakery.is_some()); + let bakery_model = bakery.unwrap(); + assert_eq!(bakery_model.name, "SeaSide Bakery"); + assert_eq!(bakery_model.profit_margin, 10.4); +}