From d6ca491d80c231826b16c04d9216b08ff32079db Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Wed, 9 Nov 2022 15:01:44 +0800 Subject: [PATCH] [CLI] Generate entity Postgres connection with schema search path (#1212) --- sea-orm-cli/src/commands/generate.rs | 34 ++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/sea-orm-cli/src/commands/generate.rs b/sea-orm-cli/src/commands/generate.rs index b1d0971b..b73fe3c3 100644 --- a/sea-orm-cli/src/commands/generate.rs +++ b/sea-orm-cli/src/commands/generate.rs @@ -115,7 +115,7 @@ pub async fn run_generate_command( use sea_schema::mysql::discovery::SchemaDiscovery; use sqlx::MySql; - let connection = connect::(max_connections, url.as_str()).await?; + let connection = connect::(max_connections, url.as_str(), None).await?; let schema_discovery = SchemaDiscovery::new(connection, database_name); let schema = schema_discovery.discover().await; let table_stmts = schema @@ -132,7 +132,7 @@ pub async fn run_generate_command( use sea_schema::sqlite::discovery::SchemaDiscovery; use sqlx::Sqlite; - let connection = connect::(max_connections, url.as_str()).await?; + let connection = connect::(max_connections, url.as_str(), None).await?; let schema_discovery = SchemaDiscovery::new(connection); let schema = schema_discovery.discover().await?; let table_stmts = schema @@ -150,7 +150,8 @@ pub async fn run_generate_command( use sqlx::Postgres; let schema = &database_schema; - let connection = connect::(max_connections, url.as_str()).await?; + let connection = + connect::(max_connections, url.as_str(), Some(schema)).await?; let schema_discovery = SchemaDiscovery::new(connection, schema); let schema = schema_discovery.discover().await; let table_stmts = schema @@ -198,15 +199,30 @@ pub async fn run_generate_command( Ok(()) } -async fn connect(max_connections: u32, url: &str) -> Result, Box> +async fn connect( + max_connections: u32, + url: &str, + schema: Option<&str>, +) -> Result, Box> where DB: sqlx::Database, + for<'a> &'a mut ::Connection: sqlx::Executor<'a>, { - sqlx::pool::PoolOptions::::new() - .max_connections(max_connections) - .connect(url) - .await - .map_err(Into::into) + let mut pool_options = sqlx::pool::PoolOptions::::new().max_connections(max_connections); + // Set search_path for Postgres, E.g. Some("public") by default + // MySQL & SQLite connection initialize with schema `None` + if let Some(schema) = schema { + let sql = format!("SET search_path = '{}'", schema); + pool_options = pool_options.after_connect(move |conn, _| { + let sql = sql.clone(); + Box::pin(async move { + sqlx::Executor::execute(conn, sql.as_str()) + .await + .map(|_| ()) + }) + }); + } + pool_options.connect(url).await.map_err(Into::into) } impl From for CodegenDateTimeCrate {