From 290f78454b588097f302ff68757883f53d081963 Mon Sep 17 00:00:00 2001 From: Ari Seyhun Date: Tue, 7 Sep 2021 22:07:29 +0800 Subject: [PATCH] Rewrite DeriveModel --- sea-orm-macros/src/derives/model.rs | 174 +++++++++++++++++++--------- sea-orm-macros/src/lib.rs | 12 +- 2 files changed, 126 insertions(+), 60 deletions(-) diff --git a/sea-orm-macros/src/derives/model.rs b/sea-orm-macros/src/derives/model.rs index 1faf3642..4c3b0a03 100644 --- a/sea-orm-macros/src/derives/model.rs +++ b/sea-orm-macros/src/derives/model.rs @@ -1,57 +1,125 @@ +use std::iter::FromIterator; use heck::CamelCase; -use proc_macro2::{Ident, TokenStream}; +use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; -use syn::{Data, DataStruct, Field, Fields}; +use crate::attributes::derive_attr; -pub fn expand_derive_model(ident: Ident, data: Data) -> syn::Result { - let fields = match data { - Data::Struct(DataStruct { - fields: Fields::Named(named), - .. - }) => named.named, - _ => { - return Ok(quote_spanned! { - ident.span() => compile_error!("you can only derive DeriveModel on structs"); - }) - } - }; - - let field: Vec = fields - .clone() - .into_iter() - .map(|Field { ident, .. }| format_ident!("{}", ident.unwrap().to_string())) - .collect(); - - let name: Vec = fields - .into_iter() - .map(|Field { ident, .. }| format_ident!("{}", ident.unwrap().to_string().to_camel_case())) - .collect(); - - Ok(quote!( - impl sea_orm::ModelTrait for #ident { - type Entity = Entity; - - fn get(&self, c: ::Column) -> sea_orm::Value { - match c { - #(::Column::#name => self.#field.clone().into(),)* - _ => panic!("This Model does not have this field"), - } - } - - fn set(&mut self, c: ::Column, v: sea_orm::Value) { - match c { - #(::Column::#name => self.#field = v.unwrap(),)* - _ => panic!("This Model does not have this field"), - } - } - } - - impl sea_orm::FromQueryResult for #ident { - fn from_query_result(row: &sea_orm::QueryResult, pre: &str) -> Result { - Ok(Self { - #(#field: row.try_get(pre, <::Entity as EntityTrait>::Column::#name.as_str().into())?),* - }) - } - } - )) +enum Error { + InputNotStruct, + Syn(syn::Error), +} + +struct DeriveModel { + column_idents: Vec, + entity_ident: syn::Ident, + field_idents: Vec, + ident: syn::Ident, +} + +impl DeriveModel { + fn new(input: syn::DeriveInput) -> Result { + let fields = match input.data { + syn::Data::Struct(syn::DataStruct { + fields: syn::Fields::Named(syn::FieldsNamed { named, .. }), + .. + }) => named, + _ => return Err(Error::InputNotStruct), + }; + + let sea_attr = derive_attr::SeaOrm::try_from_attributes(&input.attrs) + .map_err(Error::Syn)? + .unwrap_or_default(); + + let ident = input.ident; + let entity_ident = sea_attr.entity.unwrap_or_else(|| format_ident!("Entity")); + + let field_idents = fields + .iter() + .map(|field| field.ident.as_ref().unwrap().clone()) + .collect(); + + let column_idents = fields + .iter() + .map(|field| { + format_ident!( + "{}", + field.ident.as_ref().unwrap().to_string().to_camel_case() + ) + }) + .collect(); + + Ok(DeriveModel { + column_idents, + entity_ident, + field_idents, + ident, + }) + } + + fn expand(&self) -> syn::Result { + let expanded_impl_from_query_result = self.impl_from_query_result(); + let expanded_impl_model_trait = self.impl_model_trait(); + + Ok(TokenStream::from_iter([ + expanded_impl_from_query_result, + expanded_impl_model_trait, + ])) + } + + fn impl_from_query_result(&self) -> TokenStream { + let ident = &self.ident; + let field_idents = &self.field_idents; + let column_idents = &self.column_idents; + + quote!( + impl sea_orm::FromQueryResult for #ident { + fn from_query_result(row: &sea_orm::QueryResult, pre: &str) -> Result { + Ok(Self { + #(#field_idents: row.try_get(pre, sea_orm::IdenStatic::as_str(&<::Entity as sea_orm::entity::EntityTrait>::Column::#column_idents).into())?),* + }) + } + } + ) + } + + fn impl_model_trait(&self) -> TokenStream { + let ident = &self.ident; + let entity_ident = &self.entity_ident; + let field_idents = &self.field_idents; + let column_idents = &self.column_idents; + + let missing_field_msg = format!("field does not exist on {}", ident); + + quote!( + impl sea_orm::ModelTrait for #ident { + type Entity = #entity_ident; + + fn get(&self, c: ::Column) -> sea_orm::Value { + match c { + #(::Column::#column_idents => self.#field_idents.clone().into(),)* + _ => panic!(#missing_field_msg), + } + } + + fn set(&mut self, c: ::Column, v: sea_orm::Value) { + match c { + #(::Column::#column_idents => self.#field_idents = v.unwrap(),)* + _ => panic!(#missing_field_msg), + } + } + } + ) + } +} + +pub fn expand_derive_model(input: syn::DeriveInput) -> syn::Result { + let ident_span = input.ident.span(); + + match DeriveModel::new(input) { + Ok(model) => model.expand(), + Err(Error::InputNotStruct) => Ok(quote_spanned! { + ident_span => compile_error!("you can only derive DeriveModel on structs"); + }), + Err(Error::Syn(err)) => Err(err), + } } diff --git a/sea-orm-macros/src/lib.rs b/sea-orm-macros/src/lib.rs index fe132e5f..22b66661 100644 --- a/sea-orm-macros/src/lib.rs +++ b/sea-orm-macros/src/lib.rs @@ -60,14 +60,12 @@ pub fn derive_custom_column(input: TokenStream) -> TokenStream { } } -#[proc_macro_derive(DeriveModel)] +#[proc_macro_derive(DeriveModel, attributes(sea_orm))] pub fn derive_model(input: TokenStream) -> TokenStream { - let DeriveInput { ident, data, .. } = parse_macro_input!(input); - - match derives::expand_derive_model(ident, data) { - Ok(ts) => ts.into(), - Err(e) => e.to_compile_error().into(), - } + let input = parse_macro_input!(input as DeriveInput); + derives::expand_derive_model(input) + .unwrap_or_else(Error::into_compile_error) + .into() } #[proc_macro_derive(DeriveActiveModel)]