use heck::CamelCase; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; use syn::{parse, punctuated::Punctuated, token::Comma, Expr, Lit, LitInt, LitStr, Meta, UnOp}; enum Error { InputNotEnum, Syn(syn::Error), TT(TokenStream), } struct ActiveEnum { ident: syn::Ident, enum_name: String, rs_type: TokenStream, db_type: TokenStream, is_string: bool, variants: Vec, } struct ActiveEnumVariant { ident: syn::Ident, string_value: Option, num_value: Option, } impl ActiveEnum { fn new(input: syn::DeriveInput) -> Result { let ident_span = input.ident.span(); let ident = input.ident; let mut enum_name = ident.to_string().to_camel_case(); let mut rs_type = Err(Error::TT(quote_spanned! { ident_span => compile_error!("Missing macro attribute `rs_type`"); })); let mut db_type = Err(Error::TT(quote_spanned! { ident_span => compile_error!("Missing macro attribute `db_type`"); })); for attr in input.attrs.iter() { if let Some(ident) = attr.path.get_ident() { if ident != "sea_orm" { continue; } } else { continue; } if let Ok(list) = attr.parse_args_with(Punctuated::::parse_terminated) { for meta in list.iter() { if let Meta::NameValue(nv) = meta { if let Some(name) = nv.path.get_ident() { if name == "rs_type" { if let Lit::Str(litstr) = &nv.lit { rs_type = syn::parse_str::(&litstr.value()) .map_err(Error::Syn); } } else if name == "db_type" { if let Lit::Str(litstr) = &nv.lit { let s = litstr.value(); match s.as_ref() { "Enum" => { db_type = Ok(quote! { Enum { name: Self::name(), variants: Self::iden_values(), } }) } _ => { db_type = syn::parse_str::(&s) .map_err(Error::Syn); } } } } else if name == "enum_name" { if let Lit::Str(litstr) = &nv.lit { enum_name = litstr.value(); } } } } } } } let variant_vec = match input.data { syn::Data::Enum(syn::DataEnum { variants, .. }) => variants, _ => return Err(Error::InputNotEnum), }; let mut is_string = false; let mut is_int = false; let mut variants = Vec::new(); for variant in variant_vec { let variant_span = variant.ident.span(); let mut string_value = None; let mut num_value = None; for attr in variant.attrs.iter() { if let Some(ident) = attr.path.get_ident() { if ident != "sea_orm" { continue; } } else { continue; } if let Ok(list) = attr.parse_args_with(Punctuated::::parse_terminated) { for meta in list { if let Meta::NameValue(nv) = meta { if let Some(name) = nv.path.get_ident() { if name == "string_value" { if let Lit::Str(lit) = nv.lit { is_string = true; string_value = Some(lit); } } else if name == "num_value" { if let Lit::Int(lit) = nv.lit { is_int = true; num_value = Some(lit); } } } } } } } if is_string && is_int { return Err(Error::TT(quote_spanned! { ident_span => compile_error!("All enum variants should specify the same `*_value` macro attribute, either `string_value` or `num_value` but not both"); })); } if string_value.is_none() && num_value.is_none() { match variant.discriminant { Some((_, Expr::Lit(exprlit))) => { if let Lit::Int(litint) = exprlit.lit { is_int = true; num_value = Some(litint); } else { return Err(Error::TT(quote_spanned! { variant_span => compile_error!("Enum variant discriminant is not an integer"); })); } } //rust doesn't provide negative variants in enums as a single LitInt, this workarounds that Some((_, Expr::Unary(exprnlit))) => { if let UnOp::Neg(_) = exprnlit.op { if let Expr::Lit(exprlit) = *exprnlit.expr { if let Lit::Int(litint) = exprlit.lit { let negative_token = quote! { -#litint }; let litint = parse(negative_token.into()).unwrap(); is_int = true; num_value = Some(litint); } } } else { return Err(Error::TT(quote_spanned! { variant_span => compile_error!("Only - token is supported in enum variants, not ! and *"); })); } } _ => { return Err(Error::TT(quote_spanned! { variant_span => compile_error!("Missing macro attribute, either `string_value` or `num_value` should be specified or specify repr[X] and have a value for every entry"); })); } } } variants.push(ActiveEnumVariant { ident: variant.ident, string_value, num_value, }); } Ok(ActiveEnum { ident, enum_name, rs_type: rs_type?, db_type: db_type?, is_string, variants, }) } fn expand(&self) -> syn::Result { let expanded_impl_active_enum = self.impl_active_enum(); Ok(expanded_impl_active_enum) } fn impl_active_enum(&self) -> TokenStream { let Self { ident, enum_name, rs_type, db_type, is_string, variants, } = self; let variant_idents: Vec = variants .iter() .map(|variant| variant.ident.clone()) .collect(); let variant_values: Vec = variants .iter() .map(|variant| { let variant_span = variant.ident.span(); if let Some(string_value) = &variant.string_value { let string = string_value.value(); quote! { #string } } else if let Some(num_value) = &variant.num_value { quote! { #num_value } } else { quote_spanned! { variant_span => compile_error!("Missing macro attribute, either `string_value` or `num_value` should be specified"); } } }) .collect(); let val = if *is_string { quote! { v.as_ref() } } else { quote! { v } }; let enum_name_iden = format_ident!("{}Enum", ident); let str_variants: Vec = variants .iter() .filter_map(|variant| { variant .string_value .as_ref() .map(|string_value| string_value.value()) }) .collect(); let impl_enum_variant_iden = if !str_variants.is_empty() { let enum_variant_iden = format_ident!("{}Variant", ident); let enum_variants: Vec = str_variants .iter() .map(|v| format_ident!("{}", v.to_camel_case())) .collect(); quote!( #[derive(Debug, Clone, PartialEq, Eq, sea_orm::EnumIter)] pub enum #enum_variant_iden { #( #enum_variants, )* } #[automatically_derived] impl sea_orm::sea_query::Iden for #enum_variant_iden { fn unquoted(&self, s: &mut dyn std::fmt::Write) { write!(s, "{}", match self { #( Self::#enum_variants => #str_variants, )* }).unwrap(); } } #[automatically_derived] impl #ident { pub fn iden_values() -> Vec { <#enum_variant_iden as sea_orm::strum::IntoEnumIterator>::iter() .map(|v| sea_orm::sea_query::SeaRc::new(v) as sea_orm::sea_query::DynIden) .collect() } } ) } else { quote!() }; quote!( #[derive(Debug, Clone, PartialEq, Eq)] pub struct #enum_name_iden; #[automatically_derived] impl sea_orm::sea_query::Iden for #enum_name_iden { fn unquoted(&self, s: &mut dyn std::fmt::Write) { write!(s, "{}", #enum_name).unwrap(); } } #impl_enum_variant_iden #[automatically_derived] impl sea_orm::ActiveEnum for #ident { type Value = #rs_type; fn name() -> sea_orm::sea_query::DynIden { sea_orm::sea_query::SeaRc::new(#enum_name_iden) as sea_orm::sea_query::DynIden } fn to_value(&self) -> Self::Value { match self { #( Self::#variant_idents => #variant_values, )* } .to_owned() } fn try_from_value(v: &Self::Value) -> std::result::Result { match #val { #( #variant_values => Ok(Self::#variant_idents), )* _ => Err(sea_orm::DbErr::Type(format!( "unexpected value for {} enum: {}", stringify!(#ident), v ))), } } fn db_type() -> sea_orm::ColumnDef { sea_orm::ColumnType::#db_type.def() } } #[automatically_derived] #[allow(clippy::from_over_into)] impl Into for #ident { fn into(self) -> sea_orm::sea_query::Value { ::to_value(&self).into() } } #[automatically_derived] impl sea_orm::TryGetable for #ident { fn try_get(res: &sea_orm::QueryResult, pre: &str, col: &str) -> std::result::Result { let value = <::Value as sea_orm::TryGetable>::try_get(res, pre, col)?; ::try_from_value(&value).map_err(sea_orm::TryGetError::DbErr) } } #[automatically_derived] impl sea_orm::sea_query::ValueType for #ident { fn try_from(v: sea_orm::sea_query::Value) -> std::result::Result { let value = <::Value as sea_orm::sea_query::ValueType>::try_from(v)?; ::try_from_value(&value).map_err(|_| sea_orm::sea_query::ValueTypeErr) } fn type_name() -> String { <::Value as sea_orm::sea_query::ValueType>::type_name() } fn array_type() -> sea_orm::sea_query::ArrayType { unimplemented!("Array of Enum is not supported") } fn column_type() -> sea_orm::sea_query::ColumnType { ::db_type() .get_column_type() .to_owned() .into() } } #[automatically_derived] impl sea_orm::sea_query::Nullable for #ident { fn null() -> sea_orm::sea_query::Value { <::Value as sea_orm::sea_query::Nullable>::null() } } #[automatically_derived] impl std::fmt::Display for #ident { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let v: sea_orm::sea_query::Value = ::to_value(&self).into(); write!(f, "{}", v) } } ) } } pub fn expand_derive_active_enum(input: syn::DeriveInput) -> syn::Result { let ident_span = input.ident.span(); match ActiveEnum::new(input) { Ok(model) => model.expand(), Err(Error::InputNotEnum) => Ok(quote_spanned! { ident_span => compile_error!("you can only derive ActiveEnum on enums"); }), Err(Error::TT(token_stream)) => Ok(token_stream), Err(Error::Syn(e)) => Err(e), } }