Drop the use of sea-strum and depends on the original strum with a tailored EnumIter provided (#1535)

* Update heck dependency

* Fix formatter error

* Drop the use of `sea-strum` and depends on the original `strum` with a tailored `EnumIter` provided

* fmt

* Depends on `strum` 0.23

* Depends on `strum` 0.24

* Source code adapted from https://github.com/Peternator7/strum

* Update LICENSE

---------

Co-authored-by: Sergei Ivankov <sergeiivankov@pm.me>
Co-authored-by: Sergei Ivankov <96142843+sergeiivankov@users.noreply.github.com>
Co-authored-by: Chris Tsang <chris.2y3@outlook.com>
This commit is contained in:
Billy Chan 2023-03-22 11:47:15 +08:00 committed by GitHub
parent 162303cd0d
commit 2eda8aa3f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 918 additions and 13 deletions

View File

@ -33,10 +33,10 @@ log = { version = "0.4", default-features = false }
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] } tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
rust_decimal = { version = "1", default-features = false, optional = true } rust_decimal = { version = "1", default-features = false, optional = true }
bigdecimal = { version = "0.3", default-features = false, optional = true } bigdecimal = { version = "0.3", default-features = false, optional = true }
sea-orm-macros = { version = "0.12.0", path = "sea-orm-macros", default-features = false, optional = true } sea-orm-macros = { version = "0.12.0", path = "sea-orm-macros", default-features = false, features = ["strum"] }
sea-query = { version = "0.28.3", features = ["thread-safe"] } sea-query = { version = "0.28.3", features = ["thread-safe"] }
sea-query-binder = { version = "0.3", default-features = false, optional = true } sea-query-binder = { version = "0.3", default-features = false, optional = true }
sea-strum = { version = "0.23", default-features = false, features = ["derive", "sea-orm"] } strum = { version = "0.24", default-features = false }
serde = { version = "1.0", default-features = false } serde = { version = "1.0", default-features = false }
serde_json = { version = "1.0", default-features = false, optional = true } serde_json = { version = "1.0", default-features = false, optional = true }
sqlx = { version = "0.6", default-features = false, optional = true } sqlx = { version = "0.6", default-features = false, optional = true }
@ -72,7 +72,7 @@ default = [
"with-uuid", "with-uuid",
"with-time", "with-time",
] ]
macros = ["sea-orm-macros", "sea-query/derive"] macros = ["sea-orm-macros/derive", "sea-query/derive"]
mock = [] mock = []
with-json = ["serde_json", "sea-query/with-json", "chrono?/serde", "time?/serde", "uuid?/serde", "sea-query-binder?/with-json", "sqlx?/json"] with-json = ["serde_json", "sea-query/with-json", "chrono?/serde", "time?/serde", "uuid?/serde", "sea-query-binder?/with-json", "sqlx?/json"]
with-chrono = ["chrono", "sea-query/with-chrono", "sea-query-binder?/with-chrono", "sqlx?/chrono"] with-chrono = ["chrono", "sea-query/with-chrono", "sea-query-binder?/with-chrono", "sqlx?/chrono"]
@ -80,7 +80,7 @@ with-rust_decimal = ["rust_decimal", "sea-query/with-rust_decimal", "sea-query-b
with-bigdecimal = ["bigdecimal", "sea-query/with-bigdecimal", "sea-query-binder?/with-bigdecimal", "sqlx?/bigdecimal"] with-bigdecimal = ["bigdecimal", "sea-query/with-bigdecimal", "sea-query-binder?/with-bigdecimal", "sqlx?/bigdecimal"]
with-uuid = ["uuid", "sea-query/with-uuid", "sea-query-binder?/with-uuid", "sqlx?/uuid"] with-uuid = ["uuid", "sea-query/with-uuid", "sea-query-binder?/with-uuid", "sqlx?/uuid"]
with-time = ["time", "sea-query/with-time", "sea-query-binder?/with-time", "sqlx?/time"] with-time = ["time", "sea-query/with-time", "sea-query-binder?/with-time", "sqlx?/time"]
postgres-array = ["sea-query/postgres-array", "sea-query-binder?/postgres-array", "sea-orm-macros?/postgres-array"] postgres-array = ["sea-query/postgres-array", "sea-query-binder?/postgres-array", "sea-orm-macros/postgres-array"]
sea-orm-internal = [] sea-orm-internal = []
sqlx-dep = [] sqlx-dep = []
sqlx-all = ["sqlx-mysql", "sqlx-postgres", "sqlx-sqlite"] sqlx-all = ["sqlx-mysql", "sqlx-postgres", "sqlx-sqlite"]

View File

@ -18,7 +18,7 @@ path = "src/lib.rs"
proc-macro = true proc-macro = true
[dependencies] [dependencies]
bae = { version = "0.1", default-features = false } bae = { version = "0.1", default-features = false, optional = true }
syn = { version = "1", default-features = false, features = ["parsing", "proc-macro", "derive", "printing"] } syn = { version = "1", default-features = false, features = ["parsing", "proc-macro", "derive", "printing"] }
quote = { version = "1", default-features = false } quote = { version = "1", default-features = false }
heck = { version = "0.4", default-features = false } heck = { version = "0.4", default-features = false }
@ -29,4 +29,7 @@ sea-orm = { path = "../", features = ["macros"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
[features] [features]
default = ["derive"]
postgres-array = [] postgres-array = []
derive = ["bae"]
strum = []

View File

@ -1,4 +1,4 @@
use crate::util::{ use super::util::{
escape_rust_keyword, field_not_ignored, format_field_ident, trim_starting_raw_identifier, escape_rust_keyword, field_not_ignored, format_field_ident, trim_starting_raw_identifier,
}; };
use heck::ToUpperCamelCase; use heck::ToUpperCamelCase;

View File

@ -3,7 +3,7 @@ use std::iter::FromIterator;
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use crate::attributes::derive_attr; use super::attributes::derive_attr;
struct DeriveEntity { struct DeriveEntity {
column_ident: syn::Ident, column_ident: syn::Ident,

View File

@ -1,4 +1,4 @@
use crate::util::{escape_rust_keyword, trim_starting_raw_identifier}; use super::util::{escape_rust_keyword, trim_starting_raw_identifier};
use heck::{ToSnakeCase, ToUpperCamelCase}; use heck::{ToSnakeCase, ToUpperCamelCase};
use proc_macro2::{Ident, Span, TokenStream}; use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote, quote_spanned}; use quote::{quote, quote_spanned};

View File

@ -1,6 +1,7 @@
mod active_enum; mod active_enum;
mod active_model; mod active_model;
mod active_model_behavior; mod active_model_behavior;
mod attributes;
mod column; mod column;
mod entity; mod entity;
mod entity_model; mod entity_model;
@ -11,6 +12,7 @@ mod model;
mod primary_key; mod primary_key;
mod relation; mod relation;
mod try_getable_from_json; mod try_getable_from_json;
mod util;
pub use active_enum::*; pub use active_enum::*;
pub use active_model::*; pub use active_model::*;

View File

@ -1,4 +1,4 @@
use crate::{ use super::{
attributes::derive_attr, attributes::derive_attr,
util::{escape_rust_keyword, field_not_ignored, trim_starting_raw_identifier}, util::{escape_rust_keyword, field_not_ignored, trim_starting_raw_identifier},
}; };

View File

@ -1,7 +1,7 @@
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned}; use quote::{format_ident, quote, quote_spanned};
use crate::attributes::{derive_attr, field_attr}; use super::attributes::{derive_attr, field_attr};
enum Error { enum Error {
InputNotEnum, InputNotEnum,

View File

@ -3,9 +3,11 @@ extern crate proc_macro;
use proc_macro::TokenStream; use proc_macro::TokenStream;
use syn::{parse_macro_input, DeriveInput, Error}; use syn::{parse_macro_input, DeriveInput, Error};
mod attributes; #[cfg(feature = "derive")]
mod derives; mod derives;
mod util;
#[cfg(feature = "strum")]
mod strum;
/// Create an Entity /// Create an Entity
/// ///
@ -70,6 +72,7 @@ mod util;
/// # /// #
/// # impl ActiveModelBehavior for ActiveModel {} /// # impl ActiveModelBehavior for ActiveModel {}
/// ``` /// ```
#[cfg(feature = "derive")]
#[proc_macro_derive(DeriveEntity, attributes(sea_orm))] #[proc_macro_derive(DeriveEntity, attributes(sea_orm))]
pub fn derive_entity(input: TokenStream) -> TokenStream { pub fn derive_entity(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput); let input = parse_macro_input!(input as DeriveInput);
@ -129,6 +132,7 @@ pub fn derive_entity(input: TokenStream) -> TokenStream {
/// # /// #
/// # impl ActiveModelBehavior for ActiveModel {} /// # impl ActiveModelBehavior for ActiveModel {}
/// ``` /// ```
#[cfg(feature = "derive")]
#[proc_macro_derive(DeriveEntityModel, attributes(sea_orm))] #[proc_macro_derive(DeriveEntityModel, attributes(sea_orm))]
pub fn derive_entity_model(input: TokenStream) -> TokenStream { pub fn derive_entity_model(input: TokenStream) -> TokenStream {
let input_ts = input.clone(); let input_ts = input.clone();
@ -216,6 +220,7 @@ pub fn derive_entity_model(input: TokenStream) -> TokenStream {
/// # /// #
/// # impl ActiveModelBehavior for ActiveModel {} /// # impl ActiveModelBehavior for ActiveModel {}
/// ``` /// ```
#[cfg(feature = "derive")]
#[proc_macro_derive(DerivePrimaryKey, attributes(sea_orm))] #[proc_macro_derive(DerivePrimaryKey, attributes(sea_orm))]
pub fn derive_primary_key(input: TokenStream) -> TokenStream { pub fn derive_primary_key(input: TokenStream) -> TokenStream {
let DeriveInput { ident, data, .. } = parse_macro_input!(input); let DeriveInput { ident, data, .. } = parse_macro_input!(input);
@ -241,6 +246,7 @@ pub fn derive_primary_key(input: TokenStream) -> TokenStream {
/// FillingId, /// FillingId,
/// } /// }
/// ``` /// ```
#[cfg(feature = "derive")]
#[proc_macro_derive(DeriveColumn, attributes(sea_orm))] #[proc_macro_derive(DeriveColumn, attributes(sea_orm))]
pub fn derive_column(input: TokenStream) -> TokenStream { pub fn derive_column(input: TokenStream) -> TokenStream {
let DeriveInput { ident, data, .. } = parse_macro_input!(input); let DeriveInput { ident, data, .. } = parse_macro_input!(input);
@ -274,6 +280,7 @@ pub fn derive_column(input: TokenStream) -> TokenStream {
/// } /// }
/// } /// }
/// ``` /// ```
#[cfg(feature = "derive")]
#[proc_macro_derive(DeriveCustomColumn)] #[proc_macro_derive(DeriveCustomColumn)]
pub fn derive_custom_column(input: TokenStream) -> TokenStream { pub fn derive_custom_column(input: TokenStream) -> TokenStream {
let DeriveInput { ident, data, .. } = parse_macro_input!(input); let DeriveInput { ident, data, .. } = parse_macro_input!(input);
@ -349,6 +356,7 @@ pub fn derive_custom_column(input: TokenStream) -> TokenStream {
/// # /// #
/// # impl ActiveModelBehavior for ActiveModel {} /// # impl ActiveModelBehavior for ActiveModel {}
/// ``` /// ```
#[cfg(feature = "derive")]
#[proc_macro_derive(DeriveModel, attributes(sea_orm))] #[proc_macro_derive(DeriveModel, attributes(sea_orm))]
pub fn derive_model(input: TokenStream) -> TokenStream { pub fn derive_model(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput); let input = parse_macro_input!(input as DeriveInput);
@ -421,6 +429,7 @@ pub fn derive_model(input: TokenStream) -> TokenStream {
/// # /// #
/// # impl ActiveModelBehavior for ActiveModel {} /// # impl ActiveModelBehavior for ActiveModel {}
/// ``` /// ```
#[cfg(feature = "derive")]
#[proc_macro_derive(DeriveActiveModel, attributes(sea_orm))] #[proc_macro_derive(DeriveActiveModel, attributes(sea_orm))]
pub fn derive_active_model(input: TokenStream) -> TokenStream { pub fn derive_active_model(input: TokenStream) -> TokenStream {
let DeriveInput { ident, data, .. } = parse_macro_input!(input); let DeriveInput { ident, data, .. } = parse_macro_input!(input);
@ -432,6 +441,7 @@ pub fn derive_active_model(input: TokenStream) -> TokenStream {
} }
/// Derive into an active model /// Derive into an active model
#[cfg(feature = "derive")]
#[proc_macro_derive(DeriveIntoActiveModel, attributes(sea_orm))] #[proc_macro_derive(DeriveIntoActiveModel, attributes(sea_orm))]
pub fn derive_into_active_model(input: TokenStream) -> TokenStream { pub fn derive_into_active_model(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput); let input = parse_macro_input!(input as DeriveInput);
@ -503,6 +513,7 @@ pub fn derive_into_active_model(input: TokenStream) -> TokenStream {
/// # } /// # }
/// # } /// # }
/// ``` /// ```
#[cfg(feature = "derive")]
#[proc_macro_derive(DeriveActiveModelBehavior)] #[proc_macro_derive(DeriveActiveModelBehavior)]
pub fn derive_active_model_behavior(input: TokenStream) -> TokenStream { pub fn derive_active_model_behavior(input: TokenStream) -> TokenStream {
let DeriveInput { ident, data, .. } = parse_macro_input!(input); let DeriveInput { ident, data, .. } = parse_macro_input!(input);
@ -552,6 +563,7 @@ pub fn derive_active_model_behavior(input: TokenStream) -> TokenStream {
/// White = 1, /// White = 1,
/// } /// }
/// ``` /// ```
#[cfg(feature = "derive")]
#[proc_macro_derive(DeriveActiveEnum, attributes(sea_orm))] #[proc_macro_derive(DeriveActiveEnum, attributes(sea_orm))]
pub fn derive_active_enum(input: TokenStream) -> TokenStream { pub fn derive_active_enum(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput); let input = parse_macro_input!(input as DeriveInput);
@ -574,6 +586,7 @@ pub fn derive_active_enum(input: TokenStream) -> TokenStream {
/// num_of_fruits: i32, /// num_of_fruits: i32,
/// } /// }
/// ``` /// ```
#[cfg(feature = "derive")]
#[proc_macro_derive(FromQueryResult)] #[proc_macro_derive(FromQueryResult)]
pub fn derive_from_query_result(input: TokenStream) -> TokenStream { pub fn derive_from_query_result(input: TokenStream) -> TokenStream {
let DeriveInput { ident, data, .. } = parse_macro_input!(input); let DeriveInput { ident, data, .. } = parse_macro_input!(input);
@ -608,6 +621,7 @@ pub fn derive_from_query_result(input: TokenStream) -> TokenStream {
/// CakeExpanded, /// CakeExpanded,
/// } /// }
/// ``` /// ```
#[cfg(feature = "derive")]
#[proc_macro_derive(DeriveRelation, attributes(sea_orm))] #[proc_macro_derive(DeriveRelation, attributes(sea_orm))]
pub fn derive_relation(input: TokenStream) -> TokenStream { pub fn derive_relation(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput); let input = parse_macro_input!(input as DeriveInput);
@ -635,6 +649,7 @@ pub fn derive_relation(input: TokenStream) -> TokenStream {
/// } /// }
/// } /// }
/// ``` /// ```
#[cfg(feature = "derive")]
#[proc_macro_derive(DeriveMigrationName)] #[proc_macro_derive(DeriveMigrationName)]
pub fn derive_migration_name(input: TokenStream) -> TokenStream { pub fn derive_migration_name(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput); let input = parse_macro_input!(input as DeriveInput);
@ -643,6 +658,7 @@ pub fn derive_migration_name(input: TokenStream) -> TokenStream {
.into() .into()
} }
#[cfg(feature = "derive")]
#[proc_macro_derive(FromJsonQueryResult)] #[proc_macro_derive(FromJsonQueryResult)]
pub fn derive_from_json_query_result(input: TokenStream) -> TokenStream { pub fn derive_from_json_query_result(input: TokenStream) -> TokenStream {
let DeriveInput { ident, .. } = parse_macro_input!(input); let DeriveInput { ident, .. } = parse_macro_input!(input);
@ -654,6 +670,7 @@ pub fn derive_from_json_query_result(input: TokenStream) -> TokenStream {
} }
#[doc(hidden)] #[doc(hidden)]
#[cfg(feature = "derive")]
#[proc_macro_attribute] #[proc_macro_attribute]
pub fn test(_: TokenStream, input: TokenStream) -> TokenStream { pub fn test(_: TokenStream, input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as syn::ItemFn); let input = parse_macro_input!(input as syn::ItemFn);
@ -676,3 +693,19 @@ pub fn test(_: TokenStream, input: TokenStream) -> TokenStream {
) )
.into() .into()
} }
/// Creates a new type that iterates of the variants of an enum.
///
/// Iterate over the variants of an Enum. Any additional data on your variants will be set to `Default::default()`.
/// The macro implements `strum::IntoEnumIterator` on your enum and creates a new type called `YourEnumIter` that is the iterator object.
/// You cannot derive `EnumIter` on any type with a lifetime bound (`<'a>`) because the iterator would surely
/// create [unbounded lifetimes](https://doc.rust-lang.org/nightly/nomicon/unbounded-lifetimes.html).
#[cfg(feature = "strum")]
#[proc_macro_derive(EnumIter, attributes(strum))]
pub fn enum_iter(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
strum::enum_iter::enum_iter_inner(&ast)
.unwrap_or_else(Error::into_compile_error)
.into()
}

View File

@ -0,0 +1,23 @@
> The `strum` module is adapted from https://github.com/Peternator7/strum
MIT License
Copyright (c) 2019 Peter Glotfelty
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,172 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{Data, DeriveInput, Fields, Ident};
use super::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};
pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
let name = &ast.ident;
let gen = &ast.generics;
let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
let vis = &ast.vis;
let type_properties = ast.get_type_properties()?;
let strum_module_path = type_properties.crate_module_path();
let doc_comment = format!("An iterator over the variants of [{}]", name);
if gen.lifetimes().count() > 0 {
return Err(syn::Error::new(
Span::call_site(),
"This macro doesn't support enums with lifetimes. \
The resulting enums would be unbounded.",
));
}
let phantom_data = if gen.type_params().count() > 0 {
let g = gen.type_params().map(|param| &param.ident);
quote! { < ( #(#g),* ) > }
} else {
quote! { < () > }
};
let variants = match &ast.data {
Data::Enum(v) => &v.variants,
_ => return Err(non_enum_error()),
};
let mut arms = Vec::new();
let mut idx = 0usize;
for variant in variants {
if variant.get_variant_properties()?.disabled.is_some() {
continue;
}
let ident = &variant.ident;
let params = match &variant.fields {
Fields::Unit => quote! {},
Fields::Unnamed(fields) => {
let defaults = ::core::iter::repeat(quote!(::core::default::Default::default()))
.take(fields.unnamed.len());
quote! { (#(#defaults),*) }
}
Fields::Named(fields) => {
let fields = fields
.named
.iter()
.map(|field| field.ident.as_ref().unwrap());
quote! { {#(#fields: ::core::default::Default::default()),*} }
}
};
arms.push(quote! {#idx => ::core::option::Option::Some(#name::#ident #params)});
idx += 1;
}
let variant_count = arms.len();
arms.push(quote! { _ => ::core::option::Option::None });
let iter_name = syn::parse_str::<Ident>(&format!("{}Iter", name)).unwrap();
// Create a string literal "MyEnumIter" to use in the debug impl.
let iter_name_debug_struct =
syn::parse_str::<syn::LitStr>(&format!("\"{}\"", iter_name)).unwrap();
Ok(quote! {
#[doc = #doc_comment]
#[allow(
missing_copy_implementations,
)]
#vis struct #iter_name #ty_generics {
idx: usize,
back_idx: usize,
marker: ::core::marker::PhantomData #phantom_data,
}
impl #impl_generics core::fmt::Debug for #iter_name #ty_generics #where_clause {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
// We don't know if the variants implement debug themselves so the only thing we
// can really show is how many elements are left.
f.debug_struct(#iter_name_debug_struct)
.field("len", &self.len())
.finish()
}
}
impl #impl_generics #iter_name #ty_generics #where_clause {
fn get(&self, idx: usize) -> Option<#name #ty_generics> {
match idx {
#(#arms),*
}
}
}
impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause {
type Iterator = #iter_name #ty_generics;
fn iter() -> #iter_name #ty_generics {
#iter_name {
idx: 0,
back_idx: 0,
marker: ::core::marker::PhantomData,
}
}
}
impl #impl_generics Iterator for #iter_name #ty_generics #where_clause {
type Item = #name #ty_generics;
fn next(&mut self) -> Option<<Self as Iterator>::Item> {
self.nth(0)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx };
(t, Some(t))
}
fn nth(&mut self, n: usize) -> Option<<Self as Iterator>::Item> {
let idx = self.idx + n + 1;
if idx + self.back_idx > #variant_count {
// We went past the end of the iterator. Freeze idx at #variant_count
// so that it doesn't overflow if the user calls this repeatedly.
// See PR #76 for context.
self.idx = #variant_count;
::core::option::Option::None
} else {
self.idx = idx;
self.get(idx - 1)
}
}
}
impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause {
fn len(&self) -> usize {
self.size_hint().0
}
}
impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause {
fn next_back(&mut self) -> Option<<Self as Iterator>::Item> {
let back_idx = self.back_idx + 1;
if self.idx + back_idx > #variant_count {
// We went past the end of the iterator. Freeze back_idx at #variant_count
// so that it doesn't overflow if the user calls this repeatedly.
// See PR #76 for context.
self.back_idx = #variant_count;
::core::option::Option::None
} else {
self.back_idx = back_idx;
self.get(#variant_count - self.back_idx)
}
}
}
impl #impl_generics Clone for #iter_name #ty_generics #where_clause {
fn clone(&self) -> #iter_name #ty_generics {
#iter_name {
idx: self.idx,
back_idx: self.back_idx,
marker: self.marker.clone(),
}
}
}
})
}

View File

@ -0,0 +1,117 @@
use heck::{
ToKebabCase, ToLowerCamelCase, ToShoutySnakeCase, ToSnakeCase, ToTitleCase, ToUpperCamelCase,
};
use std::str::FromStr;
use syn::{
parse::{Parse, ParseStream},
Ident, LitStr,
};
#[allow(clippy::enum_variant_names)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum CaseStyle {
CamelCase,
KebabCase,
MixedCase,
ShoutySnakeCase,
SnakeCase,
TitleCase,
UpperCase,
LowerCase,
ScreamingKebabCase,
PascalCase,
}
const VALID_CASE_STYLES: &[&str] = &[
"camelCase",
"PascalCase",
"kebab-case",
"snake_case",
"SCREAMING_SNAKE_CASE",
"SCREAMING-KEBAB-CASE",
"lowercase",
"UPPERCASE",
"title_case",
"mixed_case",
];
impl Parse for CaseStyle {
fn parse(input: ParseStream) -> syn::Result<Self> {
let text = input.parse::<LitStr>()?;
let val = text.value();
val.as_str().parse().map_err(|_| {
syn::Error::new_spanned(
&text,
format!(
"Unexpected case style for serialize_all: `{}`. Valid values are: `{:?}`",
val, VALID_CASE_STYLES
),
)
})
}
}
impl FromStr for CaseStyle {
type Err = ();
fn from_str(text: &str) -> Result<Self, ()> {
Ok(match text {
"camel_case" | "PascalCase" => CaseStyle::PascalCase,
"camelCase" => CaseStyle::CamelCase,
"snake_case" | "snek_case" => CaseStyle::SnakeCase,
"kebab_case" | "kebab-case" => CaseStyle::KebabCase,
"SCREAMING-KEBAB-CASE" => CaseStyle::ScreamingKebabCase,
"shouty_snake_case" | "shouty_snek_case" | "SCREAMING_SNAKE_CASE" => {
CaseStyle::ShoutySnakeCase
}
"title_case" => CaseStyle::TitleCase,
"mixed_case" => CaseStyle::MixedCase,
"lowercase" => CaseStyle::LowerCase,
"UPPERCASE" => CaseStyle::UpperCase,
_ => return Err(()),
})
}
}
pub trait CaseStyleHelpers {
fn convert_case(&self, case_style: Option<CaseStyle>) -> String;
}
impl CaseStyleHelpers for Ident {
fn convert_case(&self, case_style: Option<CaseStyle>) -> String {
let ident_string = self.to_string();
if let Some(case_style) = case_style {
match case_style {
CaseStyle::PascalCase => ident_string.to_upper_camel_case(),
CaseStyle::KebabCase => ident_string.to_kebab_case(),
CaseStyle::MixedCase => ident_string.to_lower_camel_case(),
CaseStyle::ShoutySnakeCase => ident_string.to_shouty_snake_case(),
CaseStyle::SnakeCase => ident_string.to_snake_case(),
CaseStyle::TitleCase => ident_string.to_title_case(),
CaseStyle::UpperCase => ident_string.to_uppercase(),
CaseStyle::LowerCase => ident_string.to_lowercase(),
CaseStyle::ScreamingKebabCase => ident_string.to_kebab_case().to_uppercase(),
CaseStyle::CamelCase => {
let camel_case = ident_string.to_upper_camel_case();
let mut pascal = String::with_capacity(camel_case.len());
let mut it = camel_case.chars();
if let Some(ch) = it.next() {
pascal.extend(ch.to_lowercase());
}
pascal.extend(it);
pascal
}
}
} else {
ident_string
}
}
}
#[test]
fn test_convert_case() {
let id = Ident::new("test_me", proc_macro2::Span::call_site());
assert_eq!("testMe", id.convert_case(Some(CaseStyle::CamelCase)));
assert_eq!("TestMe", id.convert_case(Some(CaseStyle::PascalCase)));
}

View File

@ -0,0 +1,309 @@
use proc_macro2::{Span, TokenStream};
use syn::{
parenthesized,
parse::{Parse, ParseStream},
parse2, parse_str,
punctuated::Punctuated,
spanned::Spanned,
Attribute, DeriveInput, Ident, Lit, LitBool, LitStr, Meta, MetaNameValue, Path, Token, Variant,
Visibility,
};
use super::case_style::CaseStyle;
pub mod kw {
use syn::custom_keyword;
pub use syn::token::Crate;
// enum metadata
custom_keyword!(serialize_all);
custom_keyword!(use_phf);
// enum discriminant metadata
custom_keyword!(derive);
custom_keyword!(name);
custom_keyword!(vis);
// variant metadata
custom_keyword!(message);
custom_keyword!(detailed_message);
custom_keyword!(serialize);
custom_keyword!(to_string);
custom_keyword!(disabled);
custom_keyword!(default);
custom_keyword!(props);
custom_keyword!(ascii_case_insensitive);
}
pub enum EnumMeta {
SerializeAll {
kw: kw::serialize_all,
case_style: CaseStyle,
},
AsciiCaseInsensitive(kw::ascii_case_insensitive),
Crate {
kw: kw::Crate,
crate_module_path: Path,
},
UsePhf(kw::use_phf),
}
impl Parse for EnumMeta {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(kw::serialize_all) {
let kw = input.parse::<kw::serialize_all>()?;
input.parse::<Token![=]>()?;
let case_style = input.parse()?;
Ok(EnumMeta::SerializeAll { kw, case_style })
} else if lookahead.peek(kw::Crate) {
let kw = input.parse::<kw::Crate>()?;
input.parse::<Token![=]>()?;
let path_str: LitStr = input.parse()?;
let path_tokens = parse_str(&path_str.value())?;
let crate_module_path = parse2(path_tokens)?;
Ok(EnumMeta::Crate {
kw,
crate_module_path,
})
} else if lookahead.peek(kw::ascii_case_insensitive) {
Ok(EnumMeta::AsciiCaseInsensitive(input.parse()?))
} else if lookahead.peek(kw::use_phf) {
Ok(EnumMeta::UsePhf(input.parse()?))
} else {
Err(lookahead.error())
}
}
}
impl Spanned for EnumMeta {
fn span(&self) -> Span {
match self {
EnumMeta::SerializeAll { kw, .. } => kw.span(),
EnumMeta::AsciiCaseInsensitive(kw) => kw.span(),
EnumMeta::Crate { kw, .. } => kw.span(),
EnumMeta::UsePhf(use_phf) => use_phf.span(),
}
}
}
pub enum EnumDiscriminantsMeta {
Derive { kw: kw::derive, paths: Vec<Path> },
Name { kw: kw::name, name: Ident },
Vis { kw: kw::vis, vis: Visibility },
Other { path: Path, nested: TokenStream },
}
impl Parse for EnumDiscriminantsMeta {
fn parse(input: ParseStream) -> syn::Result<Self> {
if input.peek(kw::derive) {
let kw = input.parse()?;
let content;
parenthesized!(content in input);
let paths = content.parse_terminated::<_, Token![,]>(Path::parse)?;
Ok(EnumDiscriminantsMeta::Derive {
kw,
paths: paths.into_iter().collect(),
})
} else if input.peek(kw::name) {
let kw = input.parse()?;
let content;
parenthesized!(content in input);
let name = content.parse()?;
Ok(EnumDiscriminantsMeta::Name { kw, name })
} else if input.peek(kw::vis) {
let kw = input.parse()?;
let content;
parenthesized!(content in input);
let vis = content.parse()?;
Ok(EnumDiscriminantsMeta::Vis { kw, vis })
} else {
let path = input.parse()?;
let content;
parenthesized!(content in input);
let nested = content.parse()?;
Ok(EnumDiscriminantsMeta::Other { path, nested })
}
}
}
impl Spanned for EnumDiscriminantsMeta {
fn span(&self) -> Span {
match self {
EnumDiscriminantsMeta::Derive { kw, .. } => kw.span,
EnumDiscriminantsMeta::Name { kw, .. } => kw.span,
EnumDiscriminantsMeta::Vis { kw, .. } => kw.span,
EnumDiscriminantsMeta::Other { path, .. } => path.span(),
}
}
}
pub trait DeriveInputExt {
/// Get all the strum metadata associated with an enum.
fn get_metadata(&self) -> syn::Result<Vec<EnumMeta>>;
/// Get all the `strum_discriminants` metadata associated with an enum.
fn get_discriminants_metadata(&self) -> syn::Result<Vec<EnumDiscriminantsMeta>>;
}
impl DeriveInputExt for DeriveInput {
fn get_metadata(&self) -> syn::Result<Vec<EnumMeta>> {
get_metadata_inner("strum", &self.attrs)
}
fn get_discriminants_metadata(&self) -> syn::Result<Vec<EnumDiscriminantsMeta>> {
get_metadata_inner("strum_discriminants", &self.attrs)
}
}
pub enum VariantMeta {
Message {
kw: kw::message,
value: LitStr,
},
DetailedMessage {
kw: kw::detailed_message,
value: LitStr,
},
Serialize {
kw: kw::serialize,
value: LitStr,
},
Documentation {
value: LitStr,
},
ToString {
kw: kw::to_string,
value: LitStr,
},
Disabled(kw::disabled),
Default(kw::default),
AsciiCaseInsensitive {
kw: kw::ascii_case_insensitive,
value: bool,
},
Props {
kw: kw::props,
props: Vec<(LitStr, LitStr)>,
},
}
impl Parse for VariantMeta {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(kw::message) {
let kw = input.parse()?;
let _: Token![=] = input.parse()?;
let value = input.parse()?;
Ok(VariantMeta::Message { kw, value })
} else if lookahead.peek(kw::detailed_message) {
let kw = input.parse()?;
let _: Token![=] = input.parse()?;
let value = input.parse()?;
Ok(VariantMeta::DetailedMessage { kw, value })
} else if lookahead.peek(kw::serialize) {
let kw = input.parse()?;
let _: Token![=] = input.parse()?;
let value = input.parse()?;
Ok(VariantMeta::Serialize { kw, value })
} else if lookahead.peek(kw::to_string) {
let kw = input.parse()?;
let _: Token![=] = input.parse()?;
let value = input.parse()?;
Ok(VariantMeta::ToString { kw, value })
} else if lookahead.peek(kw::disabled) {
Ok(VariantMeta::Disabled(input.parse()?))
} else if lookahead.peek(kw::default) {
Ok(VariantMeta::Default(input.parse()?))
} else if lookahead.peek(kw::ascii_case_insensitive) {
let kw = input.parse()?;
let value = if input.peek(Token![=]) {
let _: Token![=] = input.parse()?;
input.parse::<LitBool>()?.value
} else {
true
};
Ok(VariantMeta::AsciiCaseInsensitive { kw, value })
} else if lookahead.peek(kw::props) {
let kw = input.parse()?;
let content;
parenthesized!(content in input);
let props = content.parse_terminated::<_, Token![,]>(Prop::parse)?;
Ok(VariantMeta::Props {
kw,
props: props
.into_iter()
.map(|Prop(k, v)| (LitStr::new(&k.to_string(), k.span()), v))
.collect(),
})
} else {
Err(lookahead.error())
}
}
}
struct Prop(Ident, LitStr);
impl Parse for Prop {
fn parse(input: ParseStream) -> syn::Result<Self> {
use syn::ext::IdentExt;
let k = Ident::parse_any(input)?;
let _: Token![=] = input.parse()?;
let v = input.parse()?;
Ok(Prop(k, v))
}
}
impl Spanned for VariantMeta {
fn span(&self) -> Span {
match self {
VariantMeta::Message { kw, .. } => kw.span,
VariantMeta::DetailedMessage { kw, .. } => kw.span,
VariantMeta::Documentation { value } => value.span(),
VariantMeta::Serialize { kw, .. } => kw.span,
VariantMeta::ToString { kw, .. } => kw.span,
VariantMeta::Disabled(kw) => kw.span,
VariantMeta::Default(kw) => kw.span,
VariantMeta::AsciiCaseInsensitive { kw, .. } => kw.span,
VariantMeta::Props { kw, .. } => kw.span,
}
}
}
pub trait VariantExt {
/// Get all the metadata associated with an enum variant.
fn get_metadata(&self) -> syn::Result<Vec<VariantMeta>>;
}
impl VariantExt for Variant {
fn get_metadata(&self) -> syn::Result<Vec<VariantMeta>> {
let result = get_metadata_inner("strum", &self.attrs)?;
self.attrs
.iter()
.filter(|attr| attr.path.is_ident("doc"))
.try_fold(result, |mut vec, attr| {
if let Meta::NameValue(MetaNameValue {
lit: Lit::Str(value),
..
}) = attr.parse_meta()?
{
vec.push(VariantMeta::Documentation { value })
}
Ok(vec)
})
}
}
fn get_metadata_inner<'a, T: Parse + Spanned>(
ident: &str,
it: impl IntoIterator<Item = &'a Attribute>,
) -> syn::Result<Vec<T>> {
it.into_iter()
.filter(|attr| attr.path.is_ident(ident))
.try_fold(Vec::new(), |mut vec, attr| {
vec.extend(attr.parse_args_with(Punctuated::<T, Token![,]>::parse_terminated)?);
Ok(vec)
})
}

View File

@ -0,0 +1,24 @@
pub use self::case_style::CaseStyleHelpers;
pub use self::type_props::HasTypeProperties;
pub use self::variant_props::HasStrumVariantProperties;
pub mod case_style;
mod metadata;
pub mod type_props;
pub mod variant_props;
use proc_macro2::Span;
use quote::ToTokens;
pub fn non_enum_error() -> syn::Error {
syn::Error::new(Span::call_site(), "This macro only supports enums.")
}
pub fn occurrence_error<T: ToTokens>(fst: T, snd: T, attr: &str) -> syn::Error {
let mut e = syn::Error::new_spanned(
snd,
format!("Found multiple occurrences of strum({})", attr),
);
e.combine(syn::Error::new_spanned(fst, "first one here"));
e
}

View File

@ -0,0 +1,116 @@
use proc_macro2::TokenStream;
use quote::quote;
use std::default::Default;
use syn::{parse_quote, DeriveInput, Ident, Path, Visibility};
use super::case_style::CaseStyle;
use super::metadata::{DeriveInputExt, EnumDiscriminantsMeta, EnumMeta};
use super::occurrence_error;
pub trait HasTypeProperties {
fn get_type_properties(&self) -> syn::Result<StrumTypeProperties>;
}
#[derive(Debug, Clone, Default)]
pub struct StrumTypeProperties {
pub case_style: Option<CaseStyle>,
pub ascii_case_insensitive: bool,
pub crate_module_path: Option<Path>,
pub discriminant_derives: Vec<Path>,
pub discriminant_name: Option<Ident>,
pub discriminant_others: Vec<TokenStream>,
pub discriminant_vis: Option<Visibility>,
pub use_phf: bool,
}
impl HasTypeProperties for DeriveInput {
fn get_type_properties(&self) -> syn::Result<StrumTypeProperties> {
let mut output = StrumTypeProperties::default();
let strum_meta = self.get_metadata()?;
let discriminants_meta = self.get_discriminants_metadata()?;
let mut serialize_all_kw = None;
let mut ascii_case_insensitive_kw = None;
let mut use_phf_kw = None;
let mut crate_module_path_kw = None;
for meta in strum_meta {
match meta {
EnumMeta::SerializeAll { case_style, kw } => {
if let Some(fst_kw) = serialize_all_kw {
return Err(occurrence_error(fst_kw, kw, "serialize_all"));
}
serialize_all_kw = Some(kw);
output.case_style = Some(case_style);
}
EnumMeta::AsciiCaseInsensitive(kw) => {
if let Some(fst_kw) = ascii_case_insensitive_kw {
return Err(occurrence_error(fst_kw, kw, "ascii_case_insensitive"));
}
ascii_case_insensitive_kw = Some(kw);
output.ascii_case_insensitive = true;
}
EnumMeta::UsePhf(kw) => {
if let Some(fst_kw) = use_phf_kw {
return Err(occurrence_error(fst_kw, kw, "use_phf"));
}
use_phf_kw = Some(kw);
output.use_phf = true;
}
EnumMeta::Crate {
crate_module_path,
kw,
} => {
if let Some(fst_kw) = crate_module_path_kw {
return Err(occurrence_error(fst_kw, kw, "Crate"));
}
crate_module_path_kw = Some(kw);
output.crate_module_path = Some(crate_module_path);
}
}
}
let mut name_kw = None;
let mut vis_kw = None;
for meta in discriminants_meta {
match meta {
EnumDiscriminantsMeta::Derive { paths, .. } => {
output.discriminant_derives.extend(paths);
}
EnumDiscriminantsMeta::Name { name, kw } => {
if let Some(fst_kw) = name_kw {
return Err(occurrence_error(fst_kw, kw, "name"));
}
name_kw = Some(kw);
output.discriminant_name = Some(name);
}
EnumDiscriminantsMeta::Vis { vis, kw } => {
if let Some(fst_kw) = vis_kw {
return Err(occurrence_error(fst_kw, kw, "vis"));
}
vis_kw = Some(kw);
output.discriminant_vis = Some(vis);
}
EnumDiscriminantsMeta::Other { path, nested } => {
output.discriminant_others.push(quote! { #path(#nested) });
}
}
}
Ok(output)
}
}
impl StrumTypeProperties {
pub fn crate_module_path(&self) -> Path {
self.crate_module_path
.as_ref()
.map_or_else(|| parse_quote!(sea_orm::strum), |path| parse_quote!(#path))
}
}

View File

@ -0,0 +1,102 @@
use std::default::Default;
use syn::{Ident, LitStr, Variant};
use super::metadata::{kw, VariantExt, VariantMeta};
use super::occurrence_error;
pub trait HasStrumVariantProperties {
fn get_variant_properties(&self) -> syn::Result<StrumVariantProperties>;
}
#[derive(Clone, Eq, PartialEq, Debug, Default)]
pub struct StrumVariantProperties {
pub disabled: Option<kw::disabled>,
pub default: Option<kw::default>,
pub ascii_case_insensitive: Option<bool>,
pub message: Option<LitStr>,
pub detailed_message: Option<LitStr>,
pub documentation: Vec<LitStr>,
pub string_props: Vec<(LitStr, LitStr)>,
serialize: Vec<LitStr>,
to_string: Option<LitStr>,
ident: Option<Ident>,
}
impl HasStrumVariantProperties for Variant {
fn get_variant_properties(&self) -> syn::Result<StrumVariantProperties> {
let mut output = StrumVariantProperties {
ident: Some(self.ident.clone()),
..Default::default()
};
let mut message_kw = None;
let mut detailed_message_kw = None;
let mut to_string_kw = None;
let mut disabled_kw = None;
let mut default_kw = None;
let mut ascii_case_insensitive_kw = None;
for meta in self.get_metadata()? {
match meta {
VariantMeta::Message { value, kw } => {
if let Some(fst_kw) = message_kw {
return Err(occurrence_error(fst_kw, kw, "message"));
}
message_kw = Some(kw);
output.message = Some(value);
}
VariantMeta::DetailedMessage { value, kw } => {
if let Some(fst_kw) = detailed_message_kw {
return Err(occurrence_error(fst_kw, kw, "detailed_message"));
}
detailed_message_kw = Some(kw);
output.detailed_message = Some(value);
}
VariantMeta::Documentation { value } => {
output.documentation.push(value);
}
VariantMeta::Serialize { value, .. } => {
output.serialize.push(value);
}
VariantMeta::ToString { value, kw } => {
if let Some(fst_kw) = to_string_kw {
return Err(occurrence_error(fst_kw, kw, "to_string"));
}
to_string_kw = Some(kw);
output.to_string = Some(value);
}
VariantMeta::Disabled(kw) => {
if let Some(fst_kw) = disabled_kw {
return Err(occurrence_error(fst_kw, kw, "disabled"));
}
disabled_kw = Some(kw);
output.disabled = Some(kw);
}
VariantMeta::Default(kw) => {
if let Some(fst_kw) = default_kw {
return Err(occurrence_error(fst_kw, kw, "default"));
}
default_kw = Some(kw);
output.default = Some(kw);
}
VariantMeta::AsciiCaseInsensitive { kw, value } => {
if let Some(fst_kw) = ascii_case_insensitive_kw {
return Err(occurrence_error(fst_kw, kw, "ascii_case_insensitive"));
}
ascii_case_insensitive_kw = Some(kw);
output.ascii_case_insensitive = Some(value);
}
VariantMeta::Props { props, .. } => {
output.string_props.extend(props);
}
}
}
Ok(output)
}
}

View File

@ -0,0 +1,4 @@
//! Source code adapted from https://github.com/Peternator7/strum
pub mod enum_iter;
pub mod helpers;

View File

@ -359,5 +359,5 @@ pub use sea_query::Iden;
#[cfg(feature = "macros")] #[cfg(feature = "macros")]
pub use sea_query::Iden as DeriveIden; pub use sea_query::Iden as DeriveIden;
pub use sea_orm_macros::EnumIter;
pub use strum; pub use strum;
pub use strum::EnumIter;