Changed manual url parsing to use Url crate

This commit is contained in:
Forest Anderson 2021-10-16 23:19:48 -04:00
parent 10b101b142
commit bcc5b5066a
4 changed files with 49 additions and 37 deletions

View File

@ -37,6 +37,7 @@ serde_json = { version = "^1", optional = true }
sqlx = { version = "^0.5", optional = true } sqlx = { version = "^0.5", optional = true }
uuid = { version = "0.8", features = ["serde", "v4"], optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true }
ouroboros = "0.11" ouroboros = "0.11"
url = "^2.2"
[dev-dependencies] [dev-dependencies]
smol = { version = "^1.2" } smol = { version = "^1.2" }

View File

@ -32,6 +32,7 @@ sea-schema = { version = "^0.2.9", default-features = false, features = [
sqlx = { version = "^0.5", default-features = false, features = [ "mysql", "postgres" ] } sqlx = { version = "^0.5", default-features = false, features = [ "mysql", "postgres" ] }
env_logger = { version = "^0.9" } env_logger = { version = "^0.9" }
log = { version = "^0.4" } log = { version = "^0.4" }
url = "^2.2"
[features] [features]
default = [ "runtime-async-std-native-tls" ] default = [ "runtime-async-std-native-tls" ]

View File

@ -3,6 +3,7 @@ use dotenv::dotenv;
use log::LevelFilter; use log::LevelFilter;
use sea_orm_codegen::{EntityTransformer, OutputFile, WithSerde}; use sea_orm_codegen::{EntityTransformer, OutputFile, WithSerde};
use std::{error::Error, fmt::Display, fs, io::Write, path::Path, process::Command, str::FromStr}; use std::{error::Error, fmt::Display, fs, io::Write, path::Path, process::Command, str::FromStr};
use url::Url;
mod cli; mod cli;
@ -23,7 +24,8 @@ async fn main() {
async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Error>> { async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Error>> {
match matches.subcommand() { match matches.subcommand() {
("entity", Some(args)) => { ("entity", Some(args)) => {
let url = args.value_of("DATABASE_URL").unwrap(); // The database should be a valid URL that can be parsed
let url = Url::parse(args.value_of("DATABASE_URL").unwrap())?;
let output_dir = args.value_of("OUTPUT_DIR").unwrap(); let output_dir = args.value_of("OUTPUT_DIR").unwrap();
let include_hidden_tables = args.is_present("INCLUDE_HIDDEN_TABLES"); let include_hidden_tables = args.is_present("INCLUDE_HIDDEN_TABLES");
let tables = args let tables = args
@ -33,7 +35,7 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Er
let expanded_format = args.is_present("EXPANDED_FORMAT"); let expanded_format = args.is_present("EXPANDED_FORMAT");
let with_serde = args.value_of("WITH_SERDE").unwrap(); let with_serde = args.value_of("WITH_SERDE").unwrap();
let filter_tables = |table: &str| -> bool { let filter_tables = |table: &str| -> bool {
if tables.len() > 0 { if !tables.is_empty() {
return tables.contains(&table); return tables.contains(&table);
} }
@ -43,7 +45,7 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Er
if include_hidden_tables { if include_hidden_tables {
true true
} else { } else {
!table.starts_with("_") !table.starts_with('_')
} }
}; };
if args.is_present("VERBOSE") { if args.is_present("VERBOSE") {
@ -53,13 +55,16 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Er
.try_init(); .try_init();
} }
let table_stmts = if url.starts_with("mysql://") { let table_stmts = match url.scheme() {
"mysql" => {
use sea_schema::mysql::discovery::SchemaDiscovery; use sea_schema::mysql::discovery::SchemaDiscovery;
use sqlx::MySqlPool; use sqlx::MySqlPool;
let url_parts: Vec<&str> = url.split("/").collect(); // TODO: as far as I can tell, this used to be the last
let schema = url_parts.last().unwrap(); // value of the url, which should have been the database
let connection = MySqlPool::connect(url).await?; // name?
let schema = url.path_segments().unwrap().last().unwrap();
let connection = MySqlPool::connect(url.as_str()).await?;
let schema_discovery = SchemaDiscovery::new(connection, schema); let schema_discovery = SchemaDiscovery::new(connection, schema);
let schema = schema_discovery.discover().await; let schema = schema_discovery.discover().await;
schema schema
@ -69,12 +74,13 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Er
.filter(|schema| filter_hidden_tables(&schema.info.name)) .filter(|schema| filter_hidden_tables(&schema.info.name))
.map(|schema| schema.write()) .map(|schema| schema.write())
.collect() .collect()
} else if url.starts_with("postgres://") || url.starts_with("postgresql://") { }
"postgres" | "postgresql" => {
use sea_schema::postgres::discovery::SchemaDiscovery; use sea_schema::postgres::discovery::SchemaDiscovery;
use sqlx::PgPool; use sqlx::PgPool;
let schema = args.value_of("DATABASE_SCHEMA").unwrap_or("public"); let schema = args.value_of("DATABASE_SCHEMA").unwrap_or("public");
let connection = PgPool::connect(url).await?; let connection = PgPool::connect(url.as_str()).await?;
let schema_discovery = SchemaDiscovery::new(connection, schema); let schema_discovery = SchemaDiscovery::new(connection, schema);
let schema = schema_discovery.discover().await; let schema = schema_discovery.discover().await;
schema schema
@ -84,8 +90,8 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Er
.filter(|schema| filter_hidden_tables(&schema.info.name)) .filter(|schema| filter_hidden_tables(&schema.info.name))
.map(|schema| schema.write()) .map(|schema| schema.write())
.collect() .collect()
} else { }
panic!("This database is not supported ({})", url) _ => unimplemented!("{} is not supported", url.scheme()),
}; };
let output = EntityTransformer::transform(table_stmts)? let output = EntityTransformer::transform(table_stmts)?
@ -99,6 +105,8 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Er
let mut file = fs::File::create(file_path)?; let mut file = fs::File::create(file_path)?;
file.write_all(content.as_bytes())?; file.write_all(content.as_bytes())?;
} }
// Format each of the files
for OutputFile { name, .. } in output.files.iter() { for OutputFile { name, .. } in output.files.iter() {
Command::new("rustfmt") Command::new("rustfmt")
.arg(dir.join(name)) .arg(dir.join(name))

View File

@ -4,6 +4,7 @@ use crate::{
}; };
use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder};
use std::{future::Future, pin::Pin}; use std::{future::Future, pin::Pin};
use url::Url;
#[cfg(feature = "sqlx-dep")] #[cfg(feature = "sqlx-dep")]
use sqlx::pool::PoolConnection; use sqlx::pool::PoolConnection;
@ -223,12 +224,13 @@ impl DatabaseConnection {
impl DbBackend { impl DbBackend {
pub fn is_prefix_of(self, base_url: &str) -> bool { pub fn is_prefix_of(self, base_url: &str) -> bool {
let base_url_parsed = Url::parse(base_url).unwrap();
match self { match self {
Self::Postgres => { Self::Postgres => {
base_url.starts_with("postgres://") || base_url.starts_with("postgresql://") base_url_parsed.scheme() == "postgres" || base_url_parsed.scheme() == "postgresql"
} }
Self::MySql => base_url.starts_with("mysql://"), Self::MySql => base_url_parsed.scheme() == "mysql",
Self::Sqlite => base_url.starts_with("sqlite:"), Self::Sqlite => base_url_parsed.scheme() == "sqlite",
} }
} }