use crate::{Entity, Error}; use proc_macro2::TokenStream; use quote::quote; use std::{ fs::{self, File}, io::{self, Write}, path::Path, process::Command, }; #[derive(Clone, Debug)] pub struct EntityWriter { pub(crate) entities: Vec, } impl EntityWriter { pub fn generate(self, output_dir: &str) -> Result<(), Error> { for entity in self.entities.iter() { let code_blocks = Self::gen_code_blocks(entity); Self::write(output_dir, entity, code_blocks)?; } for entity in self.entities.iter() { Self::format_entity(output_dir, entity)?; } self.write_mod(output_dir)?; self.write_prelude(output_dir)?; Ok(()) } pub fn write_mod(&self, output_dir: &str) -> io::Result<()> { let file_name = "mod.rs"; let dir = Self::create_dir(output_dir)?; let file_path = dir.join(file_name); let mut file = fs::File::create(file_path)?; Self::write_doc_comment(&mut file)?; for entity in self.entities.iter() { let code_block = Self::gen_mod(entity); file.write_all(code_block.to_string().as_bytes())?; } Self::format_file(output_dir, file_name)?; Ok(()) } pub fn write_prelude(&self, output_dir: &str) -> io::Result<()> { let file_name = "prelude.rs"; let dir = Self::create_dir(output_dir)?; let file_path = dir.join(file_name); let mut file = fs::File::create(file_path)?; Self::write_doc_comment(&mut file)?; for entity in self.entities.iter() { let code_block = Self::gen_prelude_use(entity); file.write_all(code_block.to_string().as_bytes())?; } Self::format_file(output_dir, file_name)?; Ok(()) } pub fn write( output_dir: &str, entity: &Entity, code_blocks: Vec, ) -> io::Result<()> { let dir = Self::create_dir(output_dir)?; let file_path = dir.join(format!("{}.rs", entity.table_name)); let mut file = fs::File::create(file_path)?; Self::write_doc_comment(&mut file)?; for code_block in code_blocks { file.write_all(code_block.to_string().as_bytes())?; file.write_all(b"\n\n")?; } Ok(()) } pub fn write_doc_comment(file: &mut File) -> io::Result<()> { let ver = env!("CARGO_PKG_VERSION"); let comments = vec![format!( "//! SeaORM Entity. Generated by sea-orm-codegen {}", ver )]; for comment in comments { file.write_all(comment.as_bytes())?; file.write_all(b"\n\n")?; } Ok(()) } pub fn create_dir(output_dir: &str) -> io::Result<&Path> { let dir = Path::new(output_dir); fs::create_dir_all(dir)?; Ok(dir) } pub fn format_entity(output_dir: &str, entity: &Entity) -> io::Result<()> { Self::format_file(output_dir, &format!("{}.rs", entity.table_name)) } pub fn format_file(output_dir: &str, file_name: &str) -> io::Result<()> { Command::new("rustfmt") .arg(Path::new(output_dir).join(file_name)) .spawn()? .wait()?; Ok(()) } pub fn gen_code_blocks(entity: &Entity) -> Vec { let mut code_blocks = vec![ Self::gen_import(), Self::gen_entity_struct(), Self::gen_impl_entity_name(entity), Self::gen_model_struct(entity), Self::gen_column_enum(entity), Self::gen_primary_key_enum(entity), Self::gen_impl_primary_key(entity), Self::gen_relation_enum(entity), Self::gen_impl_column_trait(entity), Self::gen_impl_relation_trait(entity), ]; code_blocks.extend(Self::gen_impl_related(entity)); code_blocks.extend(vec![Self::gen_impl_active_model_behavior()]); code_blocks } pub fn gen_import() -> TokenStream { quote! { use sea_orm::entity::prelude::*; } } pub fn gen_entity_struct() -> TokenStream { quote! { #[derive(Copy, Clone, Default, Debug, DeriveEntity)] pub struct Entity; } } pub fn gen_impl_entity_name(entity: &Entity) -> TokenStream { let table_name_snake_case = entity.get_table_name_snake_case(); quote! { impl EntityName for Entity { fn table_name(&self) -> &str { #table_name_snake_case } } } } pub fn gen_model_struct(entity: &Entity) -> TokenStream { let column_names_snake_case = entity.get_column_names_snake_case(); let column_rs_types = entity.get_column_rs_types(); quote! { #[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel)] pub struct Model { #(pub #column_names_snake_case: #column_rs_types,)* } } } pub fn gen_column_enum(entity: &Entity) -> TokenStream { let column_names_camel_case = entity.get_column_names_camel_case(); quote! { #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] pub enum Column { #(#column_names_camel_case,)* } } } pub fn gen_primary_key_enum(entity: &Entity) -> TokenStream { let primary_key_names_camel_case = entity.get_primary_key_names_camel_case(); quote! { #[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)] pub enum PrimaryKey { #(#primary_key_names_camel_case,)* } } } pub fn gen_impl_primary_key(entity: &Entity) -> TokenStream { let primary_key_auto_increment = entity.get_primary_key_auto_increment(); quote! { impl PrimaryKeyTrait for PrimaryKey { fn auto_increment() -> bool { #primary_key_auto_increment } } } } pub fn gen_relation_enum(entity: &Entity) -> TokenStream { let relation_ref_tables_camel_case = entity.get_relation_ref_tables_camel_case(); quote! { #[derive(Copy, Clone, Debug, EnumIter)] pub enum Relation { #(#relation_ref_tables_camel_case,)* } } } pub fn gen_impl_column_trait(entity: &Entity) -> TokenStream { let column_names_camel_case = entity.get_column_names_camel_case(); let column_defs = entity.get_column_defs(); quote! { impl ColumnTrait for Column { type EntityName = Entity; fn def(&self) -> ColumnDef { match self { #(Self::#column_names_camel_case => #column_defs,)* } } } } } pub fn gen_impl_relation_trait(entity: &Entity) -> TokenStream { let relation_ref_tables_camel_case = entity.get_relation_ref_tables_camel_case(); let relation_defs = entity.get_relation_defs(); let quoted = if relation_ref_tables_camel_case.is_empty() { quote! { _ => panic!("No RelationDef"), } } else { quote! { #(Self::#relation_ref_tables_camel_case => #relation_defs,)* } }; quote! { impl RelationTrait for Relation { fn def(&self) -> RelationDef { match self { #quoted } } } } } pub fn gen_impl_related(entity: &Entity) -> Vec { let camel = entity.get_relation_ref_tables_camel_case(); let snake = entity.get_relation_ref_tables_snake_case(); camel .iter() .zip(snake) .map(|(c, s)| { quote! { impl Related for Entity { fn to() -> RelationDef { Relation::#c.def() } } } }) .collect() } pub fn gen_impl_active_model_behavior() -> TokenStream { quote! { impl ActiveModelBehavior for ActiveModel {} } } pub fn gen_mod(entity: &Entity) -> TokenStream { let table_name_snake_case_ident = entity.get_table_name_snake_case_ident(); quote! { pub mod #table_name_snake_case_ident; } } pub fn gen_prelude_use(entity: &Entity) -> TokenStream { let table_name_snake_case_ident = entity.get_table_name_snake_case_ident(); let table_name_camel_case_ident = entity.get_table_name_camel_case_ident(); quote! { pub use super::#table_name_snake_case_ident::Entity as #table_name_camel_case_ident; } } } #[cfg(test)] mod tests { use crate::{Column, Entity, EntityWriter, PrimaryKey, Relation, RelationType}; use proc_macro2::TokenStream; use sea_query::ColumnType; use std::io::{self, BufRead, BufReader}; const ENTITY_FILES: [&'static str; 5] = [ include_str!("../../tests/entity/cake.rs"), include_str!("../../tests/entity/cake_filling.rs"), include_str!("../../tests/entity/filling.rs"), include_str!("../../tests/entity/fruit.rs"), include_str!("../../tests/entity/vendor.rs"), ]; fn setup() -> Vec { vec![ Entity { table_name: "cake".to_owned(), columns: vec![ Column { name: "id".to_owned(), col_type: ColumnType::Integer(Some(11)), auto_increment: true, not_null: true, unique: false, }, Column { name: "name".to_owned(), col_type: ColumnType::String(Some(255)), auto_increment: false, not_null: true, unique: false, }, ], relations: vec![ Relation { ref_table: "cake_filling".to_owned(), columns: vec![], ref_columns: vec![], rel_type: RelationType::HasMany, }, Relation { ref_table: "fruit".to_owned(), columns: vec![], ref_columns: vec![], rel_type: RelationType::HasMany, }, ], primary_keys: vec![PrimaryKey { name: "id".to_owned(), }], }, Entity { table_name: "cake_filling".to_owned(), columns: vec![ Column { name: "cake_id".to_owned(), col_type: ColumnType::Integer(Some(11)), auto_increment: false, not_null: true, unique: false, }, Column { name: "filling_id".to_owned(), col_type: ColumnType::Integer(Some(11)), auto_increment: false, not_null: true, unique: false, }, ], relations: vec![ Relation { ref_table: "cake".to_owned(), columns: vec!["cake_id".to_owned()], ref_columns: vec!["id".to_owned()], rel_type: RelationType::BelongsTo, }, Relation { ref_table: "filling".to_owned(), columns: vec!["filling_id".to_owned()], ref_columns: vec!["id".to_owned()], rel_type: RelationType::BelongsTo, }, ], primary_keys: vec![ PrimaryKey { name: "cake_id".to_owned(), }, PrimaryKey { name: "filling_id".to_owned(), }, ], }, Entity { table_name: "filling".to_owned(), columns: vec![ Column { name: "id".to_owned(), col_type: ColumnType::Integer(Some(11)), auto_increment: true, not_null: true, unique: false, }, Column { name: "name".to_owned(), col_type: ColumnType::String(Some(255)), auto_increment: false, not_null: true, unique: false, }, ], relations: vec![Relation { ref_table: "cake_filling".to_owned(), columns: vec![], ref_columns: vec![], rel_type: RelationType::HasMany, }], primary_keys: vec![PrimaryKey { name: "id".to_owned(), }], }, Entity { table_name: "fruit".to_owned(), columns: vec![ Column { name: "id".to_owned(), col_type: ColumnType::Integer(Some(11)), auto_increment: true, not_null: true, unique: false, }, Column { name: "name".to_owned(), col_type: ColumnType::String(Some(255)), auto_increment: false, not_null: true, unique: false, }, Column { name: "cake_id".to_owned(), col_type: ColumnType::Integer(Some(11)), auto_increment: false, not_null: false, unique: false, }, ], relations: vec![ Relation { ref_table: "cake".to_owned(), columns: vec!["cake_id".to_owned()], ref_columns: vec!["id".to_owned()], rel_type: RelationType::BelongsTo, }, Relation { ref_table: "vendor".to_owned(), columns: vec![], ref_columns: vec![], rel_type: RelationType::HasMany, }, ], primary_keys: vec![PrimaryKey { name: "id".to_owned(), }], }, Entity { table_name: "vendor".to_owned(), columns: vec![ Column { name: "id".to_owned(), col_type: ColumnType::Integer(Some(11)), auto_increment: true, not_null: true, unique: false, }, Column { name: "name".to_owned(), col_type: ColumnType::String(Some(255)), auto_increment: false, not_null: true, unique: false, }, Column { name: "fruit_id".to_owned(), col_type: ColumnType::Integer(Some(11)), auto_increment: false, not_null: false, unique: false, }, ], relations: vec![Relation { ref_table: "fruit".to_owned(), columns: vec!["fruit_id".to_owned()], ref_columns: vec!["id".to_owned()], rel_type: RelationType::BelongsTo, }], primary_keys: vec![PrimaryKey { name: "id".to_owned(), }], }, ] } #[test] fn test_gen_code_blocks() -> io::Result<()> { let entities = setup(); assert_eq!(entities.len(), ENTITY_FILES.len()); for (i, entity) in entities.iter().enumerate() { let mut reader = BufReader::new(ENTITY_FILES[i].as_bytes()); let mut lines: Vec = Vec::new(); reader.read_until(b';', &mut Vec::new())?; let mut line = String::new(); while reader.read_line(&mut line)? > 0 { lines.push(line.to_owned()); line.clear(); } let content = lines.join(""); let expected: TokenStream = content.parse().unwrap(); let generated = EntityWriter::gen_code_blocks(entity) .into_iter() .skip(1) .fold(TokenStream::new(), |mut acc, tok| { acc.extend(tok); acc }); assert_eq!(expected.to_string(), generated.to_string()); } Ok(()) } }