Implement FromStr for DeriveColumn

This commit is contained in:
Ari Seyhun 2021-08-18 15:20:33 +09:30 committed by Chris Tsang
parent c8989dfb23
commit 46a4cafaa9
5 changed files with 72 additions and 6 deletions

View File

@ -1,6 +1,6 @@
use heck::SnakeCase; use heck::{MixedCase, SnakeCase};
use proc_macro2::{Ident, TokenStream}; use proc_macro2::{Ident, TokenStream};
use quote::{quote, quote_spanned}; use quote::{format_ident, quote, quote_spanned};
use syn::{Data, DataEnum, Fields, Variant}; use syn::{Data, DataEnum, Fields, Variant};
pub fn impl_default_as_str(ident: &Ident, data: &Data) -> syn::Result<TokenStream> { pub fn impl_default_as_str(ident: &Ident, data: &Data) -> syn::Result<TokenStream> {
@ -41,6 +41,44 @@ pub fn impl_default_as_str(ident: &Ident, data: &Data) -> syn::Result<TokenStrea
)) ))
} }
pub fn impl_col_from_str(ident: &Ident, data: &Data) -> syn::Result<TokenStream> {
let parse_error_iden = format_ident!("Parse{}Err", ident);
let data_enum = match data {
Data::Enum(data_enum) => data_enum,
_ => {
return Ok(quote_spanned! {
ident.span() => compile_error!("you can only derive DeriveColumn on enums");
})
}
};
let columns = data_enum.variants.iter().map(|column| {
let column_iden = column.ident.clone();
let column_str_snake = column_iden.to_string().to_snake_case();
let column_str_mixed = column_iden.to_string().to_mixed_case();
quote!(
#column_str_snake | #column_str_mixed => Ok(#ident::#column_iden)
)
});
Ok(quote!(
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct #parse_error_iden;
impl std::str::FromStr for #ident {
type Err = #parse_error_iden;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
#(#columns),*,
_ => Err(#parse_error_iden),
}
}
}
))
}
pub fn expand_derive_column(ident: &Ident, data: &Data) -> syn::Result<TokenStream> { pub fn expand_derive_column(ident: &Ident, data: &Data) -> syn::Result<TokenStream> {
let impl_iden = expand_derive_custom_column(ident, data)?; let impl_iden = expand_derive_custom_column(ident, data)?;
@ -57,10 +95,13 @@ pub fn expand_derive_column(ident: &Ident, data: &Data) -> syn::Result<TokenStre
pub fn expand_derive_custom_column(ident: &Ident, data: &Data) -> syn::Result<TokenStream> { pub fn expand_derive_custom_column(ident: &Ident, data: &Data) -> syn::Result<TokenStream> {
let impl_default_as_str = impl_default_as_str(ident, data)?; let impl_default_as_str = impl_default_as_str(ident, data)?;
let impl_col_from_str = impl_col_from_str(ident, data)?;
Ok(quote!( Ok(quote!(
#impl_default_as_str #impl_default_as_str
#impl_col_from_str
impl sea_orm::Iden for #ident { impl sea_orm::Iden for #ident {
fn unquoted(&self, s: &mut dyn std::fmt::Write) { fn unquoted(&self, s: &mut dyn std::fmt::Write) {
write!(s, "{}", self.as_str()).unwrap(); write!(s, "{}", self.as_str()).unwrap();

View File

@ -324,6 +324,7 @@ mod tests {
tests_cfg::*, ColumnTrait, Condition, DbBackend, EntityTrait, QueryFilter, QueryTrait, tests_cfg::*, ColumnTrait, Condition, DbBackend, EntityTrait, QueryFilter, QueryTrait,
}; };
use sea_query::Query; use sea_query::Query;
use std::str::FromStr;
#[test] #[test]
fn test_in_subquery() { fn test_in_subquery() {
@ -348,4 +349,28 @@ mod tests {
.join(" ") .join(" ")
); );
} }
#[test]
fn test_col_from_str() {
match fruit::Column::from_str("id") {
Ok(col) => assert_eq!(col, fruit::Column::Id),
Err(_) => panic!("fruit from_str fails"),
}
match fruit::Column::from_str("name") {
Ok(col) => assert_eq!(col, fruit::Column::Name),
Err(_) => panic!("fruit from_str fails"),
}
match fruit::Column::from_str("cake_id") {
Ok(col) => assert_eq!(col, fruit::Column::CakeId),
Err(_) => panic!("fruit from_str fails"),
}
match fruit::Column::from_str("cakeId") {
Ok(col) => assert_eq!(col, fruit::Column::CakeId),
Err(_) => panic!("fruit from_str fails"),
}
match fruit::Column::from_str("does_not_exist") {
Ok(_) => panic!("fruit from_str found match when it shouldn't have"),
Err(err) => assert_eq!(err, fruit::ParseColumnErr),
}
}
} }

View File

@ -167,7 +167,7 @@ mod tests {
.left_join(fruit::Entity) .left_join(fruit::Entity)
.select_also(fruit::Entity) .select_also(fruit::Entity)
.filter(cake::Column::Id.eq(1)) .filter(cake::Column::Id.eq(1))
.filter(fruit::Column::Id.eq(2)) .filter(ColumnTrait::eq(&fruit::Column::Id, 2))
.build(DbBackend::MySql) .build(DbBackend::MySql)
.to_string(), .to_string(),
[ [
@ -186,7 +186,7 @@ mod tests {
.left_join(fruit::Entity) .left_join(fruit::Entity)
.select_with(fruit::Entity) .select_with(fruit::Entity)
.filter(cake::Column::Id.eq(1)) .filter(cake::Column::Id.eq(1))
.filter(fruit::Column::Id.eq(2)) .filter(ColumnTrait::eq(&fruit::Column::Id, 2))
.build(DbBackend::MySql) .build(DbBackend::MySql)
.to_string(), .to_string(),
[ [

View File

@ -233,7 +233,7 @@ mod tests {
assert_eq!( assert_eq!(
Update::many(fruit::Entity) Update::many(fruit::Entity)
.col_expr(fruit::Column::CakeId, Expr::value(Value::Int(None))) .col_expr(fruit::Column::CakeId, Expr::value(Value::Int(None)))
.filter(fruit::Column::Id.eq(2)) .filter(ColumnTrait::eq(&fruit::Column::Id, 2))
.build(DbBackend::Postgres) .build(DbBackend::Postgres)
.to_string(), .to_string(),
r#"UPDATE "fruit" SET "cake_id" = NULL WHERE "fruit"."id" = 2"#, r#"UPDATE "fruit" SET "cake_id" = NULL WHERE "fruit"."id" = 2"#,

View File

@ -17,7 +17,7 @@ pub struct Model {
pub cake_id: Option<i32>, pub cake_id: Option<i32>,
} }
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] #[derive(Copy, Clone, PartialEq, Debug, EnumIter, DeriveColumn)]
pub enum Column { pub enum Column {
Id, Id,
Name, Name,