Derive value type (#1720)

* progress (WIP)

* WIP

* WIP, finished structure except impl programming

* WIP

* revert event_trigger modification

* adding tests and mods

* completed derive value type

* fixed tests, adjusted code position and completed error messages

* column type commit

* added attribute array_type and column_type for specification

* renamed items and files, and removed debug messages

* move attributes outside of the wrapper struct

* refactored code for type matching, and restructured code in test cases

* clippy fix

* fix(doc): fix salvo framework name (#1731)

Co-authored-by: 黄景祥 <jingxiang.huang@baishancloud.com>

* fmt

* changed json_vec_test to use DeriveValueType

* fmt

* Revert "changed json_vec_test to use DeriveValueType"

This reverts commit 92bbf3b6e4eca72e0af0af35776aeec3ee035602.

* added test cases for inserting StringVec in a model

* fmt

* Try non-public wrapped type

* Refactoring

---------

Co-authored-by: joelhy <joelhy@gmail.com>
Co-authored-by: 黄景祥 <jingxiang.huang@baishancloud.com>
Co-authored-by: Billy Chan <ccw.billy.123@gmail.com>
This commit is contained in:
darkmmon 2023-07-10 10:48:20 +08:00 committed by GitHub
parent 207c008e5b
commit f5a7311794
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 495 additions and 52 deletions

View File

@ -1,10 +1,9 @@
use super::util::{escape_rust_keyword, trim_starting_raw_identifier}; use super::util::{escape_rust_keyword, trim_starting_raw_identifier};
use heck::{ToSnakeCase, ToUpperCamelCase}; use heck::{ToSnakeCase, ToUpperCamelCase};
use proc_macro2::{Ident, Span, TokenStream}; use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote, quote_spanned}; use quote::quote;
use syn::{ use syn::{
punctuated::Punctuated, spanned::Spanned, token::Comma, Attribute, Data, Expr, Fields, Lit, punctuated::Punctuated, spanned::Spanned, token::Comma, Attribute, Data, Expr, Fields, Lit,
LitStr, Type,
}; };
/// Method to derive an Model /// Method to derive an Model
@ -245,57 +244,12 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec<Attribute>) -> syn::Res
} else { } else {
field_type.as_str() field_type.as_str()
}; };
let sea_query_col_type = match sql_type {
Some(t) => quote! { sea_orm::prelude::ColumnType::#t },
None => {
let col_type = match field_type {
"char" => quote! { Char(None) },
"String" | "&str" => quote! { String(None) },
"i8" => quote! { TinyInteger },
"u8" => quote! { TinyUnsigned },
"i16" => quote! { SmallInteger },
"u16" => quote! { SmallUnsigned },
"i32" => quote! { Integer },
"u32" => quote! { Unsigned },
"i64" => quote! { BigInteger },
"u64" => quote! { BigUnsigned },
"f32" => quote! { Float },
"f64" => quote! { Double },
"bool" => quote! { Boolean },
"Date" | "NaiveDate" => quote! { Date },
"Time" | "NaiveTime" => quote! { Time },
"DateTime" | "NaiveDateTime" => {
quote! { DateTime }
}
"DateTimeUtc" | "DateTimeLocal" | "DateTimeWithTimeZone" => {
quote! { TimestampWithTimeZone }
}
"Uuid" => quote! { Uuid },
"Json" => quote! { Json },
"Decimal" => quote! { Decimal(None) },
"Vec<u8>" => {
quote! { Binary(sea_orm::sea_query::BlobSize::Blob(None)) }
}
_ => {
// Assumed it's ActiveEnum if none of the above type matches
quote! {}
}
};
if col_type.is_empty() {
let field_span = field.span(); let field_span = field.span();
let ty: Type = LitStr::new(field_type, field_span).parse()?;
let def = quote_spanned! { field_span => let sea_query_col_type = crate::derives::sql_type_match::col_type_match(
std::convert::Into::<sea_orm::ColumnType>::into( sql_type, field_type, field_span,
<#ty as sea_orm::sea_query::ValueType>::column_type() );
)
};
quote! { #def }
} else {
quote! { sea_orm::prelude::ColumnType::#col_type }
}
}
};
let col_def = let col_def =
quote! { sea_orm::prelude::ColumnTypeTrait::def(#sea_query_col_type) }; quote! { sea_orm::prelude::ColumnTypeTrait::def(#sea_query_col_type) };

View File

@ -13,8 +13,10 @@ mod partial_model;
mod primary_key; mod primary_key;
mod related_entity; mod related_entity;
mod relation; mod relation;
mod sql_type_match;
mod try_getable_from_json; mod try_getable_from_json;
mod util; mod util;
mod value_type;
pub use active_enum::*; pub use active_enum::*;
pub use active_model::*; pub use active_model::*;
@ -31,3 +33,4 @@ pub use primary_key::*;
pub use related_entity::*; pub use related_entity::*;
pub use relation::*; pub use relation::*;
pub use try_getable_from_json::*; pub use try_getable_from_json::*;
pub use value_type::*;

View File

@ -0,0 +1,116 @@
use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned};
use syn::{LitStr, Type};
pub fn col_type_match(
col_type: Option<TokenStream>,
field_type: &str,
field_span: Span,
) -> TokenStream {
match col_type {
Some(t) => quote! { sea_orm::prelude::ColumnType::#t },
None => {
let col_type = match field_type {
"char" => quote! { Char(None) },
"String" | "&str" => quote! { String(None) },
"i8" => quote! { TinyInteger },
"u8" => quote! { TinyUnsigned },
"i16" => quote! { SmallInteger },
"u16" => quote! { SmallUnsigned },
"i32" => quote! { Integer },
"u32" => quote! { Unsigned },
"i64" => quote! { BigInteger },
"u64" => quote! { BigUnsigned },
"f32" => quote! { Float },
"f64" => quote! { Double },
"bool" => quote! { Boolean },
"Date" | "NaiveDate" => quote! { Date },
"Time" | "NaiveTime" => quote! { Time },
"DateTime" | "NaiveDateTime" => {
quote! { DateTime }
}
"DateTimeUtc" | "DateTimeLocal" | "DateTimeWithTimeZone" => {
quote! { TimestampWithTimeZone }
}
"Uuid" => quote! { Uuid },
"Json" => quote! { Json },
"Decimal" => quote! { Decimal(None) },
"Vec<u8>" => {
quote! { Binary(sea_orm::sea_query::BlobSize::Blob(None)) }
}
_ => {
// Assumed it's ActiveEnum if none of the above type matches
quote! {}
}
};
if col_type.is_empty() {
let ty: Type = LitStr::new(field_type, field_span)
.parse()
.expect("field type error");
let def = quote_spanned! { field_span =>
std::convert::Into::<sea_orm::sea_query::ColumnType>::into(
<#ty as sea_orm::sea_query::ValueType>::column_type()
)
};
quote! { #def }
} else {
quote! { sea_orm::prelude::ColumnType::#col_type }
}
}
}
}
pub fn arr_type_match(
arr_type: Option<TokenStream>,
field_type: &str,
field_span: Span,
) -> TokenStream {
match arr_type {
Some(t) => quote! { sea_orm::sea_query::ArrayType::#t },
None => {
let arr_type = match field_type {
"char" => quote! { Char },
"String" | "&str" => quote! { String },
"i8" => quote! { TinyInt },
"u8" => quote! { TinyUnsigned },
"i16" => quote! { SmallInt },
"u16" => quote! { SmallUnsigned },
"i32" => quote! { Int },
"u32" => quote! { Unsigned },
"i64" => quote! { BigInt },
"u64" => quote! { BigUnsigned },
"f32" => quote! { Float },
"f64" => quote! { Double },
"bool" => quote! { Bool },
"Date" | "NaiveDate" => quote! { ChronoDate },
"Time" | "NaiveTime" => quote! { ChronoTime },
"DateTime" | "NaiveDateTime" => {
quote! { ChronoDateTime }
}
"DateTimeUtc" | "DateTimeLocal" | "DateTimeWithTimeZone" => {
quote! { ChronoDateTimeWithTimeZone }
}
"Uuid" => quote! { Uuid },
"Json" => quote! { Json },
"Decimal" => quote! { Decimal },
_ => {
// Assumed it's ActiveEnum if none of the above type matches
quote! {}
}
};
if arr_type.is_empty() {
let ty: Type = LitStr::new(field_type, field_span)
.parse()
.expect("field type error");
let def = quote_spanned! { field_span =>
std::convert::Into::<sea_orm::sea_query::ArrayType>::into(
<#ty as sea_orm::sea_query::ValueType>::array_type()
)
};
quote! { #def }
} else {
quote! { sea_orm::sea_query::ArrayType::#arr_type }
}
}
}
}

View File

@ -0,0 +1,144 @@
use proc_macro2::TokenStream;
use quote::quote;
use syn::{spanned::Spanned, Lit, Type};
struct DeriveValueType {
name: syn::Ident,
ty: Type,
column_type: TokenStream,
array_type: TokenStream,
}
impl DeriveValueType {
pub fn new(input: syn::DeriveInput) -> Option<Self> {
let dat = input.data;
let fields: Option<syn::punctuated::Punctuated<syn::Field, syn::token::Comma>> = match dat {
syn::Data::Struct(syn::DataStruct {
fields: syn::Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }),
..
}) => Some(unnamed),
_ => None,
};
let field = fields
.expect("This derive accept only struct")
.first()
.expect("The struct should contain one value field")
.to_owned();
let name = input.ident;
let mut col_type = None;
let mut arr_type = None;
for attr in input.attrs.iter() {
if !attr.path().is_ident("sea_orm") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("column_type") {
let lit = meta.value()?.parse()?;
if let Lit::Str(litstr) = lit {
let ty: TokenStream = syn::parse_str(&litstr.value())?;
col_type = Some(ty);
} else {
return Err(meta.error(format!("Invalid column_type {:?}", lit)));
}
} else if meta.path.is_ident("array_type") {
let lit = meta.value()?.parse()?;
if let Lit::Str(litstr) = lit {
let ty: TokenStream = syn::parse_str(&litstr.value())?;
arr_type = Some(ty);
} else {
return Err(meta.error(format!("Invalid array_type {:?}", lit)));
}
} else {
// received other attribute
return Err(meta.error(format!("Invalid attribute {:?}", meta.path)));
}
Ok(())
})
.unwrap_or(());
}
let ty = field.clone().ty;
let field_type = quote! { #ty }
.to_string() //E.g.: "Option < String >"
.replace(' ', ""); // Remove spaces
let field_type = if field_type.starts_with("Option<") {
&field_type[7..(field_type.len() - 1)] // Extract `T` out of `Option<T>`
} else {
field_type.as_str()
};
let field_span = field.span();
let column_type =
crate::derives::sql_type_match::col_type_match(col_type, field_type, field_span);
let array_type =
crate::derives::sql_type_match::arr_type_match(arr_type, field_type, field_span);
Some(DeriveValueType {
name,
ty,
column_type,
array_type,
})
}
fn expand(&self) -> syn::Result<TokenStream> {
let expanded_impl_value_type: TokenStream = self.impl_value_type();
Ok(expanded_impl_value_type)
}
fn impl_value_type(&self) -> TokenStream {
let name = &self.name;
let field_type = &self.ty;
let column_type = &self.column_type;
let array_type = &self.array_type;
quote!(
#[automatically_derived]
impl From<#name> for Value {
fn from(source: #name) -> Self {
source.0.into()
}
}
#[automatically_derived]
impl sea_orm::TryGetable for #name {
fn try_get_by<I: sea_orm::ColIdx>(res: &QueryResult, idx: I) -> Result<Self, sea_orm::TryGetError> {
<#field_type as sea_orm::TryGetable>::try_get_by(res, idx).map(|v| #name(v))
}
}
#[automatically_derived]
impl sea_orm::sea_query::ValueType for #name {
fn try_from(v: Value) -> Result<Self, sea_orm::sea_query::ValueTypeErr> {
<#field_type as sea_orm::sea_query::ValueType>::try_from(v).map(|v| #name(v))
}
fn type_name() -> String {
stringify!(#name).to_owned()
}
fn array_type() -> sea_orm::sea_query::ArrayType {
#array_type
}
fn column_type() -> sea_orm::sea_query::ColumnType {
#column_type
}
}
)
}
}
pub fn expand_derive_value_type(input: syn::DeriveInput) -> syn::Result<TokenStream> {
let input_span = input.span();
match DeriveValueType::new(input) {
Some(model) => model.expand(),
None => Err(syn::Error::new(input_span, "error")),
}
}

View File

@ -832,3 +832,13 @@ pub fn enum_iter(input: TokenStream) -> TokenStream {
.unwrap_or_else(Error::into_compile_error) .unwrap_or_else(Error::into_compile_error)
.into() .into()
} }
#[cfg(feature = "derive")]
#[proc_macro_derive(DeriveValueType, attributes(sea_orm))]
pub fn derive_value_type(input: TokenStream) -> TokenStream {
let derive_input = parse_macro_input!(input as DeriveInput);
match derives::expand_derive_value_type(derive_input) {
Ok(token_stream) => token_stream.into(),
Err(e) => e.to_compile_error().into(),
}
}

View File

@ -23,6 +23,7 @@ pub mod self_join;
pub mod teas; pub mod teas;
pub mod transaction_log; pub mod transaction_log;
pub mod uuid_fmt; pub mod uuid_fmt;
pub mod value_type;
pub use active_enum::Entity as ActiveEnum; pub use active_enum::Entity as ActiveEnum;
pub use active_enum_child::Entity as ActiveEnumChild; pub use active_enum_child::Entity as ActiveEnumChild;

View File

@ -48,8 +48,10 @@ pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> {
create_binary_table(db).await?; create_binary_table(db).await?;
create_bits_table(db).await?; create_bits_table(db).await?;
create_dyn_table_name_lazy_static_table(db).await?; create_dyn_table_name_lazy_static_table(db).await?;
create_value_type_table(db).await?;
if DbBackend::Postgres == db_backend { if DbBackend::Postgres == db_backend {
create_value_type_postgres_table(db).await?;
create_collection_table(db).await?; create_collection_table(db).await?;
create_event_trigger_table(db).await?; create_event_trigger_table(db).await?;
} }
@ -634,3 +636,47 @@ pub async fn create_dyn_table_name_lazy_static_table(db: &DbConn) -> Result<(),
Ok(()) Ok(())
} }
pub async fn create_value_type_table(db: &DbConn) -> Result<ExecResult, DbErr> {
let general_stmt = sea_query::Table::create()
.table(value_type::value_type_general::Entity)
.col(
ColumnDef::new(value_type::value_type_general::Column::Id)
.integer()
.not_null()
.auto_increment()
.primary_key(),
)
.col(
ColumnDef::new(value_type::value_type_general::Column::Number)
.integer()
.not_null(),
)
.to_owned();
create_table(db, &general_stmt, value_type::value_type_general::Entity).await
}
pub async fn create_value_type_postgres_table(db: &DbConn) -> Result<ExecResult, DbErr> {
let postgres_stmt = sea_query::Table::create()
.table(value_type::value_type_pg::Entity)
.col(
ColumnDef::new(value_type::value_type_pg::Column::Id)
.integer()
.not_null()
.auto_increment()
.primary_key(),
)
.col(
ColumnDef::new(value_type::value_type_pg::Column::Number)
.integer()
.not_null(),
)
.col(
ColumnDef::new(json_vec::Column::StrVec)
.array(sea_query::ColumnType::String(None))
.not_null(),
)
.to_owned();
create_table(db, &postgres_stmt, value_type::value_type_pg::Entity).await
}

View File

@ -0,0 +1,59 @@
pub mod value_type_general {
use super::*;
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "value_type")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub number: Integer,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}
}
pub mod value_type_pg {
use super::*;
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "value_type_postgres")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub number: Integer,
pub str_vec: StringVec,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}
}
use sea_orm::entity::prelude::*;
use sea_orm_macros::DeriveValueType;
#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)]
#[sea_orm(array_type = "Int")]
pub struct Integer(i32);
impl<T> From<T> for Integer
where
T: Into<i32>,
{
fn from(v: T) -> Integer {
Integer(v.into())
}
}
#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)]
#[sea_orm(column_type = "Boolean", array_type = "Bool")]
pub struct Boolbean(pub String);
#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)]
pub struct StringVec(pub Vec<String>);

110
tests/value_type_tests.rs Normal file
View File

@ -0,0 +1,110 @@
pub mod common;
use std::sync::Arc;
use std::vec;
pub use common::{
features::{
value_type::{value_type_general, value_type_pg, Boolbean, Integer, StringVec},
*,
},
setup::*,
TestContext,
};
use pretty_assertions::assert_eq;
use sea_orm::{entity::prelude::*, entity::*, DatabaseConnection};
use sea_query::{ArrayType, ColumnType, Value, ValueType, ValueTypeErr};
#[sea_orm_macros::test]
#[cfg(any(
feature = "sqlx-mysql",
feature = "sqlx-sqlite",
feature = "sqlx-postgres"
))]
async fn main() -> Result<(), DbErr> {
let ctx = TestContext::new("value_type_tests").await;
create_tables(&ctx.db).await?;
insert_value(&ctx.db).await?;
ctx.delete().await;
if cfg!(feature = "sqlx-postgres") {
let ctx = TestContext::new("value_type_postgres_tests").await;
create_tables(&ctx.db).await?;
postgres_insert_value(&ctx.db).await?;
ctx.delete().await;
}
type_test();
conversion_test();
Ok(())
}
pub async fn insert_value(db: &DatabaseConnection) -> Result<(), DbErr> {
let model = value_type_general::Model {
id: 1,
number: 48.into(),
};
let result = model.clone().into_active_model().insert(db).await?;
assert_eq!(result, model);
Ok(())
}
pub async fn postgres_insert_value(db: &DatabaseConnection) -> Result<(), DbErr> {
let model = value_type_pg::Model {
id: 1,
number: 48.into(),
str_vec: StringVec(vec!["ab".to_string(), "cd".to_string()]),
};
let result = model.clone().into_active_model().insert(db).await?;
assert_eq!(result, model);
Ok(())
}
pub fn type_test() {
assert_eq!(StringVec::type_name(), "StringVec");
// custom types
assert_eq!(Integer::array_type(), ArrayType::Int);
assert_eq!(Integer::array_type(), ArrayType::Int);
assert_eq!(Boolbean::column_type(), ColumnType::Boolean);
assert_eq!(Boolbean::array_type(), ArrayType::Bool);
// self implied
assert_eq!(
StringVec::column_type(),
ColumnType::Array(Arc::new(ColumnType::String(None)))
);
assert_eq!(StringVec::array_type(), ArrayType::String);
}
pub fn conversion_test() {
let stringvec = StringVec(vec!["ab".to_string(), "cd".to_string()]);
let string: Value = stringvec.into();
assert_eq!(
string,
Value::Array(
ArrayType::String,
Some(Box::new(vec![
"ab".to_string().into(),
"cd".to_string().into()
]))
)
);
let value_random_int = Value::Int(Some(523));
let unwrap_int = Integer::unwrap(value_random_int.clone());
let try_from_int =
<Integer as ValueType>::try_from(value_random_int).expect("should be ok to convert");
// tests for unwrap and try_from
let direct_int: Integer = 523.into();
assert_eq!(direct_int, unwrap_int);
assert_eq!(direct_int, try_from_int);
// test for error
let try_from_string_vec = <StringVec as ValueType>::try_from(Value::Char(Some('a')))
.expect_err("should not be ok to convert char to stringvec");
assert_eq!(try_from_string_vec.to_string(), ValueTypeErr.to_string());
}