diff --git a/sea-orm-macros/src/derives/active_enum.rs b/sea-orm-macros/src/derives/active_enum.rs index 6f6b7fbc..cd2e1c0a 100644 --- a/sea-orm-macros/src/derives/active_enum.rs +++ b/sea-orm-macros/src/derives/active_enum.rs @@ -1,25 +1,38 @@ use proc_macro2::TokenStream; use quote::{quote, quote_spanned}; -use syn::{punctuated::Punctuated, token::Comma, Lit, Meta}; +use syn::{punctuated::Punctuated, token::Comma, Lit, LitInt, LitStr, Meta}; enum Error { InputNotEnum, Syn(syn::Error), + TT(TokenStream), } struct ActiveEnum { ident: syn::Ident, rs_type: TokenStream, db_type: TokenStream, - variants: syn::punctuated::Punctuated, + is_string: bool, + variants: Vec, +} + +struct ActiveEnumVariant { + ident: syn::Ident, + string_value: Option, + num_value: Option, } impl ActiveEnum { fn new(input: syn::DeriveInput) -> Result { + let ident_span = input.ident.span(); let ident = input.ident; - let mut rs_type = None; - let mut db_type = None; + let mut rs_type = Err(Error::TT(quote_spanned! { + ident_span => compile_error!("Missing macro attribute `rs_type`"); + })); + let mut db_type = Err(Error::TT(quote_spanned! { + ident_span => compile_error!("Missing macro attribute `db_type`"); + })); for attr in input.attrs.iter() { if let Some(ident) = attr.path.get_ident() { if ident != "sea_orm" { @@ -34,11 +47,13 @@ impl ActiveEnum { if let Some(name) = nv.path.get_ident() { if name == "rs_type" { if let Lit::Str(litstr) = &nv.lit { - rs_type = syn::parse_str::(&litstr.value()).ok(); + rs_type = syn::parse_str::(&litstr.value()) + .map_err(Error::Syn); } } else if name == "db_type" { if let Lit::Str(litstr) = &nv.lit { - db_type = syn::parse_str::(&litstr.value()).ok(); + db_type = syn::parse_str::(&litstr.value()) + .map_err(Error::Syn); } } } @@ -46,18 +61,73 @@ impl ActiveEnum { } } } - let rs_type = rs_type.expect("Missing rs_type"); - let db_type = db_type.expect("Missing db_type"); - let variants = match input.data { + let variant_vec = match input.data { syn::Data::Enum(syn::DataEnum { variants, .. }) => variants, _ => return Err(Error::InputNotEnum), }; + let mut is_string = false; + let mut is_int = false; + let mut variants = Vec::new(); + for variant in variant_vec { + let variant_span = variant.ident.span(); + let mut string_value = None; + let mut num_value = None; + for attr in variant.attrs.iter() { + if let Some(ident) = attr.path.get_ident() { + if ident != "sea_orm" { + continue; + } + } else { + continue; + } + if let Ok(list) = attr.parse_args_with(Punctuated::::parse_terminated) + { + for meta in list { + if let Meta::NameValue(nv) = meta { + if let Some(name) = nv.path.get_ident() { + if name == "string_value" { + if let Lit::Str(lit) = nv.lit { + is_string = true; + string_value = Some(lit); + } + } else if name == "num_value" { + if let Lit::Int(lit) = nv.lit { + is_int = true; + num_value = Some(lit); + } + } + } + } + } + } + } + + if is_string && is_int { + return Err(Error::TT(quote_spanned! { + ident_span => compile_error!("All enum variants should specify the same `*_value` macro attribute, either `string_value` or `num_value` but not both"); + })); + } + + if string_value.is_none() && num_value.is_none() { + return Err(Error::TT(quote_spanned! { + variant_span => compile_error!("Missing macro attribute, either `string_value` or `num_value` should be specified"); + })); + } + + variants.push(ActiveEnumVariant { + ident: variant.ident, + string_value, + num_value, + }); + } + Ok(ActiveEnum { ident, - rs_type, - db_type, + rs_type: rs_type?, + db_type: db_type?, + is_string, variants, }) } @@ -73,6 +143,7 @@ impl ActiveEnum { ident, rs_type, db_type, + is_string, variants, } = self; @@ -81,54 +152,25 @@ impl ActiveEnum { .map(|variant| variant.ident.clone()) .collect(); - let mut is_string = false; - let variant_values: Vec = variants .iter() .map(|variant| { - let mut string_value = None; - let mut num_value = None; - for attr in variant.attrs.iter() { - if let Some(ident) = attr.path.get_ident() { - if ident != "sea_orm" { - continue; - } - } else { - continue; - } - if let Ok(list) = - attr.parse_args_with(Punctuated::::parse_terminated) - { - for meta in list.iter() { - if let Meta::NameValue(nv) = meta { - if let Some(name) = nv.path.get_ident() { - if name == "string_value" { - if let Lit::Str(litstr) = &nv.lit { - string_value = Some(litstr.value()); - } - } else if name == "num_value" { - if let Lit::Int(litstr) = &nv.lit { - num_value = litstr.base10_parse::().ok(); - } - } - } - } - } - } - } + let variant_span = variant.ident.span(); - if let Some(string_value) = string_value { - is_string = true; - quote! { #string_value } - } else if let Some(num_value) = num_value { + if let Some(string_value) = &variant.string_value { + let string = string_value.value(); + quote! { #string } + } else if let Some(num_value) = &variant.num_value { quote! { #num_value } } else { - panic!("Either string_value or num_value should be specified") + quote_spanned! { + variant_span => compile_error!("Missing macro attribute, either `string_value` or `num_value` should be specified"); + } } }) .collect(); - let val = if is_string { + let val = if *is_string { quote! { v.as_ref() } } else { quote! { v } @@ -214,6 +256,7 @@ pub fn expand_derive_active_enum(input: syn::DeriveInput) -> syn::Result Ok(quote_spanned! { ident_span => compile_error!("you can only derive ActiveEnum on enums"); }), - Err(Error::Syn(err)) => Err(err), + Err(Error::TT(token_stream)) => Ok(token_stream), + Err(Error::Syn(e)) => Err(e), } }