From 2eda8aa3f2b5dbb9683e60f8d45227942c9ee276 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Wed, 22 Mar 2023 11:47:15 +0800 Subject: [PATCH] Drop the use of `sea-strum` and depends on the original `strum` with a tailored `EnumIter` provided (#1535) * Update heck dependency * Fix formatter error * Drop the use of `sea-strum` and depends on the original `strum` with a tailored `EnumIter` provided * fmt * Depends on `strum` 0.23 * Depends on `strum` 0.24 * Source code adapted from https://github.com/Peternator7/strum * Update LICENSE --------- Co-authored-by: Sergei Ivankov Co-authored-by: Sergei Ivankov <96142843+sergeiivankov@users.noreply.github.com> Co-authored-by: Chris Tsang --- Cargo.toml | 8 +- sea-orm-macros/Cargo.toml | 5 +- sea-orm-macros/src/derives/active_model.rs | 2 +- .../src/{ => derives}/attributes.rs | 0 sea-orm-macros/src/derives/entity.rs | 2 +- sea-orm-macros/src/derives/entity_model.rs | 2 +- sea-orm-macros/src/derives/mod.rs | 2 + sea-orm-macros/src/derives/model.rs | 2 +- sea-orm-macros/src/derives/relation.rs | 2 +- sea-orm-macros/src/{ => derives}/util.rs | 0 sea-orm-macros/src/lib.rs | 37 ++- sea-orm-macros/src/strum/LICENSE | 23 ++ sea-orm-macros/src/strum/enum_iter.rs | 172 ++++++++++ .../src/strum/helpers/case_style.rs | 117 +++++++ sea-orm-macros/src/strum/helpers/metadata.rs | 309 ++++++++++++++++++ sea-orm-macros/src/strum/helpers/mod.rs | 24 ++ .../src/strum/helpers/type_props.rs | 116 +++++++ .../src/strum/helpers/variant_props.rs | 102 ++++++ sea-orm-macros/src/strum/mod.rs | 4 + src/lib.rs | 2 +- 20 files changed, 918 insertions(+), 13 deletions(-) rename sea-orm-macros/src/{ => derives}/attributes.rs (100%) rename sea-orm-macros/src/{ => derives}/util.rs (100%) create mode 100644 sea-orm-macros/src/strum/LICENSE create mode 100644 sea-orm-macros/src/strum/enum_iter.rs create mode 100644 sea-orm-macros/src/strum/helpers/case_style.rs create mode 100644 sea-orm-macros/src/strum/helpers/metadata.rs create mode 100644 sea-orm-macros/src/strum/helpers/mod.rs create mode 100644 sea-orm-macros/src/strum/helpers/type_props.rs create mode 100644 sea-orm-macros/src/strum/helpers/variant_props.rs create mode 100644 sea-orm-macros/src/strum/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 9e9a8f3d..f89b5c7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,10 +33,10 @@ log = { version = "0.4", default-features = false } tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] } rust_decimal = { version = "1", default-features = false, optional = true } bigdecimal = { version = "0.3", default-features = false, optional = true } -sea-orm-macros = { version = "0.12.0", path = "sea-orm-macros", default-features = false, optional = true } +sea-orm-macros = { version = "0.12.0", path = "sea-orm-macros", default-features = false, features = ["strum"] } sea-query = { version = "0.28.3", features = ["thread-safe"] } sea-query-binder = { version = "0.3", default-features = false, optional = true } -sea-strum = { version = "0.23", default-features = false, features = ["derive", "sea-orm"] } +strum = { version = "0.24", default-features = false } serde = { version = "1.0", default-features = false } serde_json = { version = "1.0", default-features = false, optional = true } sqlx = { version = "0.6", default-features = false, optional = true } @@ -72,7 +72,7 @@ default = [ "with-uuid", "with-time", ] -macros = ["sea-orm-macros", "sea-query/derive"] +macros = ["sea-orm-macros/derive", "sea-query/derive"] mock = [] with-json = ["serde_json", "sea-query/with-json", "chrono?/serde", "time?/serde", "uuid?/serde", "sea-query-binder?/with-json", "sqlx?/json"] with-chrono = ["chrono", "sea-query/with-chrono", "sea-query-binder?/with-chrono", "sqlx?/chrono"] @@ -80,7 +80,7 @@ with-rust_decimal = ["rust_decimal", "sea-query/with-rust_decimal", "sea-query-b with-bigdecimal = ["bigdecimal", "sea-query/with-bigdecimal", "sea-query-binder?/with-bigdecimal", "sqlx?/bigdecimal"] with-uuid = ["uuid", "sea-query/with-uuid", "sea-query-binder?/with-uuid", "sqlx?/uuid"] with-time = ["time", "sea-query/with-time", "sea-query-binder?/with-time", "sqlx?/time"] -postgres-array = ["sea-query/postgres-array", "sea-query-binder?/postgres-array", "sea-orm-macros?/postgres-array"] +postgres-array = ["sea-query/postgres-array", "sea-query-binder?/postgres-array", "sea-orm-macros/postgres-array"] sea-orm-internal = [] sqlx-dep = [] sqlx-all = ["sqlx-mysql", "sqlx-postgres", "sqlx-sqlite"] diff --git a/sea-orm-macros/Cargo.toml b/sea-orm-macros/Cargo.toml index a470a788..09b0c2d8 100644 --- a/sea-orm-macros/Cargo.toml +++ b/sea-orm-macros/Cargo.toml @@ -18,7 +18,7 @@ path = "src/lib.rs" proc-macro = true [dependencies] -bae = { version = "0.1", default-features = false } +bae = { version = "0.1", default-features = false, optional = true } syn = { version = "1", default-features = false, features = ["parsing", "proc-macro", "derive", "printing"] } quote = { version = "1", default-features = false } heck = { version = "0.4", default-features = false } @@ -29,4 +29,7 @@ sea-orm = { path = "../", features = ["macros"] } serde = { version = "1.0", features = ["derive"] } [features] +default = ["derive"] postgres-array = [] +derive = ["bae"] +strum = [] diff --git a/sea-orm-macros/src/derives/active_model.rs b/sea-orm-macros/src/derives/active_model.rs index 5ed5d9e4..dbea86f1 100644 --- a/sea-orm-macros/src/derives/active_model.rs +++ b/sea-orm-macros/src/derives/active_model.rs @@ -1,4 +1,4 @@ -use crate::util::{ +use super::util::{ escape_rust_keyword, field_not_ignored, format_field_ident, trim_starting_raw_identifier, }; use heck::ToUpperCamelCase; diff --git a/sea-orm-macros/src/attributes.rs b/sea-orm-macros/src/derives/attributes.rs similarity index 100% rename from sea-orm-macros/src/attributes.rs rename to sea-orm-macros/src/derives/attributes.rs diff --git a/sea-orm-macros/src/derives/entity.rs b/sea-orm-macros/src/derives/entity.rs index 8af90a60..5f9a0733 100644 --- a/sea-orm-macros/src/derives/entity.rs +++ b/sea-orm-macros/src/derives/entity.rs @@ -3,7 +3,7 @@ use std::iter::FromIterator; use proc_macro2::TokenStream; use quote::{format_ident, quote}; -use crate::attributes::derive_attr; +use super::attributes::derive_attr; struct DeriveEntity { column_ident: syn::Ident, diff --git a/sea-orm-macros/src/derives/entity_model.rs b/sea-orm-macros/src/derives/entity_model.rs index 2bdb06c7..c642d9a8 100644 --- a/sea-orm-macros/src/derives/entity_model.rs +++ b/sea-orm-macros/src/derives/entity_model.rs @@ -1,4 +1,4 @@ -use crate::util::{escape_rust_keyword, trim_starting_raw_identifier}; +use super::util::{escape_rust_keyword, trim_starting_raw_identifier}; use heck::{ToSnakeCase, ToUpperCamelCase}; use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, quote_spanned}; diff --git a/sea-orm-macros/src/derives/mod.rs b/sea-orm-macros/src/derives/mod.rs index 295d3e7d..a3c71fe1 100644 --- a/sea-orm-macros/src/derives/mod.rs +++ b/sea-orm-macros/src/derives/mod.rs @@ -1,6 +1,7 @@ mod active_enum; mod active_model; mod active_model_behavior; +mod attributes; mod column; mod entity; mod entity_model; @@ -11,6 +12,7 @@ mod model; mod primary_key; mod relation; mod try_getable_from_json; +mod util; pub use active_enum::*; pub use active_model::*; diff --git a/sea-orm-macros/src/derives/model.rs b/sea-orm-macros/src/derives/model.rs index f99e4644..d3813f28 100644 --- a/sea-orm-macros/src/derives/model.rs +++ b/sea-orm-macros/src/derives/model.rs @@ -1,4 +1,4 @@ -use crate::{ +use super::{ attributes::derive_attr, util::{escape_rust_keyword, field_not_ignored, trim_starting_raw_identifier}, }; diff --git a/sea-orm-macros/src/derives/relation.rs b/sea-orm-macros/src/derives/relation.rs index 972d9e17..f7f15c32 100644 --- a/sea-orm-macros/src/derives/relation.rs +++ b/sea-orm-macros/src/derives/relation.rs @@ -1,7 +1,7 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; -use crate::attributes::{derive_attr, field_attr}; +use super::attributes::{derive_attr, field_attr}; enum Error { InputNotEnum, diff --git a/sea-orm-macros/src/util.rs b/sea-orm-macros/src/derives/util.rs similarity index 100% rename from sea-orm-macros/src/util.rs rename to sea-orm-macros/src/derives/util.rs diff --git a/sea-orm-macros/src/lib.rs b/sea-orm-macros/src/lib.rs index b5d88d41..89c32013 100644 --- a/sea-orm-macros/src/lib.rs +++ b/sea-orm-macros/src/lib.rs @@ -3,9 +3,11 @@ extern crate proc_macro; use proc_macro::TokenStream; use syn::{parse_macro_input, DeriveInput, Error}; -mod attributes; +#[cfg(feature = "derive")] mod derives; -mod util; + +#[cfg(feature = "strum")] +mod strum; /// Create an Entity /// @@ -70,6 +72,7 @@ mod util; /// # /// # impl ActiveModelBehavior for ActiveModel {} /// ``` +#[cfg(feature = "derive")] #[proc_macro_derive(DeriveEntity, attributes(sea_orm))] pub fn derive_entity(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -129,6 +132,7 @@ pub fn derive_entity(input: TokenStream) -> TokenStream { /// # /// # impl ActiveModelBehavior for ActiveModel {} /// ``` +#[cfg(feature = "derive")] #[proc_macro_derive(DeriveEntityModel, attributes(sea_orm))] pub fn derive_entity_model(input: TokenStream) -> TokenStream { let input_ts = input.clone(); @@ -216,6 +220,7 @@ pub fn derive_entity_model(input: TokenStream) -> TokenStream { /// # /// # impl ActiveModelBehavior for ActiveModel {} /// ``` +#[cfg(feature = "derive")] #[proc_macro_derive(DerivePrimaryKey, attributes(sea_orm))] pub fn derive_primary_key(input: TokenStream) -> TokenStream { let DeriveInput { ident, data, .. } = parse_macro_input!(input); @@ -241,6 +246,7 @@ pub fn derive_primary_key(input: TokenStream) -> TokenStream { /// FillingId, /// } /// ``` +#[cfg(feature = "derive")] #[proc_macro_derive(DeriveColumn, attributes(sea_orm))] pub fn derive_column(input: TokenStream) -> TokenStream { let DeriveInput { ident, data, .. } = parse_macro_input!(input); @@ -274,6 +280,7 @@ pub fn derive_column(input: TokenStream) -> TokenStream { /// } /// } /// ``` +#[cfg(feature = "derive")] #[proc_macro_derive(DeriveCustomColumn)] pub fn derive_custom_column(input: TokenStream) -> TokenStream { let DeriveInput { ident, data, .. } = parse_macro_input!(input); @@ -349,6 +356,7 @@ pub fn derive_custom_column(input: TokenStream) -> TokenStream { /// # /// # impl ActiveModelBehavior for ActiveModel {} /// ``` +#[cfg(feature = "derive")] #[proc_macro_derive(DeriveModel, attributes(sea_orm))] pub fn derive_model(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -421,6 +429,7 @@ pub fn derive_model(input: TokenStream) -> TokenStream { /// # /// # impl ActiveModelBehavior for ActiveModel {} /// ``` +#[cfg(feature = "derive")] #[proc_macro_derive(DeriveActiveModel, attributes(sea_orm))] pub fn derive_active_model(input: TokenStream) -> TokenStream { let DeriveInput { ident, data, .. } = parse_macro_input!(input); @@ -432,6 +441,7 @@ pub fn derive_active_model(input: TokenStream) -> TokenStream { } /// Derive into an active model +#[cfg(feature = "derive")] #[proc_macro_derive(DeriveIntoActiveModel, attributes(sea_orm))] pub fn derive_into_active_model(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -503,6 +513,7 @@ pub fn derive_into_active_model(input: TokenStream) -> TokenStream { /// # } /// # } /// ``` +#[cfg(feature = "derive")] #[proc_macro_derive(DeriveActiveModelBehavior)] pub fn derive_active_model_behavior(input: TokenStream) -> TokenStream { let DeriveInput { ident, data, .. } = parse_macro_input!(input); @@ -552,6 +563,7 @@ pub fn derive_active_model_behavior(input: TokenStream) -> TokenStream { /// White = 1, /// } /// ``` +#[cfg(feature = "derive")] #[proc_macro_derive(DeriveActiveEnum, attributes(sea_orm))] pub fn derive_active_enum(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -574,6 +586,7 @@ pub fn derive_active_enum(input: TokenStream) -> TokenStream { /// num_of_fruits: i32, /// } /// ``` +#[cfg(feature = "derive")] #[proc_macro_derive(FromQueryResult)] pub fn derive_from_query_result(input: TokenStream) -> TokenStream { let DeriveInput { ident, data, .. } = parse_macro_input!(input); @@ -608,6 +621,7 @@ pub fn derive_from_query_result(input: TokenStream) -> TokenStream { /// CakeExpanded, /// } /// ``` +#[cfg(feature = "derive")] #[proc_macro_derive(DeriveRelation, attributes(sea_orm))] pub fn derive_relation(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -635,6 +649,7 @@ pub fn derive_relation(input: TokenStream) -> TokenStream { /// } /// } /// ``` +#[cfg(feature = "derive")] #[proc_macro_derive(DeriveMigrationName)] pub fn derive_migration_name(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -643,6 +658,7 @@ pub fn derive_migration_name(input: TokenStream) -> TokenStream { .into() } +#[cfg(feature = "derive")] #[proc_macro_derive(FromJsonQueryResult)] pub fn derive_from_json_query_result(input: TokenStream) -> TokenStream { let DeriveInput { ident, .. } = parse_macro_input!(input); @@ -654,6 +670,7 @@ pub fn derive_from_json_query_result(input: TokenStream) -> TokenStream { } #[doc(hidden)] +#[cfg(feature = "derive")] #[proc_macro_attribute] pub fn test(_: TokenStream, input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as syn::ItemFn); @@ -676,3 +693,19 @@ pub fn test(_: TokenStream, input: TokenStream) -> TokenStream { ) .into() } + +/// Creates a new type that iterates of the variants of an enum. +/// +/// Iterate over the variants of an Enum. Any additional data on your variants will be set to `Default::default()`. +/// The macro implements `strum::IntoEnumIterator` on your enum and creates a new type called `YourEnumIter` that is the iterator object. +/// You cannot derive `EnumIter` on any type with a lifetime bound (`<'a>`) because the iterator would surely +/// create [unbounded lifetimes](https://doc.rust-lang.org/nightly/nomicon/unbounded-lifetimes.html). +#[cfg(feature = "strum")] +#[proc_macro_derive(EnumIter, attributes(strum))] +pub fn enum_iter(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + + strum::enum_iter::enum_iter_inner(&ast) + .unwrap_or_else(Error::into_compile_error) + .into() +} diff --git a/sea-orm-macros/src/strum/LICENSE b/sea-orm-macros/src/strum/LICENSE new file mode 100644 index 00000000..2bab9c5a --- /dev/null +++ b/sea-orm-macros/src/strum/LICENSE @@ -0,0 +1,23 @@ +> The `strum` module is adapted from https://github.com/Peternator7/strum + +MIT License + +Copyright (c) 2019 Peter Glotfelty + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/sea-orm-macros/src/strum/enum_iter.rs b/sea-orm-macros/src/strum/enum_iter.rs new file mode 100644 index 00000000..e4f153ad --- /dev/null +++ b/sea-orm-macros/src/strum/enum_iter.rs @@ -0,0 +1,172 @@ +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::{Data, DeriveInput, Fields, Ident}; + +use super::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties}; + +pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result { + let name = &ast.ident; + let gen = &ast.generics; + let (impl_generics, ty_generics, where_clause) = gen.split_for_impl(); + let vis = &ast.vis; + let type_properties = ast.get_type_properties()?; + let strum_module_path = type_properties.crate_module_path(); + let doc_comment = format!("An iterator over the variants of [{}]", name); + + if gen.lifetimes().count() > 0 { + return Err(syn::Error::new( + Span::call_site(), + "This macro doesn't support enums with lifetimes. \ + The resulting enums would be unbounded.", + )); + } + + let phantom_data = if gen.type_params().count() > 0 { + let g = gen.type_params().map(|param| ¶m.ident); + quote! { < ( #(#g),* ) > } + } else { + quote! { < () > } + }; + + let variants = match &ast.data { + Data::Enum(v) => &v.variants, + _ => return Err(non_enum_error()), + }; + + let mut arms = Vec::new(); + let mut idx = 0usize; + for variant in variants { + if variant.get_variant_properties()?.disabled.is_some() { + continue; + } + + let ident = &variant.ident; + let params = match &variant.fields { + Fields::Unit => quote! {}, + Fields::Unnamed(fields) => { + let defaults = ::core::iter::repeat(quote!(::core::default::Default::default())) + .take(fields.unnamed.len()); + quote! { (#(#defaults),*) } + } + Fields::Named(fields) => { + let fields = fields + .named + .iter() + .map(|field| field.ident.as_ref().unwrap()); + quote! { {#(#fields: ::core::default::Default::default()),*} } + } + }; + + arms.push(quote! {#idx => ::core::option::Option::Some(#name::#ident #params)}); + idx += 1; + } + + let variant_count = arms.len(); + arms.push(quote! { _ => ::core::option::Option::None }); + let iter_name = syn::parse_str::(&format!("{}Iter", name)).unwrap(); + + // Create a string literal "MyEnumIter" to use in the debug impl. + let iter_name_debug_struct = + syn::parse_str::(&format!("\"{}\"", iter_name)).unwrap(); + + Ok(quote! { + #[doc = #doc_comment] + #[allow( + missing_copy_implementations, + )] + #vis struct #iter_name #ty_generics { + idx: usize, + back_idx: usize, + marker: ::core::marker::PhantomData #phantom_data, + } + + impl #impl_generics core::fmt::Debug for #iter_name #ty_generics #where_clause { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + // We don't know if the variants implement debug themselves so the only thing we + // can really show is how many elements are left. + f.debug_struct(#iter_name_debug_struct) + .field("len", &self.len()) + .finish() + } + } + + impl #impl_generics #iter_name #ty_generics #where_clause { + fn get(&self, idx: usize) -> Option<#name #ty_generics> { + match idx { + #(#arms),* + } + } + } + + impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause { + type Iterator = #iter_name #ty_generics; + fn iter() -> #iter_name #ty_generics { + #iter_name { + idx: 0, + back_idx: 0, + marker: ::core::marker::PhantomData, + } + } + } + + impl #impl_generics Iterator for #iter_name #ty_generics #where_clause { + type Item = #name #ty_generics; + + fn next(&mut self) -> Option<::Item> { + self.nth(0) + } + + fn size_hint(&self) -> (usize, Option) { + let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx }; + (t, Some(t)) + } + + fn nth(&mut self, n: usize) -> Option<::Item> { + let idx = self.idx + n + 1; + if idx + self.back_idx > #variant_count { + // We went past the end of the iterator. Freeze idx at #variant_count + // so that it doesn't overflow if the user calls this repeatedly. + // See PR #76 for context. + self.idx = #variant_count; + ::core::option::Option::None + } else { + self.idx = idx; + self.get(idx - 1) + } + } + } + + impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause { + fn len(&self) -> usize { + self.size_hint().0 + } + } + + impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause { + fn next_back(&mut self) -> Option<::Item> { + let back_idx = self.back_idx + 1; + + if self.idx + back_idx > #variant_count { + // We went past the end of the iterator. Freeze back_idx at #variant_count + // so that it doesn't overflow if the user calls this repeatedly. + // See PR #76 for context. + self.back_idx = #variant_count; + ::core::option::Option::None + } else { + self.back_idx = back_idx; + self.get(#variant_count - self.back_idx) + } + } + } + + impl #impl_generics Clone for #iter_name #ty_generics #where_clause { + fn clone(&self) -> #iter_name #ty_generics { + #iter_name { + idx: self.idx, + back_idx: self.back_idx, + marker: self.marker.clone(), + } + } + } + }) +} diff --git a/sea-orm-macros/src/strum/helpers/case_style.rs b/sea-orm-macros/src/strum/helpers/case_style.rs new file mode 100644 index 00000000..42538260 --- /dev/null +++ b/sea-orm-macros/src/strum/helpers/case_style.rs @@ -0,0 +1,117 @@ +use heck::{ + ToKebabCase, ToLowerCamelCase, ToShoutySnakeCase, ToSnakeCase, ToTitleCase, ToUpperCamelCase, +}; +use std::str::FromStr; +use syn::{ + parse::{Parse, ParseStream}, + Ident, LitStr, +}; + +#[allow(clippy::enum_variant_names)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum CaseStyle { + CamelCase, + KebabCase, + MixedCase, + ShoutySnakeCase, + SnakeCase, + TitleCase, + UpperCase, + LowerCase, + ScreamingKebabCase, + PascalCase, +} + +const VALID_CASE_STYLES: &[&str] = &[ + "camelCase", + "PascalCase", + "kebab-case", + "snake_case", + "SCREAMING_SNAKE_CASE", + "SCREAMING-KEBAB-CASE", + "lowercase", + "UPPERCASE", + "title_case", + "mixed_case", +]; + +impl Parse for CaseStyle { + fn parse(input: ParseStream) -> syn::Result { + let text = input.parse::()?; + let val = text.value(); + + val.as_str().parse().map_err(|_| { + syn::Error::new_spanned( + &text, + format!( + "Unexpected case style for serialize_all: `{}`. Valid values are: `{:?}`", + val, VALID_CASE_STYLES + ), + ) + }) + } +} + +impl FromStr for CaseStyle { + type Err = (); + + fn from_str(text: &str) -> Result { + Ok(match text { + "camel_case" | "PascalCase" => CaseStyle::PascalCase, + "camelCase" => CaseStyle::CamelCase, + "snake_case" | "snek_case" => CaseStyle::SnakeCase, + "kebab_case" | "kebab-case" => CaseStyle::KebabCase, + "SCREAMING-KEBAB-CASE" => CaseStyle::ScreamingKebabCase, + "shouty_snake_case" | "shouty_snek_case" | "SCREAMING_SNAKE_CASE" => { + CaseStyle::ShoutySnakeCase + } + "title_case" => CaseStyle::TitleCase, + "mixed_case" => CaseStyle::MixedCase, + "lowercase" => CaseStyle::LowerCase, + "UPPERCASE" => CaseStyle::UpperCase, + _ => return Err(()), + }) + } +} + +pub trait CaseStyleHelpers { + fn convert_case(&self, case_style: Option) -> String; +} + +impl CaseStyleHelpers for Ident { + fn convert_case(&self, case_style: Option) -> String { + let ident_string = self.to_string(); + if let Some(case_style) = case_style { + match case_style { + CaseStyle::PascalCase => ident_string.to_upper_camel_case(), + CaseStyle::KebabCase => ident_string.to_kebab_case(), + CaseStyle::MixedCase => ident_string.to_lower_camel_case(), + CaseStyle::ShoutySnakeCase => ident_string.to_shouty_snake_case(), + CaseStyle::SnakeCase => ident_string.to_snake_case(), + CaseStyle::TitleCase => ident_string.to_title_case(), + CaseStyle::UpperCase => ident_string.to_uppercase(), + CaseStyle::LowerCase => ident_string.to_lowercase(), + CaseStyle::ScreamingKebabCase => ident_string.to_kebab_case().to_uppercase(), + CaseStyle::CamelCase => { + let camel_case = ident_string.to_upper_camel_case(); + let mut pascal = String::with_capacity(camel_case.len()); + let mut it = camel_case.chars(); + if let Some(ch) = it.next() { + pascal.extend(ch.to_lowercase()); + } + pascal.extend(it); + pascal + } + } + } else { + ident_string + } + } +} + +#[test] +fn test_convert_case() { + let id = Ident::new("test_me", proc_macro2::Span::call_site()); + assert_eq!("testMe", id.convert_case(Some(CaseStyle::CamelCase))); + assert_eq!("TestMe", id.convert_case(Some(CaseStyle::PascalCase))); +} diff --git a/sea-orm-macros/src/strum/helpers/metadata.rs b/sea-orm-macros/src/strum/helpers/metadata.rs new file mode 100644 index 00000000..56e4c78b --- /dev/null +++ b/sea-orm-macros/src/strum/helpers/metadata.rs @@ -0,0 +1,309 @@ +use proc_macro2::{Span, TokenStream}; +use syn::{ + parenthesized, + parse::{Parse, ParseStream}, + parse2, parse_str, + punctuated::Punctuated, + spanned::Spanned, + Attribute, DeriveInput, Ident, Lit, LitBool, LitStr, Meta, MetaNameValue, Path, Token, Variant, + Visibility, +}; + +use super::case_style::CaseStyle; + +pub mod kw { + use syn::custom_keyword; + pub use syn::token::Crate; + + // enum metadata + custom_keyword!(serialize_all); + custom_keyword!(use_phf); + + // enum discriminant metadata + custom_keyword!(derive); + custom_keyword!(name); + custom_keyword!(vis); + + // variant metadata + custom_keyword!(message); + custom_keyword!(detailed_message); + custom_keyword!(serialize); + custom_keyword!(to_string); + custom_keyword!(disabled); + custom_keyword!(default); + custom_keyword!(props); + custom_keyword!(ascii_case_insensitive); +} + +pub enum EnumMeta { + SerializeAll { + kw: kw::serialize_all, + case_style: CaseStyle, + }, + AsciiCaseInsensitive(kw::ascii_case_insensitive), + Crate { + kw: kw::Crate, + crate_module_path: Path, + }, + UsePhf(kw::use_phf), +} + +impl Parse for EnumMeta { + fn parse(input: ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::serialize_all) { + let kw = input.parse::()?; + input.parse::()?; + let case_style = input.parse()?; + Ok(EnumMeta::SerializeAll { kw, case_style }) + } else if lookahead.peek(kw::Crate) { + let kw = input.parse::()?; + input.parse::()?; + let path_str: LitStr = input.parse()?; + let path_tokens = parse_str(&path_str.value())?; + let crate_module_path = parse2(path_tokens)?; + Ok(EnumMeta::Crate { + kw, + crate_module_path, + }) + } else if lookahead.peek(kw::ascii_case_insensitive) { + Ok(EnumMeta::AsciiCaseInsensitive(input.parse()?)) + } else if lookahead.peek(kw::use_phf) { + Ok(EnumMeta::UsePhf(input.parse()?)) + } else { + Err(lookahead.error()) + } + } +} + +impl Spanned for EnumMeta { + fn span(&self) -> Span { + match self { + EnumMeta::SerializeAll { kw, .. } => kw.span(), + EnumMeta::AsciiCaseInsensitive(kw) => kw.span(), + EnumMeta::Crate { kw, .. } => kw.span(), + EnumMeta::UsePhf(use_phf) => use_phf.span(), + } + } +} + +pub enum EnumDiscriminantsMeta { + Derive { kw: kw::derive, paths: Vec }, + Name { kw: kw::name, name: Ident }, + Vis { kw: kw::vis, vis: Visibility }, + Other { path: Path, nested: TokenStream }, +} + +impl Parse for EnumDiscriminantsMeta { + fn parse(input: ParseStream) -> syn::Result { + if input.peek(kw::derive) { + let kw = input.parse()?; + let content; + parenthesized!(content in input); + let paths = content.parse_terminated::<_, Token![,]>(Path::parse)?; + Ok(EnumDiscriminantsMeta::Derive { + kw, + paths: paths.into_iter().collect(), + }) + } else if input.peek(kw::name) { + let kw = input.parse()?; + let content; + parenthesized!(content in input); + let name = content.parse()?; + Ok(EnumDiscriminantsMeta::Name { kw, name }) + } else if input.peek(kw::vis) { + let kw = input.parse()?; + let content; + parenthesized!(content in input); + let vis = content.parse()?; + Ok(EnumDiscriminantsMeta::Vis { kw, vis }) + } else { + let path = input.parse()?; + let content; + parenthesized!(content in input); + let nested = content.parse()?; + Ok(EnumDiscriminantsMeta::Other { path, nested }) + } + } +} + +impl Spanned for EnumDiscriminantsMeta { + fn span(&self) -> Span { + match self { + EnumDiscriminantsMeta::Derive { kw, .. } => kw.span, + EnumDiscriminantsMeta::Name { kw, .. } => kw.span, + EnumDiscriminantsMeta::Vis { kw, .. } => kw.span, + EnumDiscriminantsMeta::Other { path, .. } => path.span(), + } + } +} + +pub trait DeriveInputExt { + /// Get all the strum metadata associated with an enum. + fn get_metadata(&self) -> syn::Result>; + + /// Get all the `strum_discriminants` metadata associated with an enum. + fn get_discriminants_metadata(&self) -> syn::Result>; +} + +impl DeriveInputExt for DeriveInput { + fn get_metadata(&self) -> syn::Result> { + get_metadata_inner("strum", &self.attrs) + } + + fn get_discriminants_metadata(&self) -> syn::Result> { + get_metadata_inner("strum_discriminants", &self.attrs) + } +} + +pub enum VariantMeta { + Message { + kw: kw::message, + value: LitStr, + }, + DetailedMessage { + kw: kw::detailed_message, + value: LitStr, + }, + Serialize { + kw: kw::serialize, + value: LitStr, + }, + Documentation { + value: LitStr, + }, + ToString { + kw: kw::to_string, + value: LitStr, + }, + Disabled(kw::disabled), + Default(kw::default), + AsciiCaseInsensitive { + kw: kw::ascii_case_insensitive, + value: bool, + }, + Props { + kw: kw::props, + props: Vec<(LitStr, LitStr)>, + }, +} + +impl Parse for VariantMeta { + fn parse(input: ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::message) { + let kw = input.parse()?; + let _: Token![=] = input.parse()?; + let value = input.parse()?; + Ok(VariantMeta::Message { kw, value }) + } else if lookahead.peek(kw::detailed_message) { + let kw = input.parse()?; + let _: Token![=] = input.parse()?; + let value = input.parse()?; + Ok(VariantMeta::DetailedMessage { kw, value }) + } else if lookahead.peek(kw::serialize) { + let kw = input.parse()?; + let _: Token![=] = input.parse()?; + let value = input.parse()?; + Ok(VariantMeta::Serialize { kw, value }) + } else if lookahead.peek(kw::to_string) { + let kw = input.parse()?; + let _: Token![=] = input.parse()?; + let value = input.parse()?; + Ok(VariantMeta::ToString { kw, value }) + } else if lookahead.peek(kw::disabled) { + Ok(VariantMeta::Disabled(input.parse()?)) + } else if lookahead.peek(kw::default) { + Ok(VariantMeta::Default(input.parse()?)) + } else if lookahead.peek(kw::ascii_case_insensitive) { + let kw = input.parse()?; + let value = if input.peek(Token![=]) { + let _: Token![=] = input.parse()?; + input.parse::()?.value + } else { + true + }; + Ok(VariantMeta::AsciiCaseInsensitive { kw, value }) + } else if lookahead.peek(kw::props) { + let kw = input.parse()?; + let content; + parenthesized!(content in input); + let props = content.parse_terminated::<_, Token![,]>(Prop::parse)?; + Ok(VariantMeta::Props { + kw, + props: props + .into_iter() + .map(|Prop(k, v)| (LitStr::new(&k.to_string(), k.span()), v)) + .collect(), + }) + } else { + Err(lookahead.error()) + } + } +} + +struct Prop(Ident, LitStr); + +impl Parse for Prop { + fn parse(input: ParseStream) -> syn::Result { + use syn::ext::IdentExt; + + let k = Ident::parse_any(input)?; + let _: Token![=] = input.parse()?; + let v = input.parse()?; + + Ok(Prop(k, v)) + } +} + +impl Spanned for VariantMeta { + fn span(&self) -> Span { + match self { + VariantMeta::Message { kw, .. } => kw.span, + VariantMeta::DetailedMessage { kw, .. } => kw.span, + VariantMeta::Documentation { value } => value.span(), + VariantMeta::Serialize { kw, .. } => kw.span, + VariantMeta::ToString { kw, .. } => kw.span, + VariantMeta::Disabled(kw) => kw.span, + VariantMeta::Default(kw) => kw.span, + VariantMeta::AsciiCaseInsensitive { kw, .. } => kw.span, + VariantMeta::Props { kw, .. } => kw.span, + } + } +} + +pub trait VariantExt { + /// Get all the metadata associated with an enum variant. + fn get_metadata(&self) -> syn::Result>; +} + +impl VariantExt for Variant { + fn get_metadata(&self) -> syn::Result> { + let result = get_metadata_inner("strum", &self.attrs)?; + self.attrs + .iter() + .filter(|attr| attr.path.is_ident("doc")) + .try_fold(result, |mut vec, attr| { + if let Meta::NameValue(MetaNameValue { + lit: Lit::Str(value), + .. + }) = attr.parse_meta()? + { + vec.push(VariantMeta::Documentation { value }) + } + Ok(vec) + }) + } +} + +fn get_metadata_inner<'a, T: Parse + Spanned>( + ident: &str, + it: impl IntoIterator, +) -> syn::Result> { + it.into_iter() + .filter(|attr| attr.path.is_ident(ident)) + .try_fold(Vec::new(), |mut vec, attr| { + vec.extend(attr.parse_args_with(Punctuated::::parse_terminated)?); + Ok(vec) + }) +} diff --git a/sea-orm-macros/src/strum/helpers/mod.rs b/sea-orm-macros/src/strum/helpers/mod.rs new file mode 100644 index 00000000..77049eb4 --- /dev/null +++ b/sea-orm-macros/src/strum/helpers/mod.rs @@ -0,0 +1,24 @@ +pub use self::case_style::CaseStyleHelpers; +pub use self::type_props::HasTypeProperties; +pub use self::variant_props::HasStrumVariantProperties; + +pub mod case_style; +mod metadata; +pub mod type_props; +pub mod variant_props; + +use proc_macro2::Span; +use quote::ToTokens; + +pub fn non_enum_error() -> syn::Error { + syn::Error::new(Span::call_site(), "This macro only supports enums.") +} + +pub fn occurrence_error(fst: T, snd: T, attr: &str) -> syn::Error { + let mut e = syn::Error::new_spanned( + snd, + format!("Found multiple occurrences of strum({})", attr), + ); + e.combine(syn::Error::new_spanned(fst, "first one here")); + e +} diff --git a/sea-orm-macros/src/strum/helpers/type_props.rs b/sea-orm-macros/src/strum/helpers/type_props.rs new file mode 100644 index 00000000..3b127eb3 --- /dev/null +++ b/sea-orm-macros/src/strum/helpers/type_props.rs @@ -0,0 +1,116 @@ +use proc_macro2::TokenStream; +use quote::quote; +use std::default::Default; +use syn::{parse_quote, DeriveInput, Ident, Path, Visibility}; + +use super::case_style::CaseStyle; +use super::metadata::{DeriveInputExt, EnumDiscriminantsMeta, EnumMeta}; +use super::occurrence_error; + +pub trait HasTypeProperties { + fn get_type_properties(&self) -> syn::Result; +} + +#[derive(Debug, Clone, Default)] +pub struct StrumTypeProperties { + pub case_style: Option, + pub ascii_case_insensitive: bool, + pub crate_module_path: Option, + pub discriminant_derives: Vec, + pub discriminant_name: Option, + pub discriminant_others: Vec, + pub discriminant_vis: Option, + pub use_phf: bool, +} + +impl HasTypeProperties for DeriveInput { + fn get_type_properties(&self) -> syn::Result { + let mut output = StrumTypeProperties::default(); + + let strum_meta = self.get_metadata()?; + let discriminants_meta = self.get_discriminants_metadata()?; + + let mut serialize_all_kw = None; + let mut ascii_case_insensitive_kw = None; + let mut use_phf_kw = None; + let mut crate_module_path_kw = None; + for meta in strum_meta { + match meta { + EnumMeta::SerializeAll { case_style, kw } => { + if let Some(fst_kw) = serialize_all_kw { + return Err(occurrence_error(fst_kw, kw, "serialize_all")); + } + + serialize_all_kw = Some(kw); + output.case_style = Some(case_style); + } + EnumMeta::AsciiCaseInsensitive(kw) => { + if let Some(fst_kw) = ascii_case_insensitive_kw { + return Err(occurrence_error(fst_kw, kw, "ascii_case_insensitive")); + } + + ascii_case_insensitive_kw = Some(kw); + output.ascii_case_insensitive = true; + } + EnumMeta::UsePhf(kw) => { + if let Some(fst_kw) = use_phf_kw { + return Err(occurrence_error(fst_kw, kw, "use_phf")); + } + + use_phf_kw = Some(kw); + output.use_phf = true; + } + EnumMeta::Crate { + crate_module_path, + kw, + } => { + if let Some(fst_kw) = crate_module_path_kw { + return Err(occurrence_error(fst_kw, kw, "Crate")); + } + + crate_module_path_kw = Some(kw); + output.crate_module_path = Some(crate_module_path); + } + } + } + + let mut name_kw = None; + let mut vis_kw = None; + for meta in discriminants_meta { + match meta { + EnumDiscriminantsMeta::Derive { paths, .. } => { + output.discriminant_derives.extend(paths); + } + EnumDiscriminantsMeta::Name { name, kw } => { + if let Some(fst_kw) = name_kw { + return Err(occurrence_error(fst_kw, kw, "name")); + } + + name_kw = Some(kw); + output.discriminant_name = Some(name); + } + EnumDiscriminantsMeta::Vis { vis, kw } => { + if let Some(fst_kw) = vis_kw { + return Err(occurrence_error(fst_kw, kw, "vis")); + } + + vis_kw = Some(kw); + output.discriminant_vis = Some(vis); + } + EnumDiscriminantsMeta::Other { path, nested } => { + output.discriminant_others.push(quote! { #path(#nested) }); + } + } + } + + Ok(output) + } +} + +impl StrumTypeProperties { + pub fn crate_module_path(&self) -> Path { + self.crate_module_path + .as_ref() + .map_or_else(|| parse_quote!(sea_orm::strum), |path| parse_quote!(#path)) + } +} diff --git a/sea-orm-macros/src/strum/helpers/variant_props.rs b/sea-orm-macros/src/strum/helpers/variant_props.rs new file mode 100644 index 00000000..3b4d201f --- /dev/null +++ b/sea-orm-macros/src/strum/helpers/variant_props.rs @@ -0,0 +1,102 @@ +use std::default::Default; +use syn::{Ident, LitStr, Variant}; + +use super::metadata::{kw, VariantExt, VariantMeta}; +use super::occurrence_error; + +pub trait HasStrumVariantProperties { + fn get_variant_properties(&self) -> syn::Result; +} + +#[derive(Clone, Eq, PartialEq, Debug, Default)] +pub struct StrumVariantProperties { + pub disabled: Option, + pub default: Option, + pub ascii_case_insensitive: Option, + pub message: Option, + pub detailed_message: Option, + pub documentation: Vec, + pub string_props: Vec<(LitStr, LitStr)>, + serialize: Vec, + to_string: Option, + ident: Option, +} + +impl HasStrumVariantProperties for Variant { + fn get_variant_properties(&self) -> syn::Result { + let mut output = StrumVariantProperties { + ident: Some(self.ident.clone()), + ..Default::default() + }; + + let mut message_kw = None; + let mut detailed_message_kw = None; + let mut to_string_kw = None; + let mut disabled_kw = None; + let mut default_kw = None; + let mut ascii_case_insensitive_kw = None; + for meta in self.get_metadata()? { + match meta { + VariantMeta::Message { value, kw } => { + if let Some(fst_kw) = message_kw { + return Err(occurrence_error(fst_kw, kw, "message")); + } + + message_kw = Some(kw); + output.message = Some(value); + } + VariantMeta::DetailedMessage { value, kw } => { + if let Some(fst_kw) = detailed_message_kw { + return Err(occurrence_error(fst_kw, kw, "detailed_message")); + } + + detailed_message_kw = Some(kw); + output.detailed_message = Some(value); + } + VariantMeta::Documentation { value } => { + output.documentation.push(value); + } + VariantMeta::Serialize { value, .. } => { + output.serialize.push(value); + } + VariantMeta::ToString { value, kw } => { + if let Some(fst_kw) = to_string_kw { + return Err(occurrence_error(fst_kw, kw, "to_string")); + } + + to_string_kw = Some(kw); + output.to_string = Some(value); + } + VariantMeta::Disabled(kw) => { + if let Some(fst_kw) = disabled_kw { + return Err(occurrence_error(fst_kw, kw, "disabled")); + } + + disabled_kw = Some(kw); + output.disabled = Some(kw); + } + VariantMeta::Default(kw) => { + if let Some(fst_kw) = default_kw { + return Err(occurrence_error(fst_kw, kw, "default")); + } + + default_kw = Some(kw); + output.default = Some(kw); + } + VariantMeta::AsciiCaseInsensitive { kw, value } => { + if let Some(fst_kw) = ascii_case_insensitive_kw { + return Err(occurrence_error(fst_kw, kw, "ascii_case_insensitive")); + } + + ascii_case_insensitive_kw = Some(kw); + output.ascii_case_insensitive = Some(value); + } + VariantMeta::Props { props, .. } => { + output.string_props.extend(props); + } + } + } + + Ok(output) + } +} diff --git a/sea-orm-macros/src/strum/mod.rs b/sea-orm-macros/src/strum/mod.rs new file mode 100644 index 00000000..5f9c7072 --- /dev/null +++ b/sea-orm-macros/src/strum/mod.rs @@ -0,0 +1,4 @@ +//! Source code adapted from https://github.com/Peternator7/strum + +pub mod enum_iter; +pub mod helpers; diff --git a/src/lib.rs b/src/lib.rs index 510c5f4a..37da8aa5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -359,5 +359,5 @@ pub use sea_query::Iden; #[cfg(feature = "macros")] pub use sea_query::Iden as DeriveIden; +pub use sea_orm_macros::EnumIter; pub use strum; -pub use strum::EnumIter;