Merge pull request #253 from AngelOnFira/improve-database-url-parsing
Changed manual url parsing to use Url crate
This commit is contained in:
commit
d6b83d3346
8
.github/workflows/rust.yml
vendored
8
.github/workflows/rust.yml
vendored
@ -145,7 +145,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
command: test
|
command: test
|
||||||
args: >
|
args: >
|
||||||
--all
|
--workspace
|
||||||
|
|
||||||
- uses: actions-rs/cargo@v1
|
- uses: actions-rs/cargo@v1
|
||||||
with:
|
with:
|
||||||
@ -153,6 +153,12 @@ jobs:
|
|||||||
args: >
|
args: >
|
||||||
--manifest-path sea-orm-rocket/Cargo.toml
|
--manifest-path sea-orm-rocket/Cargo.toml
|
||||||
|
|
||||||
|
- uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: test
|
||||||
|
args: >
|
||||||
|
--manifest-path sea-orm-cli/Cargo.toml
|
||||||
|
|
||||||
cli:
|
cli:
|
||||||
name: CLI
|
name: CLI
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
@ -37,6 +37,7 @@ serde_json = { version = "^1", optional = true }
|
|||||||
sqlx = { version = "^0.5", optional = true }
|
sqlx = { version = "^0.5", optional = true }
|
||||||
uuid = { version = "0.8", features = ["serde", "v4"], optional = true }
|
uuid = { version = "0.8", features = ["serde", "v4"], optional = true }
|
||||||
ouroboros = "0.11"
|
ouroboros = "0.11"
|
||||||
|
url = "^2.2"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
smol = { version = "^1.2" }
|
smol = { version = "^1.2" }
|
||||||
|
@ -32,6 +32,8 @@ sea-schema = { version = "^0.2.9", default-features = false, features = [
|
|||||||
sqlx = { version = "^0.5", default-features = false, features = [ "mysql", "postgres" ] }
|
sqlx = { version = "^0.5", default-features = false, features = [ "mysql", "postgres" ] }
|
||||||
env_logger = { version = "^0.9" }
|
env_logger = { version = "^0.9" }
|
||||||
log = { version = "^0.4" }
|
log = { version = "^0.4" }
|
||||||
|
url = "^2.2"
|
||||||
|
smol = "1.2.5"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = [ "runtime-async-std-native-tls" ]
|
default = [ "runtime-async-std-native-tls" ]
|
||||||
|
@ -3,6 +3,7 @@ use dotenv::dotenv;
|
|||||||
use log::LevelFilter;
|
use log::LevelFilter;
|
||||||
use sea_orm_codegen::{EntityTransformer, OutputFile, WithSerde};
|
use sea_orm_codegen::{EntityTransformer, OutputFile, WithSerde};
|
||||||
use std::{error::Error, fmt::Display, fs, io::Write, path::Path, process::Command, str::FromStr};
|
use std::{error::Error, fmt::Display, fs, io::Write, path::Path, process::Command, str::FromStr};
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
mod cli;
|
mod cli;
|
||||||
|
|
||||||
@ -23,7 +24,6 @@ async fn main() {
|
|||||||
async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Error>> {
|
async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Error>> {
|
||||||
match matches.subcommand() {
|
match matches.subcommand() {
|
||||||
("entity", Some(args)) => {
|
("entity", Some(args)) => {
|
||||||
let url = args.value_of("DATABASE_URL").unwrap();
|
|
||||||
let output_dir = args.value_of("OUTPUT_DIR").unwrap();
|
let output_dir = args.value_of("OUTPUT_DIR").unwrap();
|
||||||
let include_hidden_tables = args.is_present("INCLUDE_HIDDEN_TABLES");
|
let include_hidden_tables = args.is_present("INCLUDE_HIDDEN_TABLES");
|
||||||
let tables = args
|
let tables = args
|
||||||
@ -32,8 +32,67 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Er
|
|||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let expanded_format = args.is_present("EXPANDED_FORMAT");
|
let expanded_format = args.is_present("EXPANDED_FORMAT");
|
||||||
let with_serde = args.value_of("WITH_SERDE").unwrap();
|
let with_serde = args.value_of("WITH_SERDE").unwrap();
|
||||||
|
if args.is_present("VERBOSE") {
|
||||||
|
let _ = ::env_logger::builder()
|
||||||
|
.filter_level(LevelFilter::Debug)
|
||||||
|
.is_test(true)
|
||||||
|
.try_init();
|
||||||
|
}
|
||||||
|
|
||||||
|
// The database should be a valid URL that can be parsed
|
||||||
|
// protocol://username:password@host/database_name
|
||||||
|
let url = Url::parse(
|
||||||
|
args.value_of("DATABASE_URL")
|
||||||
|
.expect("No database url could be found"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Make sure we have all the required url components
|
||||||
|
//
|
||||||
|
// Missing scheme will have been caught by the Url::parse() call
|
||||||
|
// above
|
||||||
|
|
||||||
|
let url_username = url.username();
|
||||||
|
let url_password = url.password();
|
||||||
|
let url_host = url.host_str();
|
||||||
|
|
||||||
|
// Panic on any that are missing
|
||||||
|
if url_username.is_empty() {
|
||||||
|
panic!("No username was found in the database url");
|
||||||
|
}
|
||||||
|
if url_password.is_none() {
|
||||||
|
panic!("No password was found in the database url");
|
||||||
|
}
|
||||||
|
if url_host.is_none() {
|
||||||
|
panic!("No host was found in the database url");
|
||||||
|
}
|
||||||
|
|
||||||
|
// The database name should be the first element of the path string
|
||||||
|
//
|
||||||
|
// Throwing an error if there is no database name since it might be
|
||||||
|
// accepted by the database without it, while we're looking to dump
|
||||||
|
// information from a particular database
|
||||||
|
let database_name = url
|
||||||
|
.path_segments()
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
panic!(
|
||||||
|
"There is no database name as part of the url path: {}",
|
||||||
|
url.as_str()
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.next()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// An empty string as the database name is also an error
|
||||||
|
if database_name.is_empty() {
|
||||||
|
panic!(
|
||||||
|
"There is no database name as part of the url path: {}",
|
||||||
|
url.as_str()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Closures for filtering tables
|
||||||
let filter_tables = |table: &str| -> bool {
|
let filter_tables = |table: &str| -> bool {
|
||||||
if tables.len() > 0 {
|
if !tables.is_empty() {
|
||||||
return tables.contains(&table);
|
return tables.contains(&table);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,49 +102,43 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Er
|
|||||||
if include_hidden_tables {
|
if include_hidden_tables {
|
||||||
true
|
true
|
||||||
} else {
|
} else {
|
||||||
!table.starts_with("_")
|
!table.starts_with('_')
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
if args.is_present("VERBOSE") {
|
|
||||||
let _ = ::env_logger::builder()
|
|
||||||
.filter_level(LevelFilter::Debug)
|
|
||||||
.is_test(true)
|
|
||||||
.try_init();
|
|
||||||
}
|
|
||||||
|
|
||||||
let table_stmts = if url.starts_with("mysql://") {
|
let table_stmts = match url.scheme() {
|
||||||
use sea_schema::mysql::discovery::SchemaDiscovery;
|
"mysql" => {
|
||||||
use sqlx::MySqlPool;
|
use sea_schema::mysql::discovery::SchemaDiscovery;
|
||||||
|
use sqlx::MySqlPool;
|
||||||
|
|
||||||
let url_parts: Vec<&str> = url.split("/").collect();
|
let connection = MySqlPool::connect(url.as_str()).await?;
|
||||||
let schema = url_parts.last().unwrap();
|
let schema_discovery = SchemaDiscovery::new(connection, database_name);
|
||||||
let connection = MySqlPool::connect(url).await?;
|
let schema = schema_discovery.discover().await;
|
||||||
let schema_discovery = SchemaDiscovery::new(connection, schema);
|
schema
|
||||||
let schema = schema_discovery.discover().await;
|
.tables
|
||||||
schema
|
.into_iter()
|
||||||
.tables
|
.filter(|schema| filter_tables(&schema.info.name))
|
||||||
.into_iter()
|
.filter(|schema| filter_hidden_tables(&schema.info.name))
|
||||||
.filter(|schema| filter_tables(&schema.info.name))
|
.map(|schema| schema.write())
|
||||||
.filter(|schema| filter_hidden_tables(&schema.info.name))
|
.collect()
|
||||||
.map(|schema| schema.write())
|
}
|
||||||
.collect()
|
"postgres" | "postgresql" => {
|
||||||
} else if url.starts_with("postgres://") || url.starts_with("postgresql://") {
|
use sea_schema::postgres::discovery::SchemaDiscovery;
|
||||||
use sea_schema::postgres::discovery::SchemaDiscovery;
|
use sqlx::PgPool;
|
||||||
use sqlx::PgPool;
|
|
||||||
|
|
||||||
let schema = args.value_of("DATABASE_SCHEMA").unwrap_or("public");
|
let schema = args.value_of("DATABASE_SCHEMA").unwrap_or("public");
|
||||||
let connection = PgPool::connect(url).await?;
|
let connection = PgPool::connect(url.as_str()).await?;
|
||||||
let schema_discovery = SchemaDiscovery::new(connection, schema);
|
let schema_discovery = SchemaDiscovery::new(connection, schema);
|
||||||
let schema = schema_discovery.discover().await;
|
let schema = schema_discovery.discover().await;
|
||||||
schema
|
schema
|
||||||
.tables
|
.tables
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter(|schema| filter_tables(&schema.info.name))
|
.filter(|schema| filter_tables(&schema.info.name))
|
||||||
.filter(|schema| filter_hidden_tables(&schema.info.name))
|
.filter(|schema| filter_hidden_tables(&schema.info.name))
|
||||||
.map(|schema| schema.write())
|
.map(|schema| schema.write())
|
||||||
.collect()
|
.collect()
|
||||||
} else {
|
}
|
||||||
panic!("This database is not supported ({})", url)
|
_ => unimplemented!("{} is not supported", url.scheme()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let output = EntityTransformer::transform(table_stmts)?
|
let output = EntityTransformer::transform(table_stmts)?
|
||||||
@ -99,6 +152,8 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Er
|
|||||||
let mut file = fs::File::create(file_path)?;
|
let mut file = fs::File::create(file_path)?;
|
||||||
file.write_all(content.as_bytes())?;
|
file.write_all(content.as_bytes())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Format each of the files
|
||||||
for OutputFile { name, .. } in output.files.iter() {
|
for OutputFile { name, .. } in output.files.iter() {
|
||||||
Command::new("rustfmt")
|
Command::new("rustfmt")
|
||||||
.arg(dir.join(name))
|
.arg(dir.join(name))
|
||||||
@ -119,3 +174,125 @@ where
|
|||||||
eprintln!("{}", error);
|
eprintln!("{}", error);
|
||||||
::std::process::exit(1);
|
::std::process::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use clap::AppSettings;
|
||||||
|
use url::ParseError;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[async_std::test]
|
||||||
|
async fn test_generate_entity_no_protocol() {
|
||||||
|
let matches = cli::build_cli()
|
||||||
|
.setting(AppSettings::NoBinaryName)
|
||||||
|
.get_matches_from(vec![
|
||||||
|
"generate",
|
||||||
|
"entity",
|
||||||
|
"--database-url",
|
||||||
|
"://root:root@localhost:3306/database",
|
||||||
|
]);
|
||||||
|
|
||||||
|
let result = std::panic::catch_unwind(|| {
|
||||||
|
smol::block_on(run_generate_command(matches.subcommand().1.unwrap()))
|
||||||
|
});
|
||||||
|
|
||||||
|
// Make sure result is a ParseError
|
||||||
|
match result {
|
||||||
|
Ok(Err(e)) => match e.downcast::<ParseError>() {
|
||||||
|
Ok(_) => (),
|
||||||
|
Err(e) => panic!("Expected ParseError but got: {:?}", e),
|
||||||
|
},
|
||||||
|
_ => panic!("Should have panicked"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic]
|
||||||
|
fn test_generate_entity_no_database_section() {
|
||||||
|
let matches = cli::build_cli()
|
||||||
|
.setting(AppSettings::NoBinaryName)
|
||||||
|
.get_matches_from(vec![
|
||||||
|
"generate",
|
||||||
|
"entity",
|
||||||
|
"--database-url",
|
||||||
|
"postgresql://root:root@localhost:3306",
|
||||||
|
]);
|
||||||
|
|
||||||
|
smol::block_on(run_generate_command(matches.subcommand().1.unwrap()))
|
||||||
|
.unwrap_or_else(handle_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic]
|
||||||
|
fn test_generate_entity_no_database_path() {
|
||||||
|
let matches = cli::build_cli()
|
||||||
|
.setting(AppSettings::NoBinaryName)
|
||||||
|
.get_matches_from(vec![
|
||||||
|
"generate",
|
||||||
|
"entity",
|
||||||
|
"--database-url",
|
||||||
|
"mysql://root:root@localhost:3306/",
|
||||||
|
]);
|
||||||
|
|
||||||
|
smol::block_on(run_generate_command(matches.subcommand().1.unwrap()))
|
||||||
|
.unwrap_or_else(handle_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic]
|
||||||
|
fn test_generate_entity_no_username() {
|
||||||
|
let matches = cli::build_cli()
|
||||||
|
.setting(AppSettings::NoBinaryName)
|
||||||
|
.get_matches_from(vec![
|
||||||
|
"generate",
|
||||||
|
"entity",
|
||||||
|
"--database-url",
|
||||||
|
"mysql://:root@localhost:3306/database",
|
||||||
|
]);
|
||||||
|
|
||||||
|
smol::block_on(run_generate_command(matches.subcommand().1.unwrap()))
|
||||||
|
.unwrap_or_else(handle_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic]
|
||||||
|
fn test_generate_entity_no_password() {
|
||||||
|
let matches = cli::build_cli()
|
||||||
|
.setting(AppSettings::NoBinaryName)
|
||||||
|
.get_matches_from(vec![
|
||||||
|
"generate",
|
||||||
|
"entity",
|
||||||
|
"--database-url",
|
||||||
|
"mysql://root:@localhost:3306/database",
|
||||||
|
]);
|
||||||
|
|
||||||
|
smol::block_on(run_generate_command(matches.subcommand().1.unwrap()))
|
||||||
|
.unwrap_or_else(handle_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_std::test]
|
||||||
|
async fn test_generate_entity_no_host() {
|
||||||
|
let matches = cli::build_cli()
|
||||||
|
.setting(AppSettings::NoBinaryName)
|
||||||
|
.get_matches_from(vec![
|
||||||
|
"generate",
|
||||||
|
"entity",
|
||||||
|
"--database-url",
|
||||||
|
"postgres://root:root@/database",
|
||||||
|
]);
|
||||||
|
|
||||||
|
let result = std::panic::catch_unwind(|| {
|
||||||
|
smol::block_on(run_generate_command(matches.subcommand().1.unwrap()))
|
||||||
|
});
|
||||||
|
|
||||||
|
// Make sure result is a ParseError
|
||||||
|
match result {
|
||||||
|
Ok(Err(e)) => match e.downcast::<ParseError>() {
|
||||||
|
Ok(_) => (),
|
||||||
|
Err(e) => panic!("Expected ParseError but got: {:?}", e),
|
||||||
|
},
|
||||||
|
_ => panic!("Should have panicked"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -4,6 +4,7 @@ use crate::{
|
|||||||
};
|
};
|
||||||
use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder};
|
use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder};
|
||||||
use std::{future::Future, pin::Pin};
|
use std::{future::Future, pin::Pin};
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
#[cfg(feature = "sqlx-dep")]
|
#[cfg(feature = "sqlx-dep")]
|
||||||
use sqlx::pool::PoolConnection;
|
use sqlx::pool::PoolConnection;
|
||||||
@ -223,12 +224,13 @@ impl DatabaseConnection {
|
|||||||
|
|
||||||
impl DbBackend {
|
impl DbBackend {
|
||||||
pub fn is_prefix_of(self, base_url: &str) -> bool {
|
pub fn is_prefix_of(self, base_url: &str) -> bool {
|
||||||
|
let base_url_parsed = Url::parse(base_url).unwrap();
|
||||||
match self {
|
match self {
|
||||||
Self::Postgres => {
|
Self::Postgres => {
|
||||||
base_url.starts_with("postgres://") || base_url.starts_with("postgresql://")
|
base_url_parsed.scheme() == "postgres" || base_url_parsed.scheme() == "postgresql"
|
||||||
}
|
}
|
||||||
Self::MySql => base_url.starts_with("mysql://"),
|
Self::MySql => base_url_parsed.scheme() == "mysql",
|
||||||
Self::Sqlite => base_url.starts_with("sqlite:"),
|
Self::Sqlite => base_url_parsed.scheme() == "sqlite",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user