diff --git a/sea-orm-macros/Cargo.toml b/sea-orm-macros/Cargo.toml index 312adeaa..64dcd0af 100644 --- a/sea-orm-macros/Cargo.toml +++ b/sea-orm-macros/Cargo.toml @@ -20,3 +20,4 @@ syn = { version = "^1", default-features = false, features = [ "full", "derive", quote = "^1" heck = "^0.3" proc-macro2 = "^1" +convert_case = "0.4" diff --git a/sea-orm-macros/src/lib.rs b/sea-orm-macros/src/lib.rs index a2217aa8..61551173 100644 --- a/sea-orm-macros/src/lib.rs +++ b/sea-orm-macros/src/lib.rs @@ -1,7 +1,12 @@ extern crate proc_macro; use proc_macro::TokenStream; -use syn::{parse_macro_input, DeriveInput}; + +use proc_macro2::{Ident, Span, TokenStream as TokenStream2}; +use quote::quote; +use syn::{Data, DeriveInput, Fields, Lit, Meta, parse_macro_input, punctuated::Punctuated, token::Comma}; + +use convert_case::{Case, Casing}; mod derives; @@ -104,3 +109,172 @@ pub fn test(_: TokenStream, input: TokenStream) -> TokenStream { ) .into() } + +#[proc_macro_derive(EntityModel, attributes(sea_orm))] +pub fn derive_entity_model(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + + if input.ident != "Model" { + panic!("Struct name must be Model"); + } + + // if #[sea_orm(table_name = "foo")] specified, create Entity struct + let table_name = input.attrs.iter().filter_map(|attr| { + if attr.path.get_ident()? != "sea_orm" { + return None; + } + + let list = attr.parse_args_with(Punctuated::::parse_terminated).ok()?; + for meta in list.iter() { + if let Meta::NameValue(nv) = meta { + if nv.path.get_ident()? == "table_name" { + let table_name = &nv.lit; + return Some(quote! { +#[derive(Copy, Clone, Default, Debug, sea_orm::prelude::DeriveEntity)] +pub struct Entity; + +impl sea_orm::prelude::EntityName for Entity { + fn table_name(&self) -> &str { + #table_name + } +} + }); + } + } + } + + None + }).next().unwrap_or_default(); + + // generate Column enum and it's ColumnTrait impl + let mut columns_enum: Punctuated<_, Comma> = Punctuated::new(); + let mut columns_trait: Punctuated<_, Comma> = Punctuated::new(); + let mut primary_keys: Punctuated<_, Comma> = Punctuated::new(); + if let Data::Struct(item_struct) = input.data { + if let Fields::Named(fields) = item_struct.fields { + for field in fields.named { + if let Some(ident) = &field.ident { + let field_name = Ident::new(&ident.to_string().to_case(Case::Pascal), Span::call_site()); + columns_enum.push(quote! { #field_name }); + + let mut nullable = false; + let mut sql_type = None; + // search for #[sea_orm(primary_key, type = "String", nullable)] + field.attrs.iter().for_each(|attr| { + if let Some(ident) = attr.path.get_ident() { + if ident != "sea_orm" { + return; + } + } + else { + return; + } + + // single param + if let Ok(list) = attr.parse_args_with(Punctuated::::parse_terminated) { + for meta in list.iter() { + match meta { + Meta::NameValue(nv) => { + if let Some(name) = nv.path.get_ident() { + if name == "type" { + if let Lit::Str(litstr) = &nv.lit { + let ty: TokenStream2 = syn::parse_str(&litstr.value()).unwrap(); + sql_type = Some(ty); + } + } + } + }, + Meta::Path(p) => { + if let Some(name) = p.get_ident() { + if name == "primary_key" { + primary_keys.push(quote! { #field_name }); + } + else if name == "nullable" { + nullable = true; + } + } + }, + _ => {}, + } + } + } + }); + let field_type = sql_type.unwrap_or_else(|| { + let field_type = &field.ty; + let temp = quote! { #field_type } + .to_string()//Example: "Option < String >" + .replace(" ", ""); + let temp = if temp.starts_with("Option<") { + nullable = true; + &temp[7..(temp.len() - 1)] + } + else { + temp.as_str() + }; + match temp { + "char" => quote! { Char(None) }, + "String" | "&str" => quote! { String(None) }, + "u8" | "i8" => quote! { TinyInteger }, + "u16" | "i16" => quote! { SmallInteger }, + "u32" | "u64" | "i32" | "i64" => quote! { Integer }, + "u128" | "i128" => quote! { BigInteger }, + "f32" => quote! { Float }, + "f64" => quote! { Double }, + "bool" => quote! { Boolean }, + "NaiveDate" => quote! { Date }, + "NaiveTime" => quote! { Time }, + "NaiveDateTime" => quote! { DateTime }, + "Uuid" => quote! { Uuid }, + "Decimal" => quote! { BigInteger }, + _ => panic!("unrecognized type {}", temp), + } + }); + + if nullable { + columns_trait.push(quote! { Self::#field_name => sea_orm::prelude::ColumnType::#field_type.def().null() }); + } + else { + columns_trait.push(quote! { Self::#field_name => sea_orm::prelude::ColumnType::#field_type.def() }); + } + } + } + } + } + + let primary_key = (!primary_keys.is_empty()).then(|| { + let auto_increment = primary_keys.len() == 1; + quote! { +#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)] +pub enum PrimaryKey { + #primary_keys +} + +impl PrimaryKeyTrait for PrimaryKey { + fn auto_increment() -> bool { + #auto_increment + } +} + } + }).unwrap_or_default(); + + return quote! { +#[derive(Copy, Clone, Debug, sea_orm::prelude::EnumIter, sea_orm::prelude::DeriveColumn)] +pub enum Column { + #columns_enum +} + +impl sea_orm::prelude::ColumnTrait for Column { + type EntityName = Entity; + + fn def(&self) -> sea_orm::prelude::ColumnDef { + match self { + #columns_trait + } + } +} + +#table_name + +#primary_key + }.into(); +}