Rewrite DeriveModel

This commit is contained in:
Ari Seyhun 2021-09-07 22:07:29 +08:00 committed by Chris Tsang
parent 46d9fd30e6
commit 290f78454b
2 changed files with 126 additions and 60 deletions

View File

@ -1,57 +1,125 @@
use std::iter::FromIterator;
use heck::CamelCase; use heck::CamelCase;
use proc_macro2::{Ident, TokenStream}; use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned}; use quote::{format_ident, quote, quote_spanned};
use syn::{Data, DataStruct, Field, Fields}; use crate::attributes::derive_attr;
pub fn expand_derive_model(ident: Ident, data: Data) -> syn::Result<TokenStream> { enum Error {
let fields = match data { InputNotStruct,
Data::Struct(DataStruct { Syn(syn::Error),
fields: Fields::Named(named), }
..
}) => named.named, struct DeriveModel {
_ => { column_idents: Vec<syn::Ident>,
return Ok(quote_spanned! { entity_ident: syn::Ident,
ident.span() => compile_error!("you can only derive DeriveModel on structs"); field_idents: Vec<syn::Ident>,
}) ident: syn::Ident,
} }
};
impl DeriveModel {
let field: Vec<Ident> = fields fn new(input: syn::DeriveInput) -> Result<Self, Error> {
.clone() let fields = match input.data {
.into_iter() syn::Data::Struct(syn::DataStruct {
.map(|Field { ident, .. }| format_ident!("{}", ident.unwrap().to_string())) fields: syn::Fields::Named(syn::FieldsNamed { named, .. }),
.collect(); ..
}) => named,
let name: Vec<Ident> = fields _ => return Err(Error::InputNotStruct),
.into_iter() };
.map(|Field { ident, .. }| format_ident!("{}", ident.unwrap().to_string().to_camel_case()))
.collect(); let sea_attr = derive_attr::SeaOrm::try_from_attributes(&input.attrs)
.map_err(Error::Syn)?
Ok(quote!( .unwrap_or_default();
impl sea_orm::ModelTrait for #ident {
type Entity = Entity; let ident = input.ident;
let entity_ident = sea_attr.entity.unwrap_or_else(|| format_ident!("Entity"));
fn get(&self, c: <Self::Entity as EntityTrait>::Column) -> sea_orm::Value {
match c { let field_idents = fields
#(<Self::Entity as EntityTrait>::Column::#name => self.#field.clone().into(),)* .iter()
_ => panic!("This Model does not have this field"), .map(|field| field.ident.as_ref().unwrap().clone())
} .collect();
}
let column_idents = fields
fn set(&mut self, c: <Self::Entity as EntityTrait>::Column, v: sea_orm::Value) { .iter()
match c { .map(|field| {
#(<Self::Entity as EntityTrait>::Column::#name => self.#field = v.unwrap(),)* format_ident!(
_ => panic!("This Model does not have this field"), "{}",
} field.ident.as_ref().unwrap().to_string().to_camel_case()
} )
} })
.collect();
impl sea_orm::FromQueryResult for #ident {
fn from_query_result(row: &sea_orm::QueryResult, pre: &str) -> Result<Self, sea_orm::DbErr> { Ok(DeriveModel {
Ok(Self { column_idents,
#(#field: row.try_get(pre, <<Self as ModelTrait>::Entity as EntityTrait>::Column::#name.as_str().into())?),* entity_ident,
}) field_idents,
} ident,
} })
)) }
fn expand(&self) -> syn::Result<TokenStream> {
let expanded_impl_from_query_result = self.impl_from_query_result();
let expanded_impl_model_trait = self.impl_model_trait();
Ok(TokenStream::from_iter([
expanded_impl_from_query_result,
expanded_impl_model_trait,
]))
}
fn impl_from_query_result(&self) -> TokenStream {
let ident = &self.ident;
let field_idents = &self.field_idents;
let column_idents = &self.column_idents;
quote!(
impl sea_orm::FromQueryResult for #ident {
fn from_query_result(row: &sea_orm::QueryResult, pre: &str) -> Result<Self, sea_orm::DbErr> {
Ok(Self {
#(#field_idents: row.try_get(pre, sea_orm::IdenStatic::as_str(&<<Self as sea_orm::ModelTrait>::Entity as sea_orm::entity::EntityTrait>::Column::#column_idents).into())?),*
})
}
}
)
}
fn impl_model_trait(&self) -> TokenStream {
let ident = &self.ident;
let entity_ident = &self.entity_ident;
let field_idents = &self.field_idents;
let column_idents = &self.column_idents;
let missing_field_msg = format!("field does not exist on {}", ident);
quote!(
impl sea_orm::ModelTrait for #ident {
type Entity = #entity_ident;
fn get(&self, c: <Self::Entity as sea_orm::entity::EntityTrait>::Column) -> sea_orm::Value {
match c {
#(<Self::Entity as sea_orm::entity::EntityTrait>::Column::#column_idents => self.#field_idents.clone().into(),)*
_ => panic!(#missing_field_msg),
}
}
fn set(&mut self, c: <Self::Entity as sea_orm::entity::EntityTrait>::Column, v: sea_orm::Value) {
match c {
#(<Self::Entity as sea_orm::entity::EntityTrait>::Column::#column_idents => self.#field_idents = v.unwrap(),)*
_ => panic!(#missing_field_msg),
}
}
}
)
}
}
pub fn expand_derive_model(input: syn::DeriveInput) -> syn::Result<TokenStream> {
let ident_span = input.ident.span();
match DeriveModel::new(input) {
Ok(model) => model.expand(),
Err(Error::InputNotStruct) => Ok(quote_spanned! {
ident_span => compile_error!("you can only derive DeriveModel on structs");
}),
Err(Error::Syn(err)) => Err(err),
}
} }

View File

@ -60,14 +60,12 @@ pub fn derive_custom_column(input: TokenStream) -> TokenStream {
} }
} }
#[proc_macro_derive(DeriveModel)] #[proc_macro_derive(DeriveModel, attributes(sea_orm))]
pub fn derive_model(input: TokenStream) -> TokenStream { pub fn derive_model(input: TokenStream) -> TokenStream {
let DeriveInput { ident, data, .. } = parse_macro_input!(input); let input = parse_macro_input!(input as DeriveInput);
derives::expand_derive_model(input)
match derives::expand_derive_model(ident, data) { .unwrap_or_else(Error::into_compile_error)
Ok(ts) => ts.into(), .into()
Err(e) => e.to_compile_error().into(),
}
} }
#[proc_macro_derive(DeriveActiveModel)] #[proc_macro_derive(DeriveActiveModel)]