diff --git a/sea-orm-cli/src/bin/main.rs b/sea-orm-cli/src/bin/main.rs index e9d40401..e847d003 100644 --- a/sea-orm-cli/src/bin/main.rs +++ b/sea-orm-cli/src/bin/main.rs @@ -17,8 +17,16 @@ async fn main() { } Commands::Migrate { migration_dir, + database_schema, + database_url, command, - } => run_migrate_command(command, migration_dir.as_str(), verbose) - .unwrap_or_else(handle_error), + } => run_migrate_command( + command, + &migration_dir, + database_schema, + database_url, + verbose, + ) + .unwrap_or_else(handle_error), } } diff --git a/sea-orm-cli/src/bin/sea.rs b/sea-orm-cli/src/bin/sea.rs index 4a8b3d14..edf15c83 100644 --- a/sea-orm-cli/src/bin/sea.rs +++ b/sea-orm-cli/src/bin/sea.rs @@ -19,8 +19,16 @@ async fn main() { } Commands::Migrate { migration_dir, + database_schema, + database_url, command, - } => run_migrate_command(command, migration_dir.as_str(), verbose) - .unwrap_or_else(handle_error), + } => run_migrate_command( + command, + &migration_dir, + database_schema, + database_url, + verbose, + ) + .unwrap_or_else(handle_error), } } diff --git a/sea-orm-cli/src/cli.rs b/sea-orm-cli/src/cli.rs index 827a17e1..ebbbf6e9 100644 --- a/sea-orm-cli/src/cli.rs +++ b/sea-orm-cli/src/cli.rs @@ -34,6 +34,28 @@ you should provide the directory of that submodule.", )] migration_dir: String, + #[clap( + value_parser, + global = true, + short = 's', + long, + env = "DATABASE_SCHEMA", + long_help = "Database schema\n \ + - For MySQL and SQLite, this argument is ignored.\n \ + - For PostgreSQL, this argument is optional with default value 'public'.\n" + )] + database_schema: Option, + + #[clap( + value_parser, + global = true, + short = 'u', + long, + env = "DATABASE_URL", + help = "Database URL" + )] + database_url: Option, + #[clap(subcommand)] command: Option, }, diff --git a/sea-orm-cli/src/commands/migrate.rs b/sea-orm-cli/src/commands/migrate.rs index ac612a59..1ab5f33a 100644 --- a/sea-orm-cli/src/commands/migrate.rs +++ b/sea-orm-cli/src/commands/migrate.rs @@ -13,6 +13,8 @@ use crate::MigrateSubcommands; pub fn run_migrate_command( command: Option, migration_dir: &str, + database_schema: Option, + database_url: Option, verbose: bool, ) -> Result<(), Box> { match command { @@ -41,20 +43,20 @@ pub fn run_migrate_command( format!("{}/Cargo.toml", migration_dir) }; // Construct the arguments that will be supplied to `cargo` command - let mut args = vec![ - "run", - "--manifest-path", - manifest_path.as_str(), - "--", - subcommand, - ]; + let mut args = vec!["run", "--manifest-path", &manifest_path, "--", subcommand]; let mut num: String = "".to_string(); if let Some(steps) = steps { num = steps.to_string(); } if !num.is_empty() { - args.extend(["-n", num.as_str()]) + args.extend(["-n", &num]) + } + if let Some(database_url) = &database_url { + args.extend(["-u", database_url]); + } + if let Some(database_schema) = &database_schema { + args.extend(["-s", database_schema]); } if verbose { args.push("-v"); diff --git a/sea-orm-migration/src/cli.rs b/sea-orm-migration/src/cli.rs index d098cb81..93c7ef38 100644 --- a/sea-orm-migration/src/cli.rs +++ b/sea-orm-migration/src/cli.rs @@ -3,7 +3,7 @@ use dotenvy::dotenv; use std::{error::Error, fmt::Display, process::exit}; use tracing_subscriber::{prelude::*, EnvFilter}; -use sea_orm::{Database, DbConn}; +use sea_orm::{ConnectOptions, Database, DbConn}; use sea_orm_cli::{run_migrate_generate, run_migrate_init, MigrateSubcommands}; use super::MigratorTrait; @@ -15,10 +15,20 @@ where M: MigratorTrait, { dotenv().ok(); - let url = std::env::var("DATABASE_URL").expect("Environment variable 'DATABASE_URL' not set"); - let db = &Database::connect(&url).await.unwrap(); let cli = Cli::parse(); + let url = cli + .database_url + .expect("Environment variable 'DATABASE_URL' not set"); + let schema = cli.database_schema.unwrap_or_else(|| "public".to_owned()); + + let connect_options = ConnectOptions::new(url) + .set_schema_search_path(schema) + .to_owned(); + let db = &Database::connect(connect_options) + .await + .expect("Fail to acquire database connection"); + run_migrate(migrator, db, cli.command, cli.verbose) .await .unwrap_or_else(handle_error); @@ -81,6 +91,28 @@ pub struct Cli { #[clap(action, short = 'v', long, global = true, help = "Show debug messages")] verbose: bool, + #[clap( + value_parser, + global = true, + short = 's', + long, + env = "DATABASE_SCHEMA", + long_help = "Database schema\n \ + - For MySQL and SQLite, this argument is ignored.\n \ + - For PostgreSQL, this argument is optional with default value 'public'.\n" + )] + database_schema: Option, + + #[clap( + value_parser, + global = true, + short = 'u', + long, + env = "DATABASE_URL", + help = "Database URL" + )] + database_url: Option, + #[clap(subcommand)] command: Option, } diff --git a/sea-orm-migration/tests/main.rs b/sea-orm-migration/tests/main.rs index 2dfcf93d..fe006c47 100644 --- a/sea-orm-migration/tests/main.rs +++ b/sea-orm-migration/tests/main.rs @@ -1,7 +1,7 @@ mod migrator; use migrator::Migrator; -use sea_orm::{ConnectionTrait, Database, DbBackend, DbErr, Statement}; +use sea_orm::{ConnectOptions, ConnectionTrait, Database, DbBackend, DbErr, Statement}; use sea_orm_migration::prelude::*; #[async_std::test] @@ -11,9 +11,26 @@ async fn main() -> Result<(), DbErr> { .with_test_writer() .init(); - let url = std::env::var("DATABASE_URL").expect("Environment variable 'DATABASE_URL' not set"); - let db_name = "sea_orm_migration"; - let db = Database::connect(&url).await?; + let url = &std::env::var("DATABASE_URL").expect("Environment variable 'DATABASE_URL' not set"); + + run_migration(url, "sea_orm_migration", "public").await?; + + run_migration(url, "sea_orm_migration_schema", "my_schema").await?; + + Ok(()) +} + +async fn run_migration(url: &str, db_name: &str, schema: &str) -> Result<(), DbErr> { + let db_connect = |url: String| async { + let connect_options = ConnectOptions::new(url) + .set_schema_search_path(schema.to_owned()) + .to_owned(); + + Database::connect(connect_options).await + }; + + let db = db_connect(url.to_owned()).await?; + let db = &match db.get_database_backend() { DbBackend::MySql => { db.execute(Statement::from_string( @@ -23,7 +40,7 @@ async fn main() -> Result<(), DbErr> { .await?; let url = format!("{}/{}", url, db_name); - Database::connect(&url).await? + db_connect(url).await? } DbBackend::Postgres => { db.execute(Statement::from_string( @@ -38,7 +55,15 @@ async fn main() -> Result<(), DbErr> { .await?; let url = format!("{}/{}", url, db_name); - Database::connect(&url).await? + let db = db_connect(url).await?; + + db.execute(Statement::from_string( + db.get_database_backend(), + format!("CREATE SCHEMA IF NOT EXISTS \"{}\";", schema), + )) + .await?; + + db } DbBackend::Sqlite => db, }; diff --git a/src/database/mod.rs b/src/database/mod.rs index 1a1a399e..795789e6 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -48,6 +48,8 @@ pub struct ConnectOptions { pub(crate) sqlx_logging_level: log::LevelFilter, /// set sqlcipher key pub(crate) sqlcipher_key: Option>, + /// Schema search path (PostgreSQL only) + pub(crate) schema_search_path: Option, } impl Database { @@ -114,6 +116,7 @@ impl ConnectOptions { sqlx_logging: true, sqlx_logging_level: log::LevelFilter::Info, sqlcipher_key: None, + schema_search_path: None, } } @@ -251,4 +254,10 @@ impl ConnectOptions { self.sqlcipher_key = Some(value.into()); self } + + /// Set schema search path (PostgreSQL only) + pub fn set_schema_search_path(&mut self, schema_search_path: String) -> &mut Self { + self.schema_search_path = Some(schema_search_path); + self + } } diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index 085abaa8..72432435 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -52,7 +52,22 @@ impl SqlxPostgresConnector { } else { opt.log_statements(options.sqlx_logging_level); } - match options.pool_options().connect_with(opt).await { + let set_search_path_sql = options + .schema_search_path + .as_ref() + .map(|schema| format!("SET search_path = '{}'", schema)); + let mut pool_options = options.pool_options(); + if let Some(sql) = set_search_path_sql { + pool_options = pool_options.after_connect(move |conn, _| { + let sql = sql.clone(); + Box::pin(async move { + sqlx::Executor::execute(conn, sql.as_str()) + .await + .map(|_| ()) + }) + }); + } + match pool_options.connect_with(opt).await { Ok(pool) => Ok(DatabaseConnection::SqlxPostgresPoolConnection( SqlxPostgresPoolConnection { pool,