diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 28e73ad6..c9073f5c 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -145,7 +145,7 @@ jobs: with: command: test args: > - --all + --workspace - uses: actions-rs/cargo@v1 with: @@ -153,6 +153,12 @@ jobs: args: > --manifest-path sea-orm-rocket/Cargo.toml + - uses: actions-rs/cargo@v1 + with: + command: test + args: > + --manifest-path sea-orm-cli/Cargo.toml + cli: name: CLI runs-on: ${{ matrix.os }} diff --git a/Cargo.toml b/Cargo.toml index 7dfb7cb8..71690374 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ serde_json = { version = "^1", optional = true } sqlx = { version = "^0.5", optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true } ouroboros = "0.11" +url = "^2.2" [dev-dependencies] smol = { version = "^1.2" } diff --git a/sea-orm-cli/Cargo.toml b/sea-orm-cli/Cargo.toml index 02fddcb1..aead20df 100644 --- a/sea-orm-cli/Cargo.toml +++ b/sea-orm-cli/Cargo.toml @@ -32,6 +32,8 @@ sea-schema = { version = "^0.2.9", default-features = false, features = [ sqlx = { version = "^0.5", default-features = false, features = [ "mysql", "postgres" ] } env_logger = { version = "^0.9" } log = { version = "^0.4" } +url = "^2.2" +smol = "1.2.5" [features] default = [ "runtime-async-std-native-tls" ] diff --git a/sea-orm-cli/src/main.rs b/sea-orm-cli/src/main.rs index 528ce5f5..2529b0f1 100644 --- a/sea-orm-cli/src/main.rs +++ b/sea-orm-cli/src/main.rs @@ -3,6 +3,7 @@ use dotenv::dotenv; use log::LevelFilter; use sea_orm_codegen::{EntityTransformer, OutputFile, WithSerde}; use std::{error::Error, fmt::Display, fs, io::Write, path::Path, process::Command, str::FromStr}; +use url::Url; mod cli; @@ -23,7 +24,6 @@ async fn main() { async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box> { match matches.subcommand() { ("entity", Some(args)) => { - let url = args.value_of("DATABASE_URL").unwrap(); let output_dir = args.value_of("OUTPUT_DIR").unwrap(); let include_hidden_tables = args.is_present("INCLUDE_HIDDEN_TABLES"); let tables = args @@ -32,8 +32,67 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box>(); let expanded_format = args.is_present("EXPANDED_FORMAT"); let with_serde = args.value_of("WITH_SERDE").unwrap(); + if args.is_present("VERBOSE") { + let _ = ::env_logger::builder() + .filter_level(LevelFilter::Debug) + .is_test(true) + .try_init(); + } + + // The database should be a valid URL that can be parsed + // protocol://username:password@host/database_name + let url = Url::parse( + args.value_of("DATABASE_URL") + .expect("No database url could be found"), + )?; + + // Make sure we have all the required url components + // + // Missing scheme will have been caught by the Url::parse() call + // above + + let url_username = url.username(); + let url_password = url.password(); + let url_host = url.host_str(); + + // Panic on any that are missing + if url_username.is_empty() { + panic!("No username was found in the database url"); + } + if url_password.is_none() { + panic!("No password was found in the database url"); + } + if url_host.is_none() { + panic!("No host was found in the database url"); + } + + // The database name should be the first element of the path string + // + // Throwing an error if there is no database name since it might be + // accepted by the database without it, while we're looking to dump + // information from a particular database + let database_name = url + .path_segments() + .unwrap_or_else(|| { + panic!( + "There is no database name as part of the url path: {}", + url.as_str() + ) + }) + .next() + .unwrap(); + + // An empty string as the database name is also an error + if database_name.is_empty() { + panic!( + "There is no database name as part of the url path: {}", + url.as_str() + ); + } + + // Closures for filtering tables let filter_tables = |table: &str| -> bool { - if tables.len() > 0 { + if !tables.is_empty() { return tables.contains(&table); } @@ -43,49 +102,43 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box { + use sea_schema::mysql::discovery::SchemaDiscovery; + use sqlx::MySqlPool; - let url_parts: Vec<&str> = url.split("/").collect(); - let schema = url_parts.last().unwrap(); - let connection = MySqlPool::connect(url).await?; - let schema_discovery = SchemaDiscovery::new(connection, schema); - let schema = schema_discovery.discover().await; - schema - .tables - .into_iter() - .filter(|schema| filter_tables(&schema.info.name)) - .filter(|schema| filter_hidden_tables(&schema.info.name)) - .map(|schema| schema.write()) - .collect() - } else if url.starts_with("postgres://") || url.starts_with("postgresql://") { - use sea_schema::postgres::discovery::SchemaDiscovery; - use sqlx::PgPool; + let connection = MySqlPool::connect(url.as_str()).await?; + let schema_discovery = SchemaDiscovery::new(connection, database_name); + let schema = schema_discovery.discover().await; + schema + .tables + .into_iter() + .filter(|schema| filter_tables(&schema.info.name)) + .filter(|schema| filter_hidden_tables(&schema.info.name)) + .map(|schema| schema.write()) + .collect() + } + "postgres" | "postgresql" => { + use sea_schema::postgres::discovery::SchemaDiscovery; + use sqlx::PgPool; - let schema = args.value_of("DATABASE_SCHEMA").unwrap_or("public"); - let connection = PgPool::connect(url).await?; - let schema_discovery = SchemaDiscovery::new(connection, schema); - let schema = schema_discovery.discover().await; - schema - .tables - .into_iter() - .filter(|schema| filter_tables(&schema.info.name)) - .filter(|schema| filter_hidden_tables(&schema.info.name)) - .map(|schema| schema.write()) - .collect() - } else { - panic!("This database is not supported ({})", url) + let schema = args.value_of("DATABASE_SCHEMA").unwrap_or("public"); + let connection = PgPool::connect(url.as_str()).await?; + let schema_discovery = SchemaDiscovery::new(connection, schema); + let schema = schema_discovery.discover().await; + schema + .tables + .into_iter() + .filter(|schema| filter_tables(&schema.info.name)) + .filter(|schema| filter_hidden_tables(&schema.info.name)) + .map(|schema| schema.write()) + .collect() + } + _ => unimplemented!("{} is not supported", url.scheme()), }; let output = EntityTransformer::transform(table_stmts)? @@ -99,6 +152,8 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box match e.downcast::() { + Ok(_) => (), + Err(e) => panic!("Expected ParseError but got: {:?}", e), + }, + _ => panic!("Should have panicked"), + } + } + + #[test] + #[should_panic] + fn test_generate_entity_no_database_section() { + let matches = cli::build_cli() + .setting(AppSettings::NoBinaryName) + .get_matches_from(vec![ + "generate", + "entity", + "--database-url", + "postgresql://root:root@localhost:3306", + ]); + + smol::block_on(run_generate_command(matches.subcommand().1.unwrap())) + .unwrap_or_else(handle_error); + } + + #[test] + #[should_panic] + fn test_generate_entity_no_database_path() { + let matches = cli::build_cli() + .setting(AppSettings::NoBinaryName) + .get_matches_from(vec![ + "generate", + "entity", + "--database-url", + "mysql://root:root@localhost:3306/", + ]); + + smol::block_on(run_generate_command(matches.subcommand().1.unwrap())) + .unwrap_or_else(handle_error); + } + + #[test] + #[should_panic] + fn test_generate_entity_no_username() { + let matches = cli::build_cli() + .setting(AppSettings::NoBinaryName) + .get_matches_from(vec![ + "generate", + "entity", + "--database-url", + "mysql://:root@localhost:3306/database", + ]); + + smol::block_on(run_generate_command(matches.subcommand().1.unwrap())) + .unwrap_or_else(handle_error); + } + + #[test] + #[should_panic] + fn test_generate_entity_no_password() { + let matches = cli::build_cli() + .setting(AppSettings::NoBinaryName) + .get_matches_from(vec![ + "generate", + "entity", + "--database-url", + "mysql://root:@localhost:3306/database", + ]); + + smol::block_on(run_generate_command(matches.subcommand().1.unwrap())) + .unwrap_or_else(handle_error); + } + + #[async_std::test] + async fn test_generate_entity_no_host() { + let matches = cli::build_cli() + .setting(AppSettings::NoBinaryName) + .get_matches_from(vec![ + "generate", + "entity", + "--database-url", + "postgres://root:root@/database", + ]); + + let result = std::panic::catch_unwind(|| { + smol::block_on(run_generate_command(matches.subcommand().1.unwrap())) + }); + + // Make sure result is a ParseError + match result { + Ok(Err(e)) => match e.downcast::() { + Ok(_) => (), + Err(e) => panic!("Expected ParseError but got: {:?}", e), + }, + _ => panic!("Should have panicked"), + } + } +} diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index e8f37546..6b8e3356 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -4,6 +4,7 @@ use crate::{ }; use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; use std::{future::Future, pin::Pin}; +use url::Url; #[cfg(feature = "sqlx-dep")] use sqlx::pool::PoolConnection; @@ -223,12 +224,13 @@ impl DatabaseConnection { impl DbBackend { pub fn is_prefix_of(self, base_url: &str) -> bool { + let base_url_parsed = Url::parse(base_url).unwrap(); match self { Self::Postgres => { - base_url.starts_with("postgres://") || base_url.starts_with("postgresql://") + base_url_parsed.scheme() == "postgres" || base_url_parsed.scheme() == "postgresql" } - Self::MySql => base_url.starts_with("mysql://"), - Self::Sqlite => base_url.starts_with("sqlite:"), + Self::MySql => base_url_parsed.scheme() == "mysql", + Self::Sqlite => base_url_parsed.scheme() == "sqlite", } }