diff --git a/sea-orm-macros/src/derives/column.rs b/sea-orm-macros/src/derives/column.rs index 034e966d..0315fa1e 100644 --- a/sea-orm-macros/src/derives/column.rs +++ b/sea-orm-macros/src/derives/column.rs @@ -1,6 +1,6 @@ -use heck::SnakeCase; +use heck::{MixedCase, SnakeCase}; use proc_macro2::{Ident, TokenStream}; -use quote::{quote, quote_spanned}; +use quote::{format_ident, quote, quote_spanned}; use syn::{Data, DataEnum, Fields, Variant}; pub fn impl_default_as_str(ident: &Ident, data: &Data) -> syn::Result { @@ -41,6 +41,44 @@ pub fn impl_default_as_str(ident: &Ident, data: &Data) -> syn::Result syn::Result { + let parse_error_iden = format_ident!("Parse{}Err", ident); + + let data_enum = match data { + Data::Enum(data_enum) => data_enum, + _ => { + return Ok(quote_spanned! { + ident.span() => compile_error!("you can only derive DeriveColumn on enums"); + }) + } + }; + + let columns = data_enum.variants.iter().map(|column| { + let column_iden = column.ident.clone(); + let column_str_snake = column_iden.to_string().to_snake_case(); + let column_str_mixed = column_iden.to_string().to_mixed_case(); + quote!( + #column_str_snake | #column_str_mixed => Ok(#ident::#column_iden) + ) + }); + + Ok(quote!( + #[derive(Debug, Clone, Copy, PartialEq)] + pub struct #parse_error_iden; + + impl std::str::FromStr for #ident { + type Err = #parse_error_iden; + + fn from_str(s: &str) -> Result { + match s { + #(#columns),*, + _ => Err(#parse_error_iden), + } + } + } + )) +} + pub fn expand_derive_column(ident: &Ident, data: &Data) -> syn::Result { let impl_iden = expand_derive_custom_column(ident, data)?; @@ -57,10 +95,13 @@ pub fn expand_derive_column(ident: &Ident, data: &Data) -> syn::Result syn::Result { let impl_default_as_str = impl_default_as_str(ident, data)?; + let impl_col_from_str = impl_col_from_str(ident, data)?; Ok(quote!( #impl_default_as_str + #impl_col_from_str + impl sea_orm::Iden for #ident { fn unquoted(&self, s: &mut dyn std::fmt::Write) { write!(s, "{}", self.as_str()).unwrap(); diff --git a/src/entity/column.rs b/src/entity/column.rs index 045a85ba..2456aba7 100644 --- a/src/entity/column.rs +++ b/src/entity/column.rs @@ -324,6 +324,7 @@ mod tests { tests_cfg::*, ColumnTrait, Condition, DbBackend, EntityTrait, QueryFilter, QueryTrait, }; use sea_query::Query; + use std::str::FromStr; #[test] fn test_in_subquery() { @@ -348,4 +349,28 @@ mod tests { .join(" ") ); } + + #[test] + fn test_col_from_str() { + match fruit::Column::from_str("id") { + Ok(col) => assert_eq!(col, fruit::Column::Id), + Err(_) => panic!("fruit from_str fails"), + } + match fruit::Column::from_str("name") { + Ok(col) => assert_eq!(col, fruit::Column::Name), + Err(_) => panic!("fruit from_str fails"), + } + match fruit::Column::from_str("cake_id") { + Ok(col) => assert_eq!(col, fruit::Column::CakeId), + Err(_) => panic!("fruit from_str fails"), + } + match fruit::Column::from_str("cakeId") { + Ok(col) => assert_eq!(col, fruit::Column::CakeId), + Err(_) => panic!("fruit from_str fails"), + } + match fruit::Column::from_str("does_not_exist") { + Ok(_) => panic!("fruit from_str found match when it shouldn't have"), + Err(err) => assert_eq!(err, fruit::ParseColumnErr), + } + } } diff --git a/src/query/combine.rs b/src/query/combine.rs index 8cce0510..71783e8b 100644 --- a/src/query/combine.rs +++ b/src/query/combine.rs @@ -167,7 +167,7 @@ mod tests { .left_join(fruit::Entity) .select_also(fruit::Entity) .filter(cake::Column::Id.eq(1)) - .filter(fruit::Column::Id.eq(2)) + .filter(ColumnTrait::eq(&fruit::Column::Id, 2)) .build(DbBackend::MySql) .to_string(), [ @@ -186,7 +186,7 @@ mod tests { .left_join(fruit::Entity) .select_with(fruit::Entity) .filter(cake::Column::Id.eq(1)) - .filter(fruit::Column::Id.eq(2)) + .filter(ColumnTrait::eq(&fruit::Column::Id, 2)) .build(DbBackend::MySql) .to_string(), [ diff --git a/src/query/update.rs b/src/query/update.rs index 21fd39cb..014f6ae1 100644 --- a/src/query/update.rs +++ b/src/query/update.rs @@ -233,7 +233,7 @@ mod tests { assert_eq!( Update::many(fruit::Entity) .col_expr(fruit::Column::CakeId, Expr::value(Value::Int(None))) - .filter(fruit::Column::Id.eq(2)) + .filter(ColumnTrait::eq(&fruit::Column::Id, 2)) .build(DbBackend::Postgres) .to_string(), r#"UPDATE "fruit" SET "cake_id" = NULL WHERE "fruit"."id" = 2"#, diff --git a/src/tests_cfg/fruit.rs b/src/tests_cfg/fruit.rs index 0511ae58..48d7fcfb 100644 --- a/src/tests_cfg/fruit.rs +++ b/src/tests_cfg/fruit.rs @@ -17,7 +17,7 @@ pub struct Model { pub cake_id: Option, } -#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +#[derive(Copy, Clone, PartialEq, Debug, EnumIter, DeriveColumn)] pub enum Column { Id, Name,