From a793c59e5d6ee22ecd7d9cf6f7514a1800e5a17c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Ma=C4=87kowski?= Date: Mon, 2 Dec 2024 07:17:25 +0100 Subject: [PATCH] feat(orm): add foreign key support --- Cargo.lock | 2 + flareon-cli/Cargo.toml | 2 +- flareon-cli/src/migration_generator.rs | 558 +++--------------- flareon-cli/tests/migration_generator.rs | 65 +- .../tests/migration_generator/create_model.rs | 3 +- flareon-codegen/Cargo.toml | 5 + flareon-codegen/src/expr.rs | 280 +++++++-- flareon-codegen/src/lib.rs | 16 + flareon-codegen/src/model.rs | 203 ++++++- flareon-codegen/src/symbol_resolver.rs | 479 +++++++++++++++ flareon-macros/src/model.rs | 61 +- flareon-macros/src/query.rs | 38 +- flareon-macros/tests/compile_tests.rs | 3 + .../tests/ui/attr_model_multiple_pks.rs | 11 + .../tests/ui/attr_model_multiple_pks.stderr | 5 + flareon-macros/tests/ui/attr_model_no_pk.rs | 8 + .../tests/ui/attr_model_no_pk.stderr | 5 + .../ui/func_query_method_call_on_db_field.rs | 14 + .../func_query_method_call_on_db_field.stderr | 5 + flareon/Cargo.toml | 1 + flareon/src/auth.rs | 3 +- flareon/src/db.rs | 299 +++++++--- flareon/src/db/fields.rs | 162 ++++- flareon/src/db/impl_mysql.rs | 8 + flareon/src/db/impl_postgres.rs | 8 + flareon/src/db/impl_sqlite.rs | 19 +- flareon/src/db/migrations.rs | 52 +- flareon/src/db/query.rs | 68 ++- flareon/src/db/relations.rs | 109 ++++ flareon/src/db/sea_query_db.rs | 5 +- flareon/tests/db.rs | 178 +++++- 31 files changed, 1966 insertions(+), 709 deletions(-) create mode 100644 flareon-codegen/src/symbol_resolver.rs create mode 100644 flareon-macros/tests/ui/attr_model_multiple_pks.rs create mode 100644 flareon-macros/tests/ui/attr_model_multiple_pks.stderr create mode 100644 flareon-macros/tests/ui/attr_model_no_pk.rs create mode 100644 flareon-macros/tests/ui/attr_model_no_pk.stderr create mode 100644 flareon-macros/tests/ui/func_query_method_call_on_db_field.rs create mode 100644 flareon-macros/tests/ui/func_query_method_call_on_db_field.stderr create mode 100644 flareon/src/db/relations.rs diff --git a/Cargo.lock b/Cargo.lock index 8c4fd88..cf78bf3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -876,6 +876,7 @@ dependencies = [ "chrono", "derive_builder", "derive_more", + "env_logger", "fake", "flareon_macros", "form_urlencoded", @@ -936,6 +937,7 @@ version = "0.1.0" dependencies = [ "convert_case", "darling", + "log", "proc-macro2", "quote", "syn", diff --git a/flareon-cli/Cargo.toml b/flareon-cli/Cargo.toml index 58b6180..4b4a827 100644 --- a/flareon-cli/Cargo.toml +++ b/flareon-cli/Cargo.toml @@ -17,7 +17,7 @@ clap-verbosity-flag.workspace = true darling.workspace = true env_logger.workspace = true flareon.workspace = true -flareon_codegen.workspace = true +flareon_codegen = { workspace = true, features = ["symbol-resolver"] } glob.workspace = true log.workspace = true prettyplease.workspace = true diff --git a/flareon-cli/src/migration_generator.rs b/flareon-cli/src/migration_generator.rs index 7b2e976..9d76d97 100644 --- a/flareon-cli/src/migration_generator.rs +++ b/flareon-cli/src/migration_generator.rs @@ -10,6 +10,7 @@ use cargo_toml::Manifest; use darling::FromMeta; use flareon::db::migrations::{DynMigration, MigrationEngine}; use flareon_codegen::model::{Field, Model, ModelArgs, ModelOpts, ModelType}; +use flareon_codegen::symbol_resolver::{ModulePath, SymbolResolver, VisibleSymbol}; use log::{debug, info, warn}; use proc_macro2::TokenStream; use quote::{format_ident, quote}; @@ -73,34 +74,47 @@ impl MigrationGenerator { fn generate_and_write_migrations(&mut self) -> anyhow::Result<()> { let source_files = self.get_source_files()?; - if let Some(migration) = self.generate_migrations(source_files)? { + if let Some(migration) = self.generate_migrations_to_write(source_files)? { self.write_migration(migration)?; } Ok(()) } + pub fn generate_migrations_to_write( + &mut self, + source_files: Vec, + ) -> anyhow::Result> { + if let Some(migration) = self.generate_migrations(source_files)? { + let migration_name = migration.migration_name.clone(); + let content = self.generate_migration_file_content(migration); + Ok(Some(MigrationAsSource::new(migration_name, content))) + } else { + Ok(None) + } + } + pub fn generate_migrations( &mut self, source_files: Vec, - ) -> anyhow::Result> { + ) -> anyhow::Result> { let AppState { models, migrations } = self.process_source_files(source_files)?; let migration_processor = MigrationProcessor::new(migrations)?; let migration_models = migration_processor.latest_models(); - let (modified_models, operations) = self.generate_operations(&models, &migration_models); + let (modified_models, operations) = self.generate_operations(&models, &migration_models); if operations.is_empty() { Ok(None) } else { let migration_name = migration_processor.next_migration_name()?; - let dependencies = migration_processor.dependencies(); - let content = self.generate_migration_file_content( - &migration_name, - &modified_models, + let dependencies = migration_processor.base_dependencies(); + + Ok(Some(GeneratedMigration { + migration_name, + modified_models, dependencies, operations, - ); - Ok(Some(MigrationToWrite::new(migration_name, content))) + })) } } @@ -173,18 +187,18 @@ impl MigrationGenerator { }: SourceFile, app_state: &mut AppState, ) -> anyhow::Result<()> { - let imports = Self::get_imports(&file, &ModulePath::from_fs_path(&path)); - let import_resolver = SymbolResolver::new(imports); + let symbol_resolver = SymbolResolver::from_file(&file, &path); let mut migration_models = Vec::new(); for item in file.items { if let syn::Item::Struct(mut item) = item { for attr in &item.attrs.clone() { if is_model_attr(attr) { - import_resolver.resolve_struct(&mut item); + symbol_resolver.resolve_struct(&mut item); let args = Self::args_from_attr(&path, attr)?; - let model_in_source = ModelInSource::from_item(item, &args)?; + let model_in_source = + ModelInSource::from_item(item, &args, &symbol_resolver)?; match args.model_type { ModelType::Application => app_state.models.push(model_in_source), @@ -214,29 +228,6 @@ impl MigrationGenerator { Ok(()) } - /// Return the list of top-level `use` statements, structs, and constants as - /// a list of [`VisibleSymbol`]s from the file. - fn get_imports(file: &syn::File, module_path: &ModulePath) -> Vec { - let mut imports = Vec::new(); - - for item in &file.items { - match item { - syn::Item::Use(item) => { - imports.append(&mut VisibleSymbol::from_item_use(item, module_path)); - } - syn::Item::Struct(item_struct) => { - imports.push(VisibleSymbol::from_item_struct(item_struct, module_path)); - } - syn::Item::Const(item_const) => { - imports.push(VisibleSymbol::from_item_const(item_const, module_path)); - } - _ => {} - } - } - - imports - } - fn args_from_attr(path: &Path, attr: &Attribute) -> Result { match attr.meta { Meta::Path(_) => { @@ -398,23 +389,20 @@ impl MigrationGenerator { todo!() } - fn generate_migration_file_content( - &self, - migration_name: &str, - modified_models: &[ModelInSource], - dependencies: Vec, - operations: Vec, - ) -> String { - let operations: Vec<_> = operations + fn generate_migration_file_content(&self, migration: GeneratedMigration) -> String { + let operations: Vec<_> = migration + .operations .into_iter() .map(|operation| operation.repr()) .collect(); - let dependencies: Vec<_> = dependencies + let dependencies: Vec<_> = migration + .dependencies .into_iter() .map(|dependency| dependency.repr()) .collect(); let app_name = self.options.app_name.as_ref().unwrap_or(&self.crate_name); + let migration_name = &migration.migration_name; let migration_def = quote! { #[derive(Debug, Copy, Clone)] pub(super) struct Migration; @@ -431,7 +419,8 @@ impl MigrationGenerator { } }; - let models = modified_models + let models = migration + .modified_models .iter() .map(Self::model_to_migration_model) .collect::>(); @@ -442,7 +431,7 @@ impl MigrationGenerator { Self::generate_migration(migration_def, models_def) } - fn write_migration(&self, migration: MigrationToWrite) -> anyhow::Result<()> { + fn write_migration(&self, migration: MigrationAsSource) -> anyhow::Result<()> { let src_path = self .options .output_dir @@ -547,290 +536,6 @@ impl AppState { } } -/// Represents a symbol visible in the current module. This might mean there is -/// a `use` statement for a given type, but also, for instance, the type is -/// defined in the current module. -/// -/// For instance, for `use std::collections::HashMap;` the `VisibleSymbol ` -/// would be: -/// ```ignore -/// # /* -/// VisibleSymbol { -/// alias: "HashMap", -/// full_path: "std::collections::HashMap", -/// kind: VisibleSymbolKind::Use, -/// } -/// # */ -/// ``` -#[derive(Debug, Clone, PartialEq, Eq)] -struct VisibleSymbol { - alias: String, - full_path: String, - kind: VisibleSymbolKind, -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -enum VisibleSymbolKind { - Use, - Struct, - Const, -} - -impl VisibleSymbol { - #[must_use] - fn new(alias: &str, full_path: &str, kind: VisibleSymbolKind) -> Self { - Self { - alias: alias.to_string(), - full_path: full_path.to_string(), - kind, - } - } - - fn full_path_parts(&self) -> impl Iterator { - self.full_path.split("::") - } - - fn new_use(alias: &str, full_path: &str) -> Self { - Self::new(alias, full_path, VisibleSymbolKind::Use) - } - - fn from_item_use(item: &syn::ItemUse, module_path: &ModulePath) -> Vec { - Self::from_tree(&item.tree, module_path) - } - - fn from_item_struct(item: &syn::ItemStruct, module_path: &ModulePath) -> Self { - let ident = item.ident.to_string(); - let full_path = Self::module_path(module_path, &ident); - - Self { - alias: ident, - full_path, - kind: VisibleSymbolKind::Struct, - } - } - - fn from_item_const(item: &syn::ItemConst, module_path: &ModulePath) -> Self { - let ident = item.ident.to_string(); - let full_path = Self::module_path(module_path, &ident); - - Self { - alias: ident, - full_path, - kind: VisibleSymbolKind::Const, - } - } - - fn module_path(module_path: &ModulePath, ident: &str) -> String { - format!("{module_path}::{ident}") - } - - fn from_tree(tree: &UseTree, current_module: &ModulePath) -> Vec { - match tree { - UseTree::Path(path) => { - let ident = path.ident.to_string(); - let resolved_path = if ident == "crate" { - current_module.crate_name().to_string() - } else if ident == "self" { - current_module.to_string() - } else if ident == "super" { - current_module.parent().to_string() - } else { - ident - }; - - return Self::from_tree(&path.tree, current_module) - .into_iter() - .map(|import| { - Self::new_use( - &import.alias, - &format!("{}::{}", resolved_path, import.full_path), - ) - }) - .collect(); - } - UseTree::Name(name) => { - let ident = name.ident.to_string(); - return vec![Self::new_use(&ident, &ident)]; - } - UseTree::Rename(rename) => { - return vec![Self::new_use( - &rename.rename.to_string(), - &rename.ident.to_string(), - )]; - } - UseTree::Glob(_) => { - warn!("Glob imports are not supported"); - } - UseTree::Group(group) => { - return group - .items - .iter() - .flat_map(|tree| Self::from_tree(tree, current_module)) - .collect(); - } - } - - vec![] - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -struct ModulePath { - parts: Vec, -} - -impl ModulePath { - #[must_use] - fn from_fs_path(path: &Path) -> Self { - let mut parts = vec![String::from("crate")]; - - if path == Path::new("lib.rs") || path == Path::new("main.rs") { - return Self { parts }; - } - - parts.append( - &mut path - .components() - .map(|c| { - let component_str = c.as_os_str().to_string_lossy(); - component_str - .strip_suffix(".rs") - .unwrap_or(&component_str) - .to_string() - }) - .collect::>(), - ); - - if parts - .last() - .expect("parts must have at least one component") - == "mod" - { - parts.pop(); - } - - Self { parts } - } - - #[must_use] - fn parent(&self) -> Self { - let mut parts = self.parts.clone(); - parts.pop(); - Self { parts } - } - - #[must_use] - fn crate_name(&self) -> &str { - &self.parts[0] - } -} - -impl Display for ModulePath { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.parts.join("::")) - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -struct SymbolResolver { - /// List of imports in the format `"HashMap" -> VisibleSymbol` - symbols: HashMap, -} - -impl SymbolResolver { - #[must_use] - fn new(symbols: Vec) -> Self { - let mut symbol_map = HashMap::new(); - for symbol in symbols { - symbol_map.insert(symbol.alias.clone(), symbol); - } - - Self { - symbols: symbol_map, - } - } - - fn resolve_struct(&self, item: &mut syn::ItemStruct) { - for field in &mut item.fields { - if let syn::Type::Path(path) = &mut field.ty { - self.resolve(path); - } - } - } - - /// Checks the provided `TypePath` and resolves the full type path, if - /// available. - fn resolve(&self, path: &mut syn::TypePath) { - let first_segment = path.path.segments.first(); - - if let Some(first_segment) = first_segment { - if let Some(symbol) = self.symbols.get(&first_segment.ident.to_string()) { - let mut new_segments: Vec<_> = symbol - .full_path_parts() - .map(|s| syn::PathSegment { - ident: syn::Ident::new(s, first_segment.ident.span()), - arguments: syn::PathArguments::None, - }) - .collect(); - - let first_arguments = first_segment.arguments.clone(); - new_segments - .last_mut() - .expect("new_segments must have at least one element") - .arguments = first_arguments; - - new_segments.extend(path.path.segments.iter().skip(1).cloned()); - path.path.segments = syn::punctuated::Punctuated::from_iter(new_segments); - } - - for segment in &mut path.path.segments { - self.resolve_path_arguments(&mut segment.arguments); - } - } - } - - fn resolve_path_arguments(&self, arguments: &mut syn::PathArguments) { - if let syn::PathArguments::AngleBracketed(args) = arguments { - for arg in &mut args.args { - self.resolve_generic_argument(arg); - } - } - } - - fn resolve_generic_argument(&self, arg: &mut syn::GenericArgument) { - if let syn::GenericArgument::Type(syn::Type::Path(path)) = arg { - if let Some(new_arg) = self.try_resolve_generic_const(path) { - *arg = new_arg; - } else { - self.resolve(path); - } - } - } - - fn try_resolve_generic_const(&self, path: &syn::TypePath) -> Option { - if path.qself.is_none() && path.path.segments.len() == 1 { - let segment = path - .path - .segments - .first() - .expect("segments have exactly one element"); - if segment.arguments.is_none() { - let ident = segment.ident.to_string(); - if let Some(symbol) = self.symbols.get(&ident) { - if symbol.kind == VisibleSymbolKind::Const { - let path = &symbol.full_path; - return Some(syn::GenericArgument::Const( - syn::parse_str(path).expect("full_path should be a valid path"), - )); - } - } - } - } - - None - } -} - /// Helper struct to process already existing migrations. #[derive(Debug, Clone)] struct MigrationProcessor { @@ -886,7 +591,9 @@ impl MigrationProcessor { Ok(format!("m_{migration_number:04}_auto_{date_time}")) } - fn dependencies(&self) -> Vec { + /// Returns the list of dependencies for the next migration, based on the + /// already existing and processed migrations. + fn base_dependencies(&self) -> Vec { if self.migrations.is_empty() { return Vec::new(); } @@ -899,18 +606,22 @@ impl MigrationProcessor { } } -#[derive(Debug, Clone, PartialEq, Eq)] -struct ModelInSource { +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ModelInSource { model_item: syn::ItemStruct, model: Model, } impl ModelInSource { - fn from_item(item: syn::ItemStruct, args: &ModelArgs) -> anyhow::Result { + fn from_item( + item: syn::ItemStruct, + args: &ModelArgs, + symbol_resolver: &SymbolResolver, + ) -> anyhow::Result { let input: syn::DeriveInput = item.clone().into(); let opts = ModelOpts::new_from_derive_input(&input) .map_err(|e| anyhow::anyhow!("cannot parse model: {}", e))?; - let model = opts.as_model(args)?; + let model = opts.as_model(args, Some(symbol_resolver))?; Ok(Self { model_item: item, @@ -919,13 +630,24 @@ impl ModelInSource { } } +/// A migration generated by the CLI and before converting to a Rust +/// source code and writing to a file. #[derive(Debug, Clone)] -pub struct MigrationToWrite { +pub struct GeneratedMigration { + pub migration_name: String, + pub modified_models: Vec, + pub dependencies: Vec, + pub operations: Vec, +} + +/// A migration represented as a generated and ready to write Rust source code. +#[derive(Debug, Clone)] +pub struct MigrationAsSource { pub name: String, pub content: String, } -impl MigrationToWrite { +impl MigrationAsSource { #[must_use] pub fn new(name: String, content: String) -> Self { Self { name, content } @@ -954,7 +676,10 @@ impl Repr for Field { let mut tokens = quote! { ::flareon::db::migrations::Field::new(::flareon::db::Identifier::new(#column_name), <#ty as ::flareon::db::DatabaseField>::TYPE) }; - if self.auto_value { + if self + .auto_value + .expect("auto_value is expected to be present when parsing the entire file") + { tokens = quote! { #tokens.auto() } } if self.primary_key { @@ -998,7 +723,7 @@ impl DynMigration for Migration { /// /// This is used to generate migration files. #[derive(Debug, Clone, PartialEq, Eq, Hash)] -enum DynDependency { +pub enum DynDependency { Migration { app: String, migration: String }, Model { app: String, model_name: String }, } @@ -1024,8 +749,8 @@ impl Repr for DynDependency { /// runtime and is using codegen types. /// /// This is used to generate migration files. -#[derive(Debug, Clone)] -enum DynOperation { +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum DynOperation { CreateModel { table_name: String, fields: Vec, @@ -1101,8 +826,6 @@ impl Error for ParsingError {} #[cfg(test)] mod tests { - use quote::ToTokens; - use super::*; #[test] @@ -1119,7 +842,7 @@ mod tests { let migrations = vec![]; let processor = MigrationProcessor::new(migrations).unwrap(); - let next_migration_name = processor.dependencies(); + let next_migration_name = processor.base_dependencies(); assert_eq!(next_migration_name, vec![]); } @@ -1132,7 +855,7 @@ mod tests { }]; let processor = MigrationProcessor::new(migrations).unwrap(); - let next_migration_name = processor.dependencies(); + let next_migration_name = processor.base_dependencies(); assert_eq!( next_migration_name, vec![DynDependency::Migration { @@ -1141,147 +864,4 @@ mod tests { }] ); } - - #[test] - fn imports() { - let source = r" -use std::collections::HashMap; -use std::error::Error as StdError; -use std::fmt::{Debug, Display, Formatter}; -use std::fs::*; -use rand as r; -use super::MyModel; -use crate::MyOtherModel; -use self::MyThirdModel; - -struct MyFourthModel {} - -const MY_CONSTANT: u8 = 42; - "; - - let file = SourceFile::parse(PathBuf::from("foo/bar.rs").clone(), source).unwrap(); - let imports = - MigrationGenerator::get_imports(&file.content, &ModulePath::from_fs_path(&file.path)); - - let expected = vec![ - VisibleSymbol { - alias: "HashMap".to_string(), - full_path: "std::collections::HashMap".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "StdError".to_string(), - full_path: "std::error::Error".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "Debug".to_string(), - full_path: "std::fmt::Debug".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "Display".to_string(), - full_path: "std::fmt::Display".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "Formatter".to_string(), - full_path: "std::fmt::Formatter".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "r".to_string(), - full_path: "rand".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "MyModel".to_string(), - full_path: "crate::foo::MyModel".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "MyOtherModel".to_string(), - full_path: "crate::MyOtherModel".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "MyThirdModel".to_string(), - full_path: "crate::foo::bar::MyThirdModel".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "MyFourthModel".to_string(), - full_path: "crate::foo::bar::MyFourthModel".to_string(), - kind: VisibleSymbolKind::Struct, - }, - VisibleSymbol { - alias: "MY_CONSTANT".to_string(), - full_path: "crate::foo::bar::MY_CONSTANT".to_string(), - kind: VisibleSymbolKind::Const, - }, - ]; - assert_eq!(imports, expected); - } - - #[test] - fn import_resolver() { - let resolver = SymbolResolver::new(vec![ - VisibleSymbol::new_use("MyType", "crate::models::MyType"), - VisibleSymbol::new_use("HashMap", "std::collections::HashMap"), - ]); - - let path = &mut parse_quote!(MyType); - resolver.resolve(path); - assert_eq!( - quote!(crate::models::MyType).to_string(), - path.into_token_stream().to_string() - ); - - let path = &mut parse_quote!(HashMap); - resolver.resolve(path); - assert_eq!( - quote!(std::collections::HashMap).to_string(), - path.into_token_stream().to_string() - ); - - let path = &mut parse_quote!(Option); - resolver.resolve(path); - assert_eq!( - quote!(Option).to_string(), - path.into_token_stream().to_string() - ); - } - - #[test] - fn import_resolver_resolve_struct() { - let resolver = SymbolResolver::new(vec![ - VisibleSymbol::new_use("MyType", "crate::models::MyType"), - VisibleSymbol::new_use("HashMap", "std::collections::HashMap"), - VisibleSymbol::new_use("LimitedString", "flareon::db::LimitedString"), - VisibleSymbol::new( - "MY_CONSTANT", - "crate::constants::MY_CONSTANT", - VisibleSymbolKind::Const, - ), - ]); - - let mut actual = parse_quote! { - struct Example { - field_1: MyType, - field_2: HashMap, - field_3: Option, - field_4: LimitedString, - } - }; - resolver.resolve_struct(&mut actual); - let expected = quote! { - struct Example { - field_1: crate::models::MyType, - field_2: std::collections::HashMap, - field_3: Option, - field_4: flareon::db::LimitedString<{ crate::constants::MY_CONSTANT }>, - } - }; - assert_eq!(actual.into_token_stream().to_string(), expected.to_string()); - } } diff --git a/flareon-cli/tests/migration_generator.rs b/flareon-cli/tests/migration_generator.rs index fadbc33..4efefff 100644 --- a/flareon-cli/tests/migration_generator.rs +++ b/flareon-cli/tests/migration_generator.rs @@ -1,24 +1,61 @@ use std::path::PathBuf; use flareon_cli::migration_generator::{ - MigrationGenerator, MigrationGeneratorOptions, MigrationToWrite, SourceFile, + DynOperation, MigrationAsSource, MigrationGenerator, MigrationGeneratorOptions, SourceFile, }; -/// Test that the migration generator can generate a create model migration for -/// a given model which compiles successfully. +/// Test that the migration generator can generate a "create model" migration +/// for a given model that has an expected state. +#[test] +fn create_model_state_test() { + let mut generator = test_generator(); + let src = include_str!("migration_generator/create_model.rs"); + let source_files = vec![SourceFile::parse(PathBuf::from("main.rs"), src).unwrap()]; + + let migration = generator + .generate_migrations(source_files) + .unwrap() + .unwrap(); + + assert_eq!(migration.migration_name, "m_0001_initial"); + assert!(migration.dependencies.is_empty()); + if let DynOperation::CreateModel { table_name, fields } = &migration.operations[0] { + assert_eq!(table_name, "my_model"); + assert_eq!(fields.len(), 3); + + let field = &fields[0]; + assert_eq!(field.column_name, "id"); + assert!(field.primary_key); + assert!(field.auto_value.unwrap()); + assert!(!field.foreign_key.unwrap()); + + let field = &fields[1]; + assert_eq!(field.column_name, "field_1"); + assert!(!field.primary_key); + assert!(!field.auto_value.unwrap()); + assert!(!field.foreign_key.unwrap()); + + let field = &fields[2]; + assert_eq!(field.column_name, "field_2"); + assert!(!field.primary_key); + assert!(!field.auto_value.unwrap()); + assert!(!field.foreign_key.unwrap()); + } +} + +/// Test that the migration generator can generate a "create model" migration +/// for a given model which compiles successfully. #[test] #[cfg_attr(miri, ignore)] // unsupported operation: extern static `pidfd_spawnp` is not supported by Miri fn create_model_compile_test() { - let mut generator = MigrationGenerator::new( - PathBuf::from("Cargo.toml"), - String::from("my_crate"), - MigrationGeneratorOptions::default(), - ); + let mut generator = test_generator(); let src = include_str!("migration_generator/create_model.rs"); let source_files = vec![SourceFile::parse(PathBuf::from("main.rs"), src).unwrap()]; - let migration_opt = generator.generate_migrations(source_files).unwrap(); - let MigrationToWrite { + let migration_opt = generator + .generate_migrations_to_write(source_files) + .unwrap(); + let MigrationAsSource { name: migration_name, content: migration_content, } = migration_opt.unwrap(); @@ -41,3 +78,11 @@ mod migrations {{ let t = trybuild::TestCases::new(); t.pass(&test_path); } + +fn test_generator() -> MigrationGenerator { + MigrationGenerator::new( + PathBuf::from("Cargo.toml"), + String::from("my_crate"), + MigrationGeneratorOptions::default(), + ) +} diff --git a/flareon-cli/tests/migration_generator/create_model.rs b/flareon-cli/tests/migration_generator/create_model.rs index a249d4d..fd19eab 100644 --- a/flareon-cli/tests/migration_generator/create_model.rs +++ b/flareon-cli/tests/migration_generator/create_model.rs @@ -1,9 +1,10 @@ -use flareon::db::{model, LimitedString}; +use flareon::db::{model, Auto, LimitedString}; pub const FIELD_LEN: u32 = 64; #[model] struct MyModel { + id: Auto, field_1: String, field_2: LimitedString, } diff --git a/flareon-codegen/Cargo.toml b/flareon-codegen/Cargo.toml index 84f0ca4..97e7437 100644 --- a/flareon-codegen/Cargo.toml +++ b/flareon-codegen/Cargo.toml @@ -11,9 +11,14 @@ workspace = true [dependencies] convert_case.workspace = true darling.workspace = true +log = { workspace = true, optional = true } proc-macro2.workspace = true quote.workspace = true syn.workspace = true [dev-dependencies] proc-macro2 = { workspace = true, features = ["span-locations"] } + +[features] +default = [] +symbol-resolver = ["dep:log"] diff --git a/flareon-codegen/src/expr.rs b/flareon-codegen/src/expr.rs index f946b6e..806142f 100644 --- a/flareon-codegen/src/expr.rs +++ b/flareon-codegen/src/expr.rs @@ -9,7 +9,9 @@ enum ItemToken { Field(FieldParser), Literal(syn::Lit), Ident(syn::Ident), - MethodCall(MethodCallParser), + MemberAccess(MemberAccessParser), + FunctionCall(FunctionCallParser), + Reference(ReferenceParser), Op(OpParser), } @@ -23,8 +25,12 @@ impl Parse for ItemToken { if lookahead.peek(Token![$]) { input.parse().map(ItemToken::Field) + } else if lookahead.peek(Token![&]) { + input.parse().map(ItemToken::Reference) } else if lookahead.peek(Token![.]) { - input.parse().map(ItemToken::MethodCall) + input.parse().map(ItemToken::MemberAccess) + } else if lookahead.peek(syn::token::Paren) { + input.parse().map(ItemToken::FunctionCall) } else if lookahead.peek(syn::Lit) { input.parse().map(ItemToken::Literal) } else if lookahead.peek(syn::Ident) { @@ -41,7 +47,9 @@ impl ItemToken { ItemToken::Field(field) => field.span(), ItemToken::Literal(lit) => lit.span(), ItemToken::Ident(ident) => ident.span(), - ItemToken::MethodCall(method_call) => method_call.span(), + ItemToken::MemberAccess(member_access) => member_access.span(), + ItemToken::FunctionCall(function_call) => function_call.span(), + ItemToken::Reference(reference) => reference.span(), ItemToken::Op(op) => op.span(), } } @@ -49,7 +57,7 @@ impl ItemToken { #[derive(Debug)] struct FieldParser { - _field_token: Token![$], + field_token: Token![$], name: syn::Ident, } @@ -63,34 +71,74 @@ impl FieldParser { impl Parse for FieldParser { fn parse(input: ParseStream) -> syn::Result { Ok(FieldParser { - _field_token: input.parse()?, + field_token: input.parse()?, name: input.parse()?, }) } } #[derive(Debug)] -struct MethodCallParser { - _dot: Token![.], - method_name: syn::Ident, - _paren_token: syn::token::Paren, +struct ReferenceParser { + reference_token: Token![&], + expr: syn::Expr, +} + +impl ReferenceParser { + #[must_use] + fn span(&self) -> proc_macro2::Span { + self.expr.span() + } +} + +impl Parse for ReferenceParser { + fn parse(input: ParseStream) -> syn::Result { + Ok(ReferenceParser { + reference_token: input.parse()?, + expr: input.parse()?, + }) + } +} + +#[derive(Debug)] +struct MemberAccessParser { + dot: Token![.], + member_name: syn::Ident, +} + +impl MemberAccessParser { + #[must_use] + fn span(&self) -> proc_macro2::Span { + self.member_name.span() + } +} + +impl Parse for MemberAccessParser { + fn parse(input: ParseStream) -> syn::Result { + Ok(Self { + dot: input.parse()?, + member_name: input.parse()?, + }) + } +} + +#[derive(Debug)] +struct FunctionCallParser { + paren_token: syn::token::Paren, args: syn::punctuated::Punctuated, } -impl MethodCallParser { +impl FunctionCallParser { #[must_use] fn span(&self) -> proc_macro2::Span { - self.method_name.span() + self.args.span() } } -impl Parse for MethodCallParser { +impl Parse for FunctionCallParser { fn parse(input: ParseStream) -> syn::Result { let args_content; Ok(Self { - _dot: input.parse()?, - method_name: input.parse()?, - _paren_token: syn::parenthesized!(args_content in input), + paren_token: syn::parenthesized!(args_content in input), args: args_content.parse_terminated(syn::Expr::parse, Token![,])?, }) } @@ -202,18 +250,25 @@ type InfixBindingPriority = BindingPriority; /// assert_eq!( /// expr, /// Expr::Eq( -/// Box::new(Expr::FieldRef(parse_quote!(field))), +/// Box::new(Expr::FieldRef { field_name: parse_quote!(field), field_token: parse_quote!($)}), /// Box::new(Expr::Value(parse_quote!(42))) /// ) /// ); /// ``` #[derive(Debug, PartialEq, Eq)] pub enum Expr { - FieldRef(syn::Ident), + FieldRef { + field_name: syn::Ident, + field_token: Token![$], + }, Value(syn::Expr), - MethodCall { - called_on: Box, - method_name: syn::Ident, + MemberAccess { + parent: Box, + member_name: syn::Ident, + member_access_token: Token![.], + }, + FunctionCall { + function: Box, args: Vec, }, And(Box, Box), @@ -247,7 +302,18 @@ impl Expr { let lhs_item = input.parse::()?; match lhs_item { - ItemToken::Field(field) => Expr::FieldRef(field.name), + ItemToken::Field(field) => Expr::FieldRef { + field_name: field.name, + field_token: field.field_token, + }, + ItemToken::Reference(reference) => { + Expr::Value(syn::Expr::Reference(syn::ExprReference { + attrs: Vec::new(), + and_token: reference.reference_token, + mutability: None, + expr: Box::new(reference.expr), + })) + } ItemToken::Ident(ident) => Expr::Value(syn::Expr::Path(syn::ExprPath { attrs: Vec::new(), qself: None, @@ -273,12 +339,19 @@ impl Expr { let op_item = input.fork().parse::()?; match op_item { - ItemToken::MethodCall(call) => { + ItemToken::MemberAccess(member_access) => { + input.parse::()?; + lhs = Expr::MemberAccess { + parent: Box::new(lhs), + member_name: member_access.member_name, + member_access_token: member_access.dot, + }; + } + ItemToken::FunctionCall(call) => { input.parse::()?; let args = call.args.into_iter().collect::>(); - lhs = Expr::MethodCall { - called_on: Box::new(lhs), - method_name: call.method_name, + lhs = Expr::FunctionCall { + function: Box::new(lhs), args, }; } @@ -321,61 +394,88 @@ impl Expr { #[must_use] pub fn as_tokens(&self) -> Option { + self.as_tokens_impl(ExprAsTokensMode::FieldRefAsNone) + } + + #[must_use] + pub fn as_tokens_full(&self) -> TokenStream { + self.as_tokens_impl(ExprAsTokensMode::Full) + .expect("Full mode should never return None") + } + + #[must_use] + fn as_tokens_impl(&self, mode: ExprAsTokensMode) -> Option { match self { - Expr::FieldRef(_) => None, + Expr::FieldRef { + field_name, + field_token, + } => match mode { + ExprAsTokensMode::FieldRefAsNone => None, + ExprAsTokensMode::Full => Some(quote! {#field_token #field_name}), + }, Expr::Value(expr) => Some(quote! {#expr}), - Expr::MethodCall { - called_on, - method_name, - args, + Expr::MemberAccess { + parent, + member_name, + member_access_token, } => { - let called_on_tokens = called_on.as_tokens()?; - Some(quote! {#called_on_tokens.#method_name(#(#args),*)}) + let parent_tokens = parent.as_tokens_impl(mode)?; + Some(quote! {#parent_tokens #member_access_token #member_name}) + } + Expr::FunctionCall { function, args } => { + let function_tokens = function.as_tokens_impl(mode)?; + Some(quote! {#function_tokens(#(#args),*)}) } Expr::And(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens && #rhs_tokens}) } Expr::Or(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens || #rhs_tokens}) } Expr::Eq(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens == #rhs_tokens}) } Expr::Ne(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens != #rhs_tokens}) } Expr::Add(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens + #rhs_tokens}) } Expr::Sub(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens - #rhs_tokens}) } Expr::Mul(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens * #rhs_tokens}) } Expr::Div(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens / #rhs_tokens}) } } } } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum ExprAsTokensMode { + FieldRefAsNone, + Full, +} + impl Parse for Expr { fn parse(input: ParseStream) -> syn::Result { Self::parse_impl(input, 0) @@ -393,7 +493,7 @@ mod tests { #[test] fn field_ref() { let input = quote! { $field }; - let expected = Expr::FieldRef(syn::Ident::new("field", span())); + let expected = field("field"); assert_eq!(expected, unwrap_syn(Expr::parse(input))); } @@ -410,7 +510,7 @@ mod tests { fn field_eq() { let input = quote! { $field == 42 }; let expected = Expr::Eq( - Box::new(Expr::FieldRef(syn::Ident::new("field", span()))), + Box::new(field("field")), Box::new(Expr::Value(parse_quote!(42))), ); @@ -439,11 +539,11 @@ mod tests { let input = quote! { $field == 42 && $field != 42 }; let expected = Expr::And( Box::new(Expr::Eq( - Box::new(Expr::FieldRef(syn::Ident::new("field", span()))), + Box::new(field("field")), Box::new(Expr::Value(parse_quote!(42))), )), Box::new(Expr::Ne( - Box::new(Expr::FieldRef(syn::Ident::new("field", span()))), + Box::new(field("field")), Box::new(Expr::Value(parse_quote!(42))), )), ); @@ -470,18 +570,52 @@ mod tests { assert_eq!(expected, unwrap_syn(Expr::parse(input))); } + #[test] + fn function_call() { + let input = quote! { $a == bar() }; + let expected = Expr::Eq( + Box::new(field("a")), + Box::new(Expr::FunctionCall { + function: Box::new(value("bar")), + args: Vec::new(), + }), + ); + + assert_eq!(expected, unwrap_syn(Expr::parse(input))); + } + + #[test] + fn parse_member_access() { + let input = quote! { $a == foo.bar }; + let expected = Expr::Eq( + Box::new(field("a")), + Box::new(member_access(value("foo"), "bar")), + ); + + assert_eq!(expected, unwrap_syn(Expr::parse(input))); + } + + #[test] + fn parse_reference() { + let input = quote! { &foo }; + let expected = reference("foo"); + + assert_eq!(expected, unwrap_syn(Expr::parse(input))); + } + #[test] fn method_call() { let input = quote! { $a == foo.bar().baz() }; let expected = Expr::Eq( Box::new(field("a")), - Box::new(Expr::MethodCall { - called_on: Box::new(Expr::MethodCall { - called_on: Box::new(value("foo")), - method_name: syn::Ident::new("bar", span()), - args: Vec::new(), - }), - method_name: syn::Ident::new("baz", span()), + Box::new(Expr::FunctionCall { + function: Box::new(member_access( + Expr::FunctionCall { + function: Box::new(member_access(value("foo"), "bar")), + args: Vec::new(), + }, + "baz", + )), args: Vec::new(), }), ); @@ -569,9 +703,35 @@ mod tests { assert_eq!(input.to_string(), expr.as_tokens().unwrap().to_string()); } + #[test] + fn tokens_full() { + let input = quote! { $name.len() }; + let expr = unwrap_syn(Expr::parse(input.clone())); + + assert_eq!(input.to_string(), expr.as_tokens_full().to_string()); + } + #[must_use] fn field(name: &str) -> Expr { - Expr::FieldRef(syn::Ident::new(name, span())) + Expr::FieldRef { + field_name: syn::Ident::new(name, span()), + field_token: Token![$](span()), + } + } + + #[must_use] + fn member_access(parent: Expr, member_name: &str) -> Expr { + Expr::MemberAccess { + parent: Box::new(parent), + member_name: syn::Ident::new(member_name, span()), + member_access_token: Token![.](span()), + } + } + + #[must_use] + fn reference(ident: &str) -> Expr { + let ident = syn::Ident::new(ident, span()); + Expr::Value(parse_quote!(&#ident)) } #[must_use] diff --git a/flareon-codegen/src/lib.rs b/flareon-codegen/src/lib.rs index db4772e..8668603 100644 --- a/flareon-codegen/src/lib.rs +++ b/flareon-codegen/src/lib.rs @@ -2,3 +2,19 @@ extern crate self as flareon_codegen; pub mod expr; pub mod model; +#[cfg(feature = "symbol-resolver")] +pub mod symbol_resolver; +#[cfg(not(feature = "symbol-resolver"))] +pub mod symbol_resolver { + /// Dummy SymbolResolver for use in contexts when it's not useful (e.g. + /// macros which do not have access to the entire source tree to look + /// for `use` statements anyway). + /// + /// This is defined as an empty enum so that it's entirely optimized out by + /// the compiler, along with all functions that reference it. + pub enum SymbolResolver {} + + impl SymbolResolver { + pub fn resolve(&self, _: &mut syn::Type) {} + } +} diff --git a/flareon-codegen/src/model.rs b/flareon-codegen/src/model.rs index a9e4e8f..f9adb39 100644 --- a/flareon-codegen/src/model.rs +++ b/flareon-codegen/src/model.rs @@ -1,6 +1,8 @@ use convert_case::{Case, Casing}; use darling::{FromDeriveInput, FromField, FromMeta}; +use crate::symbol_resolver::SymbolResolver; + #[allow(clippy::module_name_repetitions)] #[derive(Debug, Default, FromMeta)] pub struct ModelArgs { @@ -59,8 +61,16 @@ impl ModelOpts { /// /// Returns an error if the model name does not start with an underscore /// when the model type is [`ModelType::Migration`]. - pub fn as_model(&self, args: &ModelArgs) -> Result { - let fields = self.fields().iter().map(|field| field.as_field()).collect(); + pub fn as_model( + &self, + args: &ModelArgs, + symbol_resolver: Option<&SymbolResolver>, + ) -> Result { + let fields: Vec<_> = self + .fields() + .iter() + .map(|field| field.as_field(symbol_resolver)) + .collect(); let mut original_name = self.ident.to_string(); if args.model_type == ModelType::Migration { @@ -80,14 +90,36 @@ impl ModelOpts { original_name.to_string().to_case(Case::Snake) }; + let primary_key_field = self.get_primary_key_field(&fields)?; + Ok(Model { name: self.ident.clone(), original_name, model_type: args.model_type, table_name, + pk_field: primary_key_field.clone(), fields, }) } + + fn get_primary_key_field<'a>(&self, fields: &'a [Field]) -> Result<&'a Field, syn::Error> { + let pks: Vec<_> = fields.iter().filter(|field| field.primary_key).collect(); + if pks.is_empty() { + return Err(syn::Error::new( + self.ident.span(), + "models must have a primary key field, either named `id` \ + or annotated with the `#[model(primary_key)]` attribute", + )); + } + if pks.len() > 1 { + return Err(syn::Error::new( + pks[1].field_name.span(), + "composite primary keys are not supported; only one primary key field is allowed", + )); + } + + Ok(pks[0]) + } } #[derive(Debug, Clone, FromField)] @@ -95,10 +127,50 @@ impl ModelOpts { pub struct FieldOpts { pub ident: Option, pub ty: syn::Type, + pub primary_key: darling::util::Flag, pub unique: darling::util::Flag, } impl FieldOpts { + #[must_use] + fn find_type(&self, type_to_check: &str, symbol_resolver: &SymbolResolver) -> bool { + let mut ty = self.ty.clone(); + symbol_resolver.resolve(&mut ty); + Self::inner_type_names(&ty) + .iter() + .any(|name| name == type_to_check) + } + + #[must_use] + fn inner_type_names(ty: &syn::Type) -> Vec { + let mut names = Vec::new(); + Self::inner_type_names_impl(ty, &mut names); + names + } + + fn inner_type_names_impl(ty: &syn::Type, names: &mut Vec) { + if let syn::Type::Path(type_path) = ty { + let name = type_path + .path + .segments + .iter() + .map(|s| s.ident.to_string()) + .collect::>() + .join("::"); + names.push(name); + + for arg in &type_path.path.segments { + if let syn::PathArguments::AngleBracketed(arg) = &arg.arguments { + for arg in &arg.args { + if let syn::GenericArgument::Type(ty) = arg { + Self::inner_type_names_impl(ty, names); + } + } + } + } + } + } + /// Convert the field options into a field. /// /// # Panics @@ -106,32 +178,37 @@ impl FieldOpts { /// Panics if the field does not have an identifier (i.e. it is a tuple /// struct). #[must_use] - pub fn as_field(&self) -> Field { + pub fn as_field(&self, symbol_resolver: Option<&SymbolResolver>) -> Field { let name = self.ident.as_ref().unwrap(); let column_name = name.to_string(); - // TODO define a separate type for auto fields - let is_auto = column_name == "id"; - // TODO define #[model(primary_key)] attribute - let is_primary_key = column_name == "id"; + let (auto_value, foreign_key) = match symbol_resolver { + Some(resolver) => ( + Some(self.find_type("flareon::db::Auto", resolver)), + Some(self.find_type("flareon::db::ForeignKey", resolver)), + ), + None => (None, None), + }; + let is_primary_key = column_name == "id" || self.primary_key.is_present(); Field { field_name: name.clone(), column_name, ty: self.ty.clone(), - auto_value: is_auto, + auto_value, primary_key: is_primary_key, - null: false, + foreign_key, unique: self.unique.is_present(), } } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Model { pub name: syn::Ident, pub original_name: String, pub model_type: ModelType, pub table_name: String, + pub pk_field: Field, pub fields: Vec, } @@ -147,9 +224,13 @@ pub struct Field { pub field_name: syn::Ident, pub column_name: String, pub ty: syn::Type, - pub auto_value: bool, + /// Whether the field is an auto field (e.g. `id`); `None` if it could not + /// be determined. + pub auto_value: Option, pub primary_key: bool, - pub null: bool, + /// Whether the field is a foreign key; `None` if it could not be + /// determined. + pub foreign_key: Option, pub unique: bool, } @@ -158,6 +239,8 @@ mod tests { use syn::parse_quote; use super::*; + #[cfg(feature = "symbol-resolver")] + use crate::symbol_resolver::{VisibleSymbol, VisibleSymbolKind}; #[test] fn model_args_default() { @@ -197,7 +280,7 @@ mod tests { }; let opts = ModelOpts::new_from_derive_input(&input).unwrap(); let args = ModelArgs::default(); - let model = opts.as_model(&args).unwrap(); + let model = opts.as_model(&args, None).unwrap(); assert_eq!(model.name.to_string(), "TestModel"); assert_eq!(model.table_name, "test_model"); assert_eq!(model.fields.len(), 2); @@ -215,13 +298,67 @@ mod tests { }; let opts = ModelOpts::new_from_derive_input(&input).unwrap(); let args = ModelArgs::from_meta(&input.attrs.first().unwrap().meta).unwrap(); - let err = opts.as_model(&args).unwrap_err(); + let err = opts.as_model(&args, None).unwrap_err(); assert_eq!( err.to_string(), "migration model names must start with an underscore" ); } + #[test] + fn model_opts_as_model_pk_attr() { + let input: syn::DeriveInput = parse_quote! { + #[model] + struct TestModel { + #[model(primary_key)] + name: i32, + } + }; + let opts = ModelOpts::new_from_derive_input(&input).unwrap(); + let args = ModelArgs::default(); + let model = opts.as_model(&args, None).unwrap(); + assert_eq!(model.fields.len(), 1); + assert!(model.fields[0].primary_key); + } + + #[test] + fn model_opts_as_model_no_pk() { + let input: syn::DeriveInput = parse_quote! { + #[model] + struct TestModel { + name: String, + } + }; + let opts = ModelOpts::new_from_derive_input(&input).unwrap(); + let args = ModelArgs::default(); + let err = opts.as_model(&args, None).unwrap_err(); + assert_eq!( + err.to_string(), + "models must have a primary key field, either named `id` \ + or annotated with the `#[model(primary_key)]` attribute" + ); + } + + #[test] + fn model_opts_as_model_multiple_pks() { + let input: syn::DeriveInput = parse_quote! { + #[model] + struct TestModel { + id: i64, + #[model(primary_key)] + id_2: i64, + name: String, + } + }; + let opts = ModelOpts::new_from_derive_input(&input).unwrap(); + let args = ModelArgs::default(); + let err = opts.as_model(&args, None).unwrap_err(); + assert_eq!( + err.to_string(), + "composite primary keys are not supported; only one primary key field is allowed" + ); + } + #[test] fn field_opts_as_field() { let input: syn::Field = parse_quote! { @@ -229,10 +366,46 @@ mod tests { name: String }; let field_opts = FieldOpts::from_field(&input).unwrap(); - let field = field_opts.as_field(); + let field = field_opts.as_field(None); assert_eq!(field.field_name.to_string(), "name"); assert_eq!(field.column_name, "name"); assert_eq!(field.ty, parse_quote!(String)); assert!(field.unique); + assert_eq!(field.auto_value, None); + assert_eq!(field.foreign_key, None); + } + + #[test] + fn inner_type_names() { + let input: syn::Type = + parse_quote! { ::my_crate::MyContainer<'a, Vec> }; + let names = FieldOpts::inner_type_names(&input); + assert_eq!( + names, + vec!["my_crate::MyContainer", "Vec", "std::string::String"] + ); + } + + #[cfg(feature = "symbol-resolver")] + #[test] + fn contains_type() { + let symbols = vec![VisibleSymbol::new( + "MyContainer", + "my_crate::MyContainer", + VisibleSymbolKind::Use, + )]; + let resolver = SymbolResolver::new(symbols); + + let opts = FieldOpts { + ident: None, + ty: parse_quote! { MyContainer }, + primary_key: Default::default(), + unique: Default::default(), + }; + + assert!(opts.find_type("my_crate::MyContainer", &resolver)); + assert!(opts.find_type("std::string::String", &resolver)); + assert!(!opts.find_type("MyContainer", &resolver)); + assert!(!opts.find_type("String", &resolver)); } } diff --git a/flareon-codegen/src/symbol_resolver.rs b/flareon-codegen/src/symbol_resolver.rs new file mode 100644 index 0000000..9434320 --- /dev/null +++ b/flareon-codegen/src/symbol_resolver.rs @@ -0,0 +1,479 @@ +#![cfg(feature = "symbol-resolver")] + +use std::collections::HashMap; +use std::fmt::Display; +use std::iter::FromIterator; +use std::path::{Path, PathBuf}; + +use log::warn; +use quote::quote; +use syn::{parse_quote, UseTree}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SymbolResolver { + /// List of imports in the format `"HashMap" -> VisibleSymbol` + symbols: HashMap, +} + +impl SymbolResolver { + #[must_use] + pub fn new(symbols: Vec) -> Self { + let mut symbol_map = HashMap::new(); + for symbol in symbols { + symbol_map.insert(symbol.alias.clone(), symbol); + } + + Self { + symbols: symbol_map, + } + } + + pub fn from_file(file: &syn::File, module_path: &Path) -> Self { + let imports = Self::get_imports(file, &ModulePath::from_fs_path(module_path)); + Self::new(imports) + } + + /// Return the list of top-level `use` statements, structs, and constants as + /// a list of [`VisibleSymbol`]s from the file. + fn get_imports(file: &syn::File, module_path: &ModulePath) -> Vec { + let mut imports = Vec::new(); + + for item in &file.items { + match item { + syn::Item::Use(item) => { + imports.append(&mut VisibleSymbol::from_item_use(item, module_path)); + } + syn::Item::Struct(item_struct) => { + imports.push(VisibleSymbol::from_item_struct(item_struct, module_path)); + } + syn::Item::Const(item_const) => { + imports.push(VisibleSymbol::from_item_const(item_const, module_path)); + } + _ => {} + } + } + + imports + } + + pub fn resolve_struct(&self, item: &mut syn::ItemStruct) { + for field in &mut item.fields { + self.resolve(&mut field.ty); + } + } + + pub fn resolve(&self, ty: &mut syn::Type) { + if let syn::Type::Path(path) = ty { + self.resolve_type_path(path); + } + } + + /// Checks the provided `TypePath` and resolves the full type path, if + /// available. + fn resolve_type_path(&self, path: &mut syn::TypePath) { + let first_segment = path.path.segments.first(); + + if let Some(first_segment) = first_segment { + if let Some(symbol) = self.symbols.get(&first_segment.ident.to_string()) { + let mut new_segments: Vec<_> = symbol + .full_path_parts() + .map(|s| syn::PathSegment { + ident: syn::Ident::new(s, first_segment.ident.span()), + arguments: syn::PathArguments::None, + }) + .collect(); + + let first_arguments = first_segment.arguments.clone(); + new_segments + .last_mut() + .expect("new_segments must have at least one element") + .arguments = first_arguments; + + new_segments.extend(path.path.segments.iter().skip(1).cloned()); + path.path.segments = syn::punctuated::Punctuated::from_iter(new_segments); + } + + for segment in &mut path.path.segments { + self.resolve_path_arguments(&mut segment.arguments); + } + } + } + + fn resolve_path_arguments(&self, arguments: &mut syn::PathArguments) { + if let syn::PathArguments::AngleBracketed(args) = arguments { + for arg in &mut args.args { + self.resolve_generic_argument(arg); + } + } + } + + fn resolve_generic_argument(&self, arg: &mut syn::GenericArgument) { + if let syn::GenericArgument::Type(syn::Type::Path(path)) = arg { + if let Some(new_arg) = self.try_resolve_generic_const(path) { + *arg = new_arg; + } else { + self.resolve_type_path(path); + } + } + } + + fn try_resolve_generic_const(&self, path: &syn::TypePath) -> Option { + if path.qself.is_none() && path.path.segments.len() == 1 { + let segment = path + .path + .segments + .first() + .expect("segments have exactly one element"); + if segment.arguments.is_none() { + let ident = segment.ident.to_string(); + if let Some(symbol) = self.symbols.get(&ident) { + if symbol.kind == VisibleSymbolKind::Const { + let path = &symbol.full_path; + return Some(syn::GenericArgument::Const( + syn::parse_str(path).expect("full_path should be a valid path"), + )); + } + } + } + } + + None + } +} + +/// Represents a symbol visible in the current module. This might mean there is +/// a `use` statement for a given type, but also, for instance, the type is +/// defined in the current module. +/// +/// For instance, for `use std::collections::HashMap;` the `VisibleSymbol ` +/// would be: +/// ``` +/// use flareon_codegen::symbol_resolver::{VisibleSymbol, VisibleSymbolKind}; +/// +/// let _ = VisibleSymbol { +/// alias: String::from("HashMap"), +/// full_path: String::from("std::collections::HashMap"), +/// kind: VisibleSymbolKind::Use, +/// }; +/// ``` +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct VisibleSymbol { + pub alias: String, + pub full_path: String, + pub kind: VisibleSymbolKind, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum VisibleSymbolKind { + Use, + Struct, + Const, +} + +impl VisibleSymbol { + #[must_use] + pub fn new(alias: &str, full_path: &str, kind: VisibleSymbolKind) -> Self { + assert_ne!(alias, "", "alias must not be empty"); + assert!(!alias.contains("::"), "alias must not contain '::'"); + Self { + alias: alias.to_string(), + full_path: full_path.to_string(), + kind, + } + } + + fn full_path_parts(&self) -> impl Iterator { + self.full_path.split("::") + } + + fn new_use(alias: &str, full_path: &str) -> Self { + Self::new(alias, full_path, VisibleSymbolKind::Use) + } + + fn from_item_use(item: &syn::ItemUse, module_path: &ModulePath) -> Vec { + Self::from_tree(&item.tree, module_path) + } + + fn from_item_struct(item: &syn::ItemStruct, module_path: &ModulePath) -> Self { + let ident = item.ident.to_string(); + let full_path = Self::module_path(module_path, &ident); + + Self { + alias: ident, + full_path, + kind: VisibleSymbolKind::Struct, + } + } + + fn from_item_const(item: &syn::ItemConst, module_path: &ModulePath) -> Self { + let ident = item.ident.to_string(); + let full_path = Self::module_path(module_path, &ident); + + Self { + alias: ident, + full_path, + kind: VisibleSymbolKind::Const, + } + } + + fn module_path(module_path: &ModulePath, ident: &str) -> String { + format!("{module_path}::{ident}") + } + + fn from_tree(tree: &UseTree, current_module: &ModulePath) -> Vec { + match tree { + UseTree::Path(path) => { + let ident = path.ident.to_string(); + let resolved_path = if ident == "crate" { + current_module.crate_name().to_string() + } else if ident == "self" { + current_module.to_string() + } else if ident == "super" { + current_module.parent().to_string() + } else { + ident + }; + + return Self::from_tree(&path.tree, current_module) + .into_iter() + .map(|import| { + Self::new_use( + &import.alias, + &format!("{}::{}", resolved_path, import.full_path), + ) + }) + .collect(); + } + UseTree::Name(name) => { + let ident = name.ident.to_string(); + return vec![Self::new_use(&ident, &ident)]; + } + UseTree::Rename(rename) => { + return vec![Self::new_use( + &rename.rename.to_string(), + &rename.ident.to_string(), + )]; + } + UseTree::Glob(_) => { + warn!("Glob imports are not supported"); + } + UseTree::Group(group) => { + return group + .items + .iter() + .flat_map(|tree| Self::from_tree(tree, current_module)) + .collect(); + } + } + + vec![] + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ModulePath { + parts: Vec, +} + +impl ModulePath { + #[must_use] + pub fn from_fs_path(path: &Path) -> Self { + let mut parts = vec![String::from("crate")]; + + if path == Path::new("lib.rs") || path == Path::new("main.rs") { + return Self { parts }; + } + + parts.append( + &mut path + .components() + .map(|c| { + let component_str = c.as_os_str().to_string_lossy(); + component_str + .strip_suffix(".rs") + .unwrap_or(&component_str) + .to_string() + }) + .collect::>(), + ); + + if parts + .last() + .expect("parts must have at least one component") + == "mod" + { + parts.pop(); + } + + Self { parts } + } + + #[must_use] + fn parent(&self) -> Self { + let mut parts = self.parts.clone(); + parts.pop(); + Self { parts } + } + + #[must_use] + fn crate_name(&self) -> &str { + &self.parts[0] + } +} + +impl Display for ModulePath { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.parts.join("::")) + } +} + +#[cfg(test)] +mod tests { + use flareon_codegen::symbol_resolver::VisibleSymbolKind; + use quote::ToTokens; + + use super::*; + + #[test] + fn imports() { + let source = r" +use std::collections::HashMap; +use std::error::Error as StdError; +use std::fmt::{Debug, Display, Formatter}; +use std::fs::*; +use rand as r; +use super::MyModel; +use crate::MyOtherModel; +use self::MyThirdModel; + +struct MyFourthModel {} + +const MY_CONSTANT: u8 = 42; + "; + + let file = syn::parse_file(source).unwrap(); + let imports = + SymbolResolver::get_imports(&file, &ModulePath::from_fs_path(Path::new("foo/bar.rs"))); + + let expected = vec![ + VisibleSymbol { + alias: "HashMap".to_string(), + full_path: "std::collections::HashMap".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "StdError".to_string(), + full_path: "std::error::Error".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "Debug".to_string(), + full_path: "std::fmt::Debug".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "Display".to_string(), + full_path: "std::fmt::Display".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "Formatter".to_string(), + full_path: "std::fmt::Formatter".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "r".to_string(), + full_path: "rand".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "MyModel".to_string(), + full_path: "crate::foo::MyModel".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "MyOtherModel".to_string(), + full_path: "crate::MyOtherModel".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "MyThirdModel".to_string(), + full_path: "crate::foo::bar::MyThirdModel".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "MyFourthModel".to_string(), + full_path: "crate::foo::bar::MyFourthModel".to_string(), + kind: VisibleSymbolKind::Struct, + }, + VisibleSymbol { + alias: "MY_CONSTANT".to_string(), + full_path: "crate::foo::bar::MY_CONSTANT".to_string(), + kind: VisibleSymbolKind::Const, + }, + ]; + assert_eq!(imports, expected); + } + + #[test] + fn import_resolver() { + let resolver = SymbolResolver::new(vec![ + VisibleSymbol::new_use("MyType", "crate::models::MyType"), + VisibleSymbol::new_use("HashMap", "std::collections::HashMap"), + ]); + + let path = &mut parse_quote!(MyType); + resolver.resolve_type_path(path); + assert_eq!( + quote!(crate::models::MyType).to_string(), + path.into_token_stream().to_string() + ); + + let path = &mut parse_quote!(HashMap); + resolver.resolve_type_path(path); + assert_eq!( + quote!(std::collections::HashMap).to_string(), + path.into_token_stream().to_string() + ); + + let path = &mut parse_quote!(Option); + resolver.resolve_type_path(path); + assert_eq!( + quote!(Option).to_string(), + path.into_token_stream().to_string() + ); + } + + #[test] + fn import_resolver_resolve_struct() { + let resolver = SymbolResolver::new(vec![ + VisibleSymbol::new_use("MyType", "crate::models::MyType"), + VisibleSymbol::new_use("HashMap", "std::collections::HashMap"), + VisibleSymbol::new_use("LimitedString", "flareon::db::LimitedString"), + VisibleSymbol::new( + "MY_CONSTANT", + "crate::constants::MY_CONSTANT", + VisibleSymbolKind::Const, + ), + ]); + + let mut actual = parse_quote! { + struct Example { + field_1: MyType, + field_2: HashMap, + field_3: Option, + field_4: LimitedString, + } + }; + resolver.resolve_struct(&mut actual); + let expected = quote! { + struct Example { + field_1: crate::models::MyType, + field_2: std::collections::HashMap, + field_3: Option, + field_4: flareon::db::LimitedString<{ crate::constants::MY_CONSTANT }>, + } + }; + assert_eq!(actual.into_token_stream().to_string(), expected.to_string()); + } +} diff --git a/flareon-macros/src/model.rs b/flareon-macros/src/model.rs index fc0ebc9..ce32829 100644 --- a/flareon-macros/src/model.rs +++ b/flareon-macros/src/model.rs @@ -27,7 +27,7 @@ pub(super) fn impl_model_for_struct( } }; - let model = match opts.as_model(&args) { + let model = match opts.as_model(&args, None) { Ok(val) => val, Err(err) => { return err.to_compile_error(); @@ -71,9 +71,11 @@ fn remove_helper_field_attributes(fields: &mut syn::Fields) -> &Punctuated, fields_as_from_db: Vec, + fields_as_update_from_db: Vec, fields_as_get_values: Vec, fields_as_field_refs: Vec, } @@ -91,9 +93,11 @@ impl ModelBuilder { let mut model_builder = Self { name: model.name.clone(), table_name: model.table_name, + pk_field: model.pk_field.clone(), fields_struct_name: format_ident!("{}Fields", model.name), fields_as_columns: Vec::with_capacity(field_count), fields_as_from_db: Vec::with_capacity(field_count), + fields_as_update_from_db: Vec::with_capacity(field_count), fields_as_get_values: Vec::with_capacity(field_count), fields_as_field_refs: Vec::with_capacity(field_count), }; @@ -113,18 +117,9 @@ impl ModelBuilder { let column_name = &field.column_name; { - let mut field_as_column = quote!(#orm_ident::Column::new( + let field_as_column = quote!(#orm_ident::Column::new( #orm_ident::Identifier::new(#column_name) )); - if field.auto_value { - field_as_column.append_all(quote!(.auto())); - } - if field.null { - field_as_column.append_all(quote!(.null())); - } - if field.unique { - field_as_column.append_all(quote!(.unique())); - } self.fields_as_columns.push(field_as_column); } @@ -132,8 +127,12 @@ impl ModelBuilder { #name: db_row.get::<#ty>(#index)? )); + self.fields_as_update_from_db.push(quote!( + #index => { self.#name = db_row.get::<#ty>(row_field_id)?; } + )); + self.fields_as_get_values.push(quote!( - #index => &self.#name as &dyn #orm_ident::ToDbValue + #index => &self.#name as &dyn #orm_ident::ToDbFieldValue )); self.fields_as_field_refs.push(quote!( @@ -144,24 +143,40 @@ impl ModelBuilder { #[must_use] fn build_model_impl(&self) -> TokenStream { + let crate_ident = flareon_ident(); let orm_ident = orm_ident(); let name = &self.name; let table_name = &self.table_name; let fields_struct_name = &self.fields_struct_name; let fields_as_columns = &self.fields_as_columns; + let pk_field_name = &self.pk_field.field_name; + let pk_column_name = &self.pk_field.column_name; + let pk_type = &self.pk_field.ty; let fields_as_from_db = &self.fields_as_from_db; + let fields_as_update_from_db = &self.fields_as_update_from_db; let fields_as_get_values = &self.fields_as_get_values; quote! { + #[#crate_ident::__private::async_trait] #[automatically_derived] impl #orm_ident::Model for #name { type Fields = #fields_struct_name; + type PrimaryKey = #pk_type; const COLUMNS: &'static [#orm_ident::Column] = &[ #(#fields_as_columns,)* ]; const TABLE_NAME: #orm_ident::Identifier = #orm_ident::Identifier::new(#table_name); + const PRIMARY_KEY_NAME: #orm_ident::Identifier = #orm_ident::Identifier::new(#pk_column_name); + + fn primary_key(&self) -> &Self::PrimaryKey { + &self.#pk_field_name + } + + fn set_primary_key(&mut self, primary_key: Self::PrimaryKey) { + self.#pk_field_name = primary_key; + } fn from_db(db_row: #orm_ident::Row) -> #orm_ident::Result { Ok(Self { @@ -169,7 +184,18 @@ impl ModelBuilder { }) } - fn get_values(&self, columns: &[usize]) -> Vec<&dyn #orm_ident::ToDbValue> { + fn update_from_db(&mut self, db_row: #orm_ident::Row, columns: &[usize]) -> #orm_ident::Result<()> { + for (row_field_id, column_id) in columns.into_iter().enumerate() { + match *column_id { + #(#fields_as_update_from_db,)* + _ => panic!("Unknown column index: {}", column_id), + } + } + + Ok(()) + } + + fn get_values(&self, columns: &[usize]) -> Vec<&dyn #orm_ident::ToDbFieldValue> { columns .iter() .map(|&column| match column { @@ -178,6 +204,15 @@ impl ModelBuilder { }) .collect() } + + async fn get_by_primary_key( + db: &DB, + pk: Self::PrimaryKey, + ) -> #orm_ident::Result> { + #orm_ident::query!(Self, $#pk_field_name == pk) + .get(db) + .await + } } } } diff --git a/flareon-macros/src/query.rs b/flareon-macros/src/query.rs index 32a3f1b..5e0b455 100644 --- a/flareon-macros/src/query.rs +++ b/flareon-macros/src/query.rs @@ -40,23 +40,33 @@ pub(super) fn expr_to_tokens(model_name: &syn::Type, expr: Expr) -> TokenStream let crate_name = flareon_ident(); match expr { - Expr::FieldRef(name) => { - quote!(<#model_name as #crate_name::db::Model>::Fields::#name.as_expr()) + Expr::FieldRef { field_name, .. } => { + quote!(<#model_name as #crate_name::db::Model>::Fields::#field_name.as_expr()) } Expr::Value(value) => { quote!(#crate_name::db::query::Expr::value(#value)) } - Expr::MethodCall { - called_on, - method_name, - args, - } => match *called_on { - Expr::Value(syn_expr) => { - quote!(#crate_name::db::query::Expr::value(#syn_expr.#method_name(#(#args),*))) + Expr::MemberAccess { + parent, + member_name, + .. + } => match parent.as_tokens() { + Some(tokens) => { + quote!(#crate_name::db::query::Expr::value(#tokens.#member_name)) } - _ => syn::Error::new( - method_name.span(), - "only method calls on values are supported", + None => syn::Error::new_spanned( + parent.as_tokens_full(), + "accessing members of values that reference database fields is unsupported", + ) + .to_compile_error(), + }, + Expr::FunctionCall { function, args } => match function.as_tokens() { + Some(tokens) => { + quote!(#crate_name::db::query::Expr::value(#tokens(#(#args),*))) + } + None => syn::Error::new_spanned( + function.as_tokens_full(), + "calling functions that reference database fields is unsupported", ) .to_compile_error(), }, @@ -90,9 +100,9 @@ fn handle_binary_comparison( let bin_fn = format_ident!("{}", bin_fn); let bin_trait = format_ident!("{}", bin_trait); - if let Expr::FieldRef(ref field) = lhs { + if let Expr::FieldRef { ref field_name, .. } = lhs { if let Some(rhs_tokens) = rhs.as_tokens() { - return quote!(#crate_name::db::query::#bin_trait::#bin_fn(<#model_name as #crate_name::db::Model>::Fields::#field, #rhs_tokens)); + return quote!(#crate_name::db::query::#bin_trait::#bin_fn(<#model_name as #crate_name::db::Model>::Fields::#field_name, #rhs_tokens)); } } diff --git a/flareon-macros/tests/compile_tests.rs b/flareon-macros/tests/compile_tests.rs index e51a809..782b47a 100644 --- a/flareon-macros/tests/compile_tests.rs +++ b/flareon-macros/tests/compile_tests.rs @@ -16,6 +16,8 @@ fn attr_model() { t.compile_fail("tests/ui/attr_model_tuple.rs"); t.compile_fail("tests/ui/attr_model_enum.rs"); t.compile_fail("tests/ui/attr_model_generic.rs"); + t.compile_fail("tests/ui/attr_model_no_pk.rs"); + t.compile_fail("tests/ui/attr_model_multiple_pks.rs"); } #[rustversion::attr(not(nightly), ignore)] @@ -28,6 +30,7 @@ fn func_query() { t.compile_fail("tests/ui/func_query_starting_op.rs"); t.compile_fail("tests/ui/func_query_double_field.rs"); t.compile_fail("tests/ui/func_query_invalid_field.rs"); + t.compile_fail("tests/ui/func_query_method_call_on_db_field.rs"); } #[rustversion::attr(not(nightly), ignore)] diff --git a/flareon-macros/tests/ui/attr_model_multiple_pks.rs b/flareon-macros/tests/ui/attr_model_multiple_pks.rs new file mode 100644 index 0000000..614ca7f --- /dev/null +++ b/flareon-macros/tests/ui/attr_model_multiple_pks.rs @@ -0,0 +1,11 @@ +use flareon::db::model; + +#[model] +struct MyModel { + id: i64, + #[model(primary_key)] + id_2: i64, + name: String, +} + +fn main() {} diff --git a/flareon-macros/tests/ui/attr_model_multiple_pks.stderr b/flareon-macros/tests/ui/attr_model_multiple_pks.stderr new file mode 100644 index 0000000..c21fb84 --- /dev/null +++ b/flareon-macros/tests/ui/attr_model_multiple_pks.stderr @@ -0,0 +1,5 @@ +error: composite primary keys are not supported; only one primary key field is allowed + --> tests/ui/attr_model_multiple_pks.rs:7:5 + | +7 | id_2: i64, + | ^^^^ diff --git a/flareon-macros/tests/ui/attr_model_no_pk.rs b/flareon-macros/tests/ui/attr_model_no_pk.rs new file mode 100644 index 0000000..4c8114d --- /dev/null +++ b/flareon-macros/tests/ui/attr_model_no_pk.rs @@ -0,0 +1,8 @@ +use flareon::db::model; + +#[model] +struct MyModel { + name: std::string::String, +} + +fn main() {} diff --git a/flareon-macros/tests/ui/attr_model_no_pk.stderr b/flareon-macros/tests/ui/attr_model_no_pk.stderr new file mode 100644 index 0000000..528251a --- /dev/null +++ b/flareon-macros/tests/ui/attr_model_no_pk.stderr @@ -0,0 +1,5 @@ +error: models must have a primary key field, either named `id` or annotated with the `#[model(primary_key)]` attribute + --> tests/ui/attr_model_no_pk.rs:4:8 + | +4 | struct MyModel { + | ^^^^^^^ diff --git a/flareon-macros/tests/ui/func_query_method_call_on_db_field.rs b/flareon-macros/tests/ui/func_query_method_call_on_db_field.rs new file mode 100644 index 0000000..51dc7a7 --- /dev/null +++ b/flareon-macros/tests/ui/func_query_method_call_on_db_field.rs @@ -0,0 +1,14 @@ +use flareon::db::{model, query}; + +#[derive(Debug)] +#[model] +struct MyModel { + id: i32, + name: std::string::String, + description: String, + visits: i32, +} + +fn main() { + query!(MyModel, $name.len); +} diff --git a/flareon-macros/tests/ui/func_query_method_call_on_db_field.stderr b/flareon-macros/tests/ui/func_query_method_call_on_db_field.stderr new file mode 100644 index 0000000..c784f04 --- /dev/null +++ b/flareon-macros/tests/ui/func_query_method_call_on_db_field.stderr @@ -0,0 +1,5 @@ +error: accessing members of values that reference database fields is unsupported + --> tests/ui/func_query_method_call_on_db_field.rs:13:21 + | +13 | query!(MyModel, $name.len); + | ^^^^^ diff --git a/flareon/Cargo.toml b/flareon/Cargo.toml index 777fa70..33ae87c 100644 --- a/flareon/Cargo.toml +++ b/flareon/Cargo.toml @@ -46,6 +46,7 @@ time.workspace = true tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } tower = { workspace = true, features = ["util"] } tower-sessions = { workspace = true, features = ["memory-store"] } +env_logger = "0.11.5" [dev-dependencies] async-stream.workspace = true diff --git a/flareon/src/auth.rs b/flareon/src/auth.rs index 28efb14..6fe88d7 100644 --- a/flareon/src/auth.rs +++ b/flareon/src/auth.rs @@ -23,6 +23,7 @@ use subtle::ConstantTimeEq; use thiserror::Error; use crate::config::SecretKey; +use crate::db::DbValue; #[cfg(feature = "db")] use crate::db::{ColumnType, DatabaseField, FromDbValue, SqlxValueRef, ToDbValue}; use crate::request::{Request, RequestExt}; @@ -433,7 +434,7 @@ impl FromDbValue for PasswordHash { #[cfg(feature = "db")] impl ToDbValue for PasswordHash { - fn to_sea_query_value(&self) -> sea_query::Value { + fn to_db_value(&self) -> DbValue { self.0.clone().into() } } diff --git a/flareon/src/db.rs b/flareon/src/db.rs index 7e2c835..e89c3d1 100644 --- a/flareon/src/db.rs +++ b/flareon/src/db.rs @@ -12,6 +12,7 @@ pub mod impl_postgres; pub mod impl_sqlite; pub mod migrations; pub mod query; +mod relations; mod sea_query_db; use std::fmt::Write; @@ -24,7 +25,8 @@ use log::debug; #[cfg(test)] use mockall::automock; use query::Query; -use sea_query::{Iden, SchemaStatementBuilder, SimpleExpr}; +pub use relations::{ForeignKey, ForeignKeyOnDeletePolicy, ForeignKeyOnUpdatePolicy}; +use sea_query::{Iden, IntoColumnRef, ReturningClause, SchemaStatementBuilder, SimpleExpr}; use sea_query_binder::{SqlxBinder, SqlxValues}; use sqlx::{Type, TypeInfo}; use thiserror::Error; @@ -57,6 +59,13 @@ pub enum DatabaseError { /// Error when applying migrations. #[error("Error when applying migrations: {0}")] MigrationError(#[from] migrations::MigrationEngineError), + /// Foreign Key could not be retrieved from the database because the record + /// was not found. + #[error("Error retrieving a Foreign Key from the database: record not found")] + ForeignKeyNotFound, + /// Primary key could not be converted from i64 using [`TryFromI64`] trait. + #[error("Primary key could not be converted from i64")] + PrimaryKeyFromI64Error, } impl DatabaseError { @@ -100,9 +109,15 @@ pub trait Model: Sized + Send + 'static { /// Rust. type Fields; + /// The primary key type of the model. + type PrimaryKey: PrimaryKey; + /// The name of the table in the database. const TABLE_NAME: Identifier; + /// The name of the primary key column in the database. + const PRIMARY_KEY_NAME: Identifier; + /// The columns of the model. const COLUMNS: &'static [Column]; @@ -114,8 +129,17 @@ pub trait Model: Sized + Send + 'static { /// with the model. fn from_db(db_row: Row) -> Result; + fn update_from_db(&mut self, db_row: Row, columns: &[usize]) -> Result<()>; + + /// Returns the primary key of the model. + fn primary_key(&self) -> &Self::PrimaryKey; + + /// Used by the ORM to set the primary key of the model after it has been + /// saved to the database. + fn set_primary_key(&mut self, primary_key: Self::PrimaryKey); + /// Gets the values of the model for the given columns. - fn get_values(&self, columns: &[usize]) -> Vec<&dyn ToDbValue>; + fn get_values(&self, columns: &[usize]) -> Vec<&dyn ToDbFieldValue>; /// Returns a query for all objects of this model. #[must_use] @@ -123,6 +147,11 @@ pub trait Model: Sized + Send + 'static { Query::new() } + async fn get_by_primary_key( + db: &DB, + pk: Self::PrimaryKey, + ) -> Result>; + /// Saves the model to the database. /// /// # Errors @@ -175,45 +204,18 @@ impl Iden for &Identifier { #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Column { name: Identifier, - auto_value: bool, - unique: bool, - null: bool, } impl Column { /// Creates a new column with the given name. #[must_use] pub const fn new(name: Identifier) -> Self { - Self { - name, - auto_value: false, - unique: false, - null: false, - } - } - - /// Marks the column as auto-increment. - #[must_use] - pub const fn auto(mut self) -> Self { - self.auto_value = true; - self - } - - /// Marks the column unique. - #[must_use] - pub const fn unique(mut self) -> Self { - self.unique = true; - self - } - - /// Marks the column as nullable. - #[must_use] - pub const fn null(mut self) -> Self { - self.null = true; - self + Self { name } } } +pub trait PrimaryKey: DatabaseField + Clone {} + /// A row structure that holds the data of a single row retrieved from the /// database. #[non_exhaustive] @@ -259,7 +261,7 @@ impl Row { } /// A trait denoting that some type can be used as a field in a database. -pub trait DatabaseField: FromDbValue + ToDbValue { +pub trait DatabaseField: FromDbValue + ToDbFieldValue { const NULLABLE: bool = false; /// The type of the column in the database as one of the variants of @@ -318,18 +320,70 @@ pub trait FromDbValue { Self: Sized; } +pub type DbValue = sea_query::Value; + /// A trait for converting a Rust value to a database value. pub trait ToDbValue: Send + Sync { /// Converts the Rust value to a `sea_query` value. /// /// This method is used to convert the Rust value to a value that can be /// used in a query. - fn to_sea_query_value(&self) -> sea_query::Value; + fn to_db_value(&self) -> DbValue; +} + +pub trait ToDbFieldValue { + fn to_db_field_value(&self) -> DbFieldValue; +} + +#[derive(Debug, Clone, PartialEq)] +pub enum DbFieldValue { + /// The value should be automatically generated by the database and not + /// included in the query. + Auto, + /// A value that should be included in the query. + Value(DbValue), +} + +impl DbFieldValue { + #[must_use] + pub fn is_auto(&self) -> bool { + matches!(self, Self::Auto) + } + + #[must_use] + pub fn is_value(&self) -> bool { + matches!(self, Self::Value(_)) + } + + #[must_use] + pub fn unwrap_value(self) -> sea_query::Value { + self.expect_value("called DbValue::unwrap_value() on a wrong DbValue variant") + } + + #[must_use] + pub fn expect_value(self, message: &str) -> sea_query::Value { + match self { + Self::Value(value) => value, + _ => panic!("{message}"), + } + } +} + +impl ToDbFieldValue for T { + fn to_db_field_value(&self) -> DbFieldValue { + DbFieldValue::Value(self.to_db_value()) + } +} + +impl> From for DbFieldValue { + fn from(value: T) -> Self { + Self::Value(value.into()) + } } impl ToDbValue for &T { - fn to_sea_query_value(&self) -> sea_query::Value { - (*self).to_sea_query_value() + fn to_db_value(&self) -> DbValue { + (*self).to_db_value() } } @@ -483,40 +537,75 @@ impl Database { /// the database, for instance because the migrations haven't been /// applied, or there was a problem with the database connection. pub async fn insert(&self, data: &mut T) -> Result<()> { - let non_auto_column_identifiers = T::COLUMNS + let column_identifiers = T::COLUMNS .iter() - .filter_map(|column| { - if column.auto_value { - None - } else { - Some(Identifier::from(column.name.as_str())) - } - }) - .collect::>(); - let value_indices = T::COLUMNS + .map(|column| Identifier::from(column.name.as_str())); + let value_indices: Vec<_> = T::COLUMNS .iter() .enumerate() - .filter_map(|(i, column)| if column.auto_value { None } else { Some(i) }) - .collect::>(); - let values = data.get_values(&value_indices); + .map(|(i, _column)| i) + .collect(); + let values = data + .get_values(&value_indices) + .into_iter() + .map(ToDbFieldValue::to_db_field_value); + + let mut auto_col_ids = Vec::new(); + let mut auto_col_identifiers = Vec::new(); + let mut value_identifiers = Vec::new(); + let mut filtered_values = Vec::new(); + std::iter::zip(std::iter::zip(value_indices, column_identifiers), values).for_each( + |((index, identifier), value)| match value { + DbFieldValue::Auto => { + auto_col_ids.push(index); + auto_col_identifiers.push(identifier.into_column_ref()); + } + DbFieldValue::Value(value) => { + value_identifiers.push(identifier); + filtered_values.push(value); + } + }, + ); - let insert_statement = sea_query::Query::insert() + let mut insert_statement = sea_query::Query::insert() .into_table(T::TABLE_NAME) - .columns(non_auto_column_identifiers) + .columns(value_identifiers) .values( - values + filtered_values .into_iter() - .map(|value| SimpleExpr::Value(value.to_sea_query_value())) + .map(|value| SimpleExpr::Value(value)) .collect::>(), )? + .or_default_values() .to_owned(); - let statement_result = self.execute_statement(&insert_statement).await?; + if !auto_col_ids.is_empty() { + let row = if self.supports_returning() { + insert_statement.returning(ReturningClause::Columns(auto_col_identifiers)); + + self.fetch_option(&insert_statement) + .await? + .expect("query should return the primary key") + } else { + let result = self.execute_statement(&insert_statement).await?; + let row_id = result + .last_inserted_row_id + .expect("expected last inserted row ID if RETURNING clause is not supported"); + let query = sea_query::Query::select() + .from(T::TABLE_NAME) + .columns(auto_col_identifiers) + .and_where(sea_query::Expr::col(T::PRIMARY_KEY_NAME).eq(row_id)) + .to_owned(); + self.fetch_option(&query).await?.expect( + "expected a row returned from a SELECT if RETURNING clause is not supported", + ) + }; + data.update_from_db(row, &auto_col_ids)?; + } else { + self.execute_statement(&insert_statement).await?; + } - debug!( - "Inserted row; rows affected: {}", - statement_result.rows_affected() - ); + debug!("Inserted row"); Ok(()) } @@ -625,7 +714,7 @@ impl Database { ) -> Result { let values = values .iter() - .map(ToDbValue::to_sea_query_value) + .map(ToDbValue::to_db_value) .collect::>(); let values = SqlxValues(sea_query::Values(values)); @@ -659,6 +748,14 @@ impl Database { Ok(result) } + fn supports_returning(&self) -> bool { + match self.inner { + DatabaseImpl::Sqlite(_) => true, + DatabaseImpl::Postgres(_) => true, + DatabaseImpl::MySql(_) => false, + } + } + async fn fetch_all(&self, statement: &T) -> Result> where T: SqlxBinder, @@ -777,14 +874,30 @@ impl DatabaseBackend for Database { #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct StatementResult { rows_affected: RowsNum, + last_inserted_row_id: Option, } impl StatementResult { /// Creates a new statement result with the given number of rows affected. - #[cfg(test)] #[must_use] pub(crate) fn new(rows_affected: RowsNum) -> Self { - Self { rows_affected } + Self { + rows_affected, + last_inserted_row_id: None, + } + } + + /// Creates a new statement result with the given number of rows affected + /// and last inserted row ID. + #[must_use] + pub(crate) fn new_with_last_inserted_row_id( + rows_affected: RowsNum, + last_inserted_row_id: u64, + ) -> Self { + Self { + rows_affected, + last_inserted_row_id: Some(last_inserted_row_id), + } } /// Returns the number of rows affected by the query. @@ -792,12 +905,68 @@ impl StatementResult { pub fn rows_affected(&self) -> RowsNum { self.rows_affected } + + /// Returns the ID of the last inserted row. + #[must_use] + pub fn last_inserted_row_id(&self) -> Option { + self.last_inserted_row_id + } } /// A structure that holds the number of rows. #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deref, Display)] pub struct RowsNum(pub u64); +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum Auto { + Fixed(T), + Auto, +} + +impl Auto { + #[must_use] + pub const fn auto() -> Self { + Self::Auto + } + + #[must_use] + pub const fn fixed(value: T) -> Self { + Self::Fixed(value) + } +} + +impl Default for Auto { + fn default() -> Self { + Self::Auto + } +} + +impl From for Auto { + fn from(value: T) -> Self { + Self::fixed(value) + } +} + +trait TryFromI64 { + fn try_from_i64(value: i64) -> Result + where + Self: Sized; +} + +impl TryFromI64 for i64 { + fn try_from_i64(value: i64) -> Result { + Ok(value) + } +} + +impl TryFromI64 for i32 { + fn try_from_i64(value: i64) -> Result { + value + .try_into() + .map_err(|_| DatabaseError::PrimaryKeyFromI64Error) + } +} + /// A wrapper over a string that has a limited length. /// /// This type is used to represent a string that has a limited length in the @@ -935,14 +1104,6 @@ mod tests { fn column() { let column = Column::new(Identifier::new("test")); assert_eq!(column.name.as_str(), "test"); - assert!(!column.auto_value); - assert!(!column.null); - - let column_auto = column.auto(); - assert!(column_auto.auto_value); - - let column_null = column.null(); - assert!(column_null.null); } #[test] diff --git a/flareon/src/db/fields.rs b/flareon/src/db/fields.rs index 47abc69..be50409 100644 --- a/flareon/src/db/fields.rs +++ b/flareon/src/db/fields.rs @@ -1,5 +1,7 @@ +//! `DatabaseField` implementations for common types. + use flareon::db::DatabaseField; -use sea_query::Value; +use log::debug; #[cfg(feature = "mysql")] use crate::db::impl_mysql::MySqlValueRef; @@ -8,7 +10,8 @@ use crate::db::impl_postgres::PostgresValueRef; #[cfg(feature = "sqlite")] use crate::db::impl_sqlite::SqliteValueRef; use crate::db::{ - ColumnType, DatabaseError, FromDbValue, LimitedString, Result, SqlxValueRef, ToDbValue, + Auto, ColumnType, DatabaseError, DbFieldValue, DbValue, ForeignKey, FromDbValue, LimitedString, + Model, PrimaryKey, Result, SqlxValueRef, ToDbFieldValue, ToDbValue, }; macro_rules! impl_from_sqlite_default { @@ -41,13 +44,13 @@ macro_rules! impl_from_mysql_default { macro_rules! impl_to_db_value_default { ($ty:ty) => { impl ToDbValue for $ty { - fn to_sea_query_value(&self) -> Value { + fn to_db_value(&self) -> DbValue { self.clone().into() } } impl ToDbValue for Option<$ty> { - fn to_sea_query_value(&self) -> Value { + fn to_db_value(&self) -> DbValue { self.clone().into() } } @@ -136,7 +139,7 @@ impl_db_field!(String, Text); impl_db_field!(Vec, Blob); impl ToDbValue for &str { - fn to_sea_query_value(&self) -> Value { + fn to_db_value(&self) -> DbValue { (*self).to_string().into() } } @@ -171,14 +174,14 @@ impl FromDbValue for Option> { impl_to_db_value_default!(chrono::DateTime); impl ToDbValue for Option<&str> { - fn to_sea_query_value(&self) -> Value { + fn to_db_value(&self) -> DbValue { self.map(ToString::to_string).into() } } impl DatabaseField for Option where - Option: ToDbValue + FromDbValue, + Option: ToDbFieldValue + FromDbValue, { const NULLABLE: bool = true; const TYPE: ColumnType = T::TYPE; @@ -209,13 +212,154 @@ impl FromDbValue for LimitedString { } impl ToDbValue for LimitedString { - fn to_sea_query_value(&self) -> Value { + fn to_db_value(&self) -> DbValue { self.0.clone().into() } } impl ToDbValue for Option> { - fn to_sea_query_value(&self) -> Value { + fn to_db_value(&self) -> DbValue { self.clone().map(|s| s.0).into() } } + +impl DatabaseField for ForeignKey { + const NULLABLE: bool = T::PrimaryKey::NULLABLE; + const TYPE: ColumnType = T::PrimaryKey::TYPE; +} + +impl FromDbValue for ForeignKey { + #[cfg(feature = "sqlite")] + fn from_sqlite(value: SqliteValueRef) -> Result { + T::PrimaryKey::from_sqlite(value).map(ForeignKey::PrimaryKey) + } + + #[cfg(feature = "postgres")] + fn from_postgres(value: PostgresValueRef) -> Result { + T::PrimaryKey::from_postgres(value).map(ForeignKey::PrimaryKey) + } + + #[cfg(feature = "mysql")] + fn from_mysql(value: MySqlValueRef) -> Result { + T::PrimaryKey::from_mysql(value).map(ForeignKey::PrimaryKey) + } +} + +impl ToDbFieldValue for ForeignKey { + fn to_db_field_value(&self) -> DbFieldValue { + self.primary_key().to_db_field_value() + } +} + +impl FromDbValue for Option> +where + Option: FromDbValue, +{ + #[cfg(feature = "sqlite")] + fn from_sqlite(value: SqliteValueRef) -> Result { + Ok(>::from_sqlite(value)?.map(ForeignKey::PrimaryKey)) + } + + #[cfg(feature = "postgres")] + fn from_postgres(value: PostgresValueRef) -> Result { + Ok(>::from_postgres(value)?.map(ForeignKey::PrimaryKey)) + } + + #[cfg(feature = "mysql")] + fn from_mysql(value: MySqlValueRef) -> Result { + Ok(>::from_mysql(value)?.map(ForeignKey::PrimaryKey)) + } +} + +impl ToDbFieldValue for Option> +where + Option: ToDbFieldValue, +{ + fn to_db_field_value(&self) -> DbFieldValue { + match self { + Some(foreign_key) => foreign_key.to_db_field_value(), + None => >::None.to_db_field_value(), + } + } +} + +impl DatabaseField for Auto { + const NULLABLE: bool = T::NULLABLE; + const TYPE: ColumnType = T::TYPE; +} + +impl FromDbValue for Auto { + fn from_sqlite(value: SqliteValueRef) -> Result + where + Self: Sized, + { + Ok(Self::fixed(T::from_sqlite(value)?)) + } + + fn from_postgres(value: PostgresValueRef) -> Result + where + Self: Sized, + { + Ok(Self::fixed(T::from_postgres(value)?)) + } + + fn from_mysql(value: MySqlValueRef) -> Result + where + Self: Sized, + { + Ok(Self::fixed(T::from_mysql(value)?)) + } +} + +impl ToDbFieldValue for Auto { + fn to_db_field_value(&self) -> DbFieldValue { + match self { + Self::Fixed(value) => value.to_db_field_value(), + Self::Auto => DbFieldValue::Auto, + } + } +} + +impl FromDbValue for Option> +where + Option: FromDbValue, +{ + fn from_sqlite(value: SqliteValueRef) -> Result + where + Self: Sized, + { + >::from_sqlite(value).map(|value| value.map(Auto::fixed)) + } + + fn from_postgres(value: PostgresValueRef) -> Result + where + Self: Sized, + { + >::from_postgres(value).map(|value| value.map(Auto::fixed)) + } + + fn from_mysql(value: MySqlValueRef) -> Result + where + Self: Sized, + { + >::from_mysql(value).map(|value| value.map(Auto::fixed)) + } +} + +impl ToDbFieldValue for Option> +where + Option: ToDbFieldValue, +{ + fn to_db_field_value(&self) -> DbFieldValue { + match self { + Some(auto) => auto.to_db_field_value(), + None => >::None.to_db_field_value(), + } + } +} + +impl PrimaryKey for Auto {} + +impl PrimaryKey for i32 {} + +impl PrimaryKey for i64 {} diff --git a/flareon/src/db/impl_mysql.rs b/flareon/src/db/impl_mysql.rs index 2314104..d63aef9 100644 --- a/flareon/src/db/impl_mysql.rs +++ b/flareon/src/db/impl_mysql.rs @@ -4,10 +4,18 @@ use crate::db::ColumnType; impl_sea_query_db_backend!(DatabaseMySql: sqlx::mysql::MySql, sqlx::mysql::MySqlPool, MySqlRow, MySqlValueRef, sea_query::MysqlQueryBuilder); impl DatabaseMySql { + async fn init(&self) -> crate::db::Result<()> { + Ok(()) + } + fn prepare_values(_values: &mut sea_query_binder::SqlxValues) { // No changes are needed for MySQL } + fn last_inserted_row_id_for(result: &sqlx::mysql::MySqlQueryResult) -> Option { + Some(result.last_insert_id()) + } + pub(super) fn sea_query_column_type_for( &self, column_type: ColumnType, diff --git a/flareon/src/db/impl_postgres.rs b/flareon/src/db/impl_postgres.rs index 5ade464..eca4550 100644 --- a/flareon/src/db/impl_postgres.rs +++ b/flareon/src/db/impl_postgres.rs @@ -3,6 +3,10 @@ use crate::db::sea_query_db::impl_sea_query_db_backend; impl_sea_query_db_backend!(DatabasePostgres: sqlx::postgres::Postgres, sqlx::postgres::PgPool, PostgresRow, PostgresValueRef, sea_query::PostgresQueryBuilder); impl DatabasePostgres { + async fn init(&self) -> crate::db::Result<()> { + Ok(()) + } + fn prepare_values(values: &mut sea_query_binder::SqlxValues) { for value in &mut values.0 .0 { Self::tinyint_to_smallint(value); @@ -34,6 +38,10 @@ impl DatabasePostgres { } } + fn last_inserted_row_id_for(result: &sqlx::postgres::PgQueryResult) -> Option { + None + } + pub(super) fn sea_query_column_type_for( &self, column_type: crate::db::ColumnType, diff --git a/flareon/src/db/impl_sqlite.rs b/flareon/src/db/impl_sqlite.rs index 5f228b1..705f517 100644 --- a/flareon/src/db/impl_sqlite.rs +++ b/flareon/src/db/impl_sqlite.rs @@ -1,12 +1,29 @@ +use sea_query_binder::SqlxValues; +use sqlx::Executor; + use crate::db::sea_query_db::impl_sea_query_db_backend; impl_sea_query_db_backend!(DatabaseSqlite: sqlx::sqlite::Sqlite, sqlx::sqlite::SqlitePool, SqliteRow, SqliteValueRef, sea_query::SqliteQueryBuilder); impl DatabaseSqlite { - fn prepare_values(_values: &mut sea_query_binder::SqlxValues) { + async fn init(&self) -> crate::db::Result<()> { + self.raw("PRAGMA foreign_keys = ON").await?; + Ok(()) + } + + async fn raw(&self, sql: &str) -> crate::db::Result { + self.raw_with(sql, SqlxValues(sea_query::Values(Vec::new()))) + .await + } + + fn prepare_values(_values: &mut SqlxValues) { // No changes are needed for SQLite } + fn last_inserted_row_id_for(result: &sqlx::sqlite::SqliteQueryResult) -> Option { + Some(result.last_insert_rowid() as u64) + } + pub(super) fn sea_query_column_type_for( &self, column_type: crate::db::ColumnType, diff --git a/flareon/src/db/migrations.rs b/flareon/src/db/migrations.rs index 62c13a4..f17dba8 100644 --- a/flareon/src/db/migrations.rs +++ b/flareon/src/db/migrations.rs @@ -3,13 +3,14 @@ mod sorter; use std::fmt; use std::fmt::{Debug, Formatter}; -use flareon_macros::{model, query}; +use flareon::db::relations::ForeignKeyOnUpdatePolicy; use log::info; use sea_query::{ColumnDef, StringLen}; use thiserror::Error; use crate::db::migrations::sorter::{MigrationSorter, MigrationSorterError}; -use crate::db::{ColumnType, Database, DatabaseField, Identifier, Result}; +use crate::db::relations::ForeignKeyOnDeletePolicy; +use crate::db::{model, query, ColumnType, Database, DatabaseField, Identifier, Model, Result}; #[derive(Debug, Clone, Error)] #[non_exhaustive] @@ -244,6 +245,17 @@ impl Operation { let mut query = sea_query::Table::create().table(*table_name).to_owned(); for field in *fields { query.col(field.as_column_def(database)); + if let Some(foreign_key) = field.foreign_key { + query.foreign_key( + sea_query::ForeignKeyCreateStatement::new() + .from_tbl(*table_name) + .from_col(field.name) + .to_tbl(foreign_key.model) + .to_col(foreign_key.field) + .on_delete(foreign_key.on_delete.into()) + .on_update(foreign_key.on_update.into()), + ); + } } if *if_not_exists { query.if_not_exists(); @@ -345,6 +357,7 @@ pub struct Field { pub null: bool, /// Whether the column has a unique constraint pub unique: bool, + foreign_key: Option, } impl Field { @@ -357,9 +370,36 @@ impl Field { auto_value: false, null: false, unique: false, + foreign_key: None, } } + #[must_use] + pub const fn foreign_key( + mut self, + to_model: Identifier, + to_field: Identifier, + on_delete: ForeignKeyOnDeletePolicy, + on_update: ForeignKeyOnUpdatePolicy, + ) -> Self { + assert!( + self.null || !matches!(on_delete, ForeignKeyOnDeletePolicy::SetNone), + "`ForeignKey` must be inside `Option` if `on_delete` is set to `SetNone`" + ); + assert!( + self.null || !matches!(on_update, ForeignKeyOnUpdatePolicy::SetNone), + "`ForeignKey` must be inside `Option` if `on_update` is set to `SetNone`" + ); + + self.foreign_key = Some(ForeignKeyReference { + model: to_model, + field: to_field, + on_delete, + on_update, + }); + self + } + #[must_use] pub const fn primary_key(mut self) -> Self { self.primary_key = true; @@ -411,6 +451,14 @@ impl Field { } } +#[derive(Debug, Copy, Clone)] +struct ForeignKeyReference { + model: Identifier, + field: Identifier, + on_delete: ForeignKeyOnDeletePolicy, + on_update: ForeignKeyOnUpdatePolicy, +} + #[cfg_attr(test, mockall::automock)] pub(super) trait ColumnTypeMapper { fn sea_query_column_type_for(&self, column_type: ColumnType) -> sea_query::ColumnType; diff --git a/flareon/src/db/query.rs b/flareon/src/db/query.rs index 72ec7b1..2f097a7 100644 --- a/flareon/src/db/query.rs +++ b/flareon/src/db/query.rs @@ -4,7 +4,10 @@ use derive_more::Debug; use sea_query::IntoColumnRef; use crate::db; -use crate::db::{DatabaseBackend, FromDbValue, Identifier, Model, StatementResult, ToDbValue}; +use crate::db::{ + Auto, DatabaseBackend, DbFieldValue, DbValue, ForeignKey, FromDbValue, Identifier, Model, + StatementResult, ToDbFieldValue, +}; /// A query that can be executed on a database. Can be used to filter, update, /// or delete rows. @@ -131,7 +134,7 @@ impl Query { #[derive(Debug)] pub enum Expr { Field(Identifier), - Value(#[debug("{}", _0.to_sea_query_value())] Box), + Value(DbValue), And(Box, Box), Or(Box, Box), Eq(Box, Box), @@ -169,8 +172,11 @@ impl Expr { /// let expr = Expr::value(30); /// ``` #[must_use] - pub fn value(value: T) -> Self { - Self::Value(Box::new(value)) + pub fn value(value: T) -> Self { + match value.to_db_field_value() { + DbFieldValue::Value(value) => Self::Value(value), + _ => panic!("Cannot create query with a non-value field"), + } } /// Create a new `AND` expression. @@ -299,7 +305,7 @@ impl Expr { pub fn as_sea_query_expr(&self) -> sea_query::SimpleExpr { match self { Self::Field(identifier) => (*identifier).into_column_ref().into(), - Self::Value(value) => value.to_sea_query_value().into(), + Self::Value(value) => (*value).clone().into(), Self::And(lhs, rhs) => lhs.as_sea_query_expr().and(rhs.as_sea_query_expr()), Self::Or(lhs, rhs) => lhs.as_sea_query_expr().or(rhs.as_sea_query_expr()), Self::Eq(lhs, rhs) => lhs.as_sea_query_expr().eq(rhs.as_sea_query_expr()), @@ -323,7 +329,7 @@ pub struct FieldRef { phantom_data: PhantomData, } -impl FieldRef { +impl FieldRef { /// Create a new field reference. #[must_use] pub const fn new(identifier: Identifier) -> Self { @@ -344,18 +350,18 @@ impl FieldRef { /// A trait for types that can be compared in database expressions. pub trait ExprEq { - fn eq>(self, other: V) -> Expr; + fn eq>(self, other: V) -> Expr; - fn ne>(self, other: V) -> Expr; + fn ne>(self, other: V) -> Expr; } -impl ExprEq for FieldRef { - fn eq>(self, other: V) -> Expr { - Expr::eq(self.as_expr(), Expr::value(other.into())) +impl ExprEq for FieldRef { + fn eq>(self, other: V) -> Expr { + Expr::eq(self.as_expr(), Expr::value(other.into_field())) } - fn ne>(self, other: V) -> Expr { - Expr::ne(self.as_expr(), Expr::value(other.into())) + fn ne>(self, other: V) -> Expr { + Expr::ne(self.as_expr(), Expr::value(other.into_field())) } } @@ -409,6 +415,40 @@ impl_num_expr!(u64); impl_num_expr!(f32); impl_num_expr!(f64); +trait IntoField { + fn into_field(self) -> T; +} + +impl IntoField for T { + fn into_field(self) -> T { + self + } +} + +impl IntoField> for T { + fn into_field(self) -> Auto { + Auto::fixed(self) + } +} + +impl IntoField for &str { + fn into_field(self) -> String { + self.to_string() + } +} + +impl IntoField> for T { + fn into_field(self) -> ForeignKey { + ForeignKey::from(self) + } +} + +impl IntoField> for &T { + fn into_field(self) -> ForeignKey { + ForeignKey::from(self) + } +} + #[cfg(test)] mod tests { use flareon_macros::model; @@ -505,7 +545,7 @@ mod tests { fn test_expr_value() { let expr = Expr::value(30); if let Expr::Value(value) = expr { - assert_eq!(value.to_sea_query_value().to_string(), "30"); + assert_eq!(value.to_string(), "30"); } else { panic!("Expected Expr::Value"); } diff --git a/flareon/src/db/relations.rs b/flareon/src/db/relations.rs new file mode 100644 index 0000000..2dd9a67 --- /dev/null +++ b/flareon/src/db/relations.rs @@ -0,0 +1,109 @@ +use flareon::db::DatabaseError; + +use crate::db::{DatabaseBackend, Model, Result}; + +#[derive(Debug, Clone)] +pub enum ForeignKey { + PrimaryKey(T::PrimaryKey), + Model(Box), +} + +impl ForeignKey { + pub fn primary_key(&self) -> &T::PrimaryKey { + match self { + Self::PrimaryKey(pk) => pk, + Self::Model(model) => model.primary_key(), + } + } + + pub fn model(&self) -> Option<&T> { + match self { + Self::Model(model) => Some(model), + _ => None, + } + } + + pub fn unwrap(self) -> T { + match self { + Self::Model(model) => *model, + _ => panic!("object has not been retrieved from the database"), + } + } + + /// Retrieve the model from the database, if needed, and return it. + pub async fn get(&mut self, db: &DB) -> Result<&T> { + match self { + Self::Model(model) => Ok(model), + Self::PrimaryKey(pk) => { + let model = T::get_by_primary_key(db, pk.clone()) + .await? + .ok_or(DatabaseError::ForeignKeyNotFound)?; + *self = Self::Model(Box::new(model)); + Ok(self.model().expect("model was just set")) + } + } + } +} + +impl PartialEq for ForeignKey +where + T::PrimaryKey: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.primary_key() == other.primary_key() + } +} + +impl Eq for ForeignKey where T::PrimaryKey: Eq {} + +impl From for ForeignKey { + fn from(model: T) -> Self { + Self::Model(Box::new(model)) + } +} + +impl From<&T> for ForeignKey { + fn from(model: &T) -> Self { + Self::PrimaryKey(model.primary_key().clone()) + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)] +pub enum ForeignKeyOnDeletePolicy { + NoAction, + #[default] + Restrict, + Cascade, + SetNone, +} + +impl From for sea_query::ForeignKeyAction { + fn from(value: ForeignKeyOnDeletePolicy) -> Self { + match value { + ForeignKeyOnDeletePolicy::NoAction => Self::NoAction, + ForeignKeyOnDeletePolicy::Restrict => Self::Restrict, + ForeignKeyOnDeletePolicy::Cascade => Self::Cascade, + ForeignKeyOnDeletePolicy::SetNone => Self::SetNull, + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)] +pub enum ForeignKeyOnUpdatePolicy { + NoAction, + Restrict, + #[default] + Cascade, + SetNone, +} + +impl From for sea_query::ForeignKeyAction { + fn from(value: ForeignKeyOnUpdatePolicy) -> Self { + match value { + ForeignKeyOnUpdatePolicy::NoAction => Self::NoAction, + ForeignKeyOnUpdatePolicy::Restrict => Self::Restrict, + ForeignKeyOnUpdatePolicy::Cascade => Self::Cascade, + ForeignKeyOnUpdatePolicy::SetNone => Self::SetNull, + } + } +} diff --git a/flareon/src/db/sea_query_db.rs b/flareon/src/db/sea_query_db.rs index fc08016..4f92fe8 100644 --- a/flareon/src/db/sea_query_db.rs +++ b/flareon/src/db/sea_query_db.rs @@ -15,7 +15,9 @@ macro_rules! impl_sea_query_db_backend { pub(super) async fn new(url: &str) -> crate::db::Result { let db_connection = <$pool_ty>::connect(url).await?; - Ok(Self { db_connection }) + let db = Self { db_connection }; + db.init(); + Ok(db) } pub(super) async fn close(&self) -> crate::db::Result<()> { @@ -88,6 +90,7 @@ macro_rules! impl_sea_query_db_backend { let result = sqlx_statement.execute(&self.db_connection).await?; let result = crate::db::StatementResult { rows_affected: crate::db::RowsNum(result.rows_affected()), + last_inserted_row_id: Self::last_inserted_row_id_for(&result), }; log::debug!("Rows affected: {}", result.rows_affected.0); diff --git a/flareon/tests/db.rs b/flareon/tests/db.rs index d365a2a..9489d80 100644 --- a/flareon/tests/db.rs +++ b/flareon/tests/db.rs @@ -6,7 +6,10 @@ use fake::rand::SeedableRng; use fake::{Dummy, Fake, Faker}; use flareon::db::migrations::{Field, Operation}; use flareon::db::query::ExprEq; -use flareon::db::{model, query, Database, DatabaseField, Identifier, LimitedString, Model}; +use flareon::db::{ + model, query, Auto, Database, DatabaseError, DatabaseField, ForeignKey, + ForeignKeyOnDeletePolicy, ForeignKeyOnUpdatePolicy, Identifier, LimitedString, Model, +}; use flareon::test::TestDatabase; #[flareon_macros::dbtest] @@ -16,7 +19,7 @@ async fn model_crud(test_db: &mut TestDatabase) { assert_eq!(TestModel::objects().all(&**test_db).await.unwrap(), vec![]); let mut model = TestModel { - id: 0, + id: Auto::fixed(1), name: "test".to_owned(), }; model.save(&**test_db).await.unwrap(); @@ -40,7 +43,7 @@ async fn model_macro_filtering(test_db: &mut TestDatabase) { assert_eq!(TestModel::objects().all(&**test_db).await.unwrap(), vec![]); let mut model = TestModel { - id: 0, + id: Auto::auto(), name: "test".to_owned(), }; model.save(&**test_db).await.unwrap(); @@ -61,7 +64,7 @@ async fn model_macro_filtering(test_db: &mut TestDatabase) { #[derive(Debug, PartialEq)] #[model] struct TestModel { - id: i32, + id: Auto, name: String, } @@ -72,7 +75,7 @@ async fn migrate_test_model(db: &Database) { const CREATE_TEST_MODEL: Operation = Operation::create_model() .table_name(Identifier::new("test_model")) .fields(&[ - Field::new(Identifier::new("id"), ::TYPE) + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) .primary_key() .auto(), Field::new(Identifier::new("name"), ::TYPE), @@ -99,8 +102,8 @@ macro_rules! all_fields_migration_field { #[derive(Debug, PartialEq, Dummy)] #[model] struct AllFieldsModel { - #[dummy(expr = "0i32")] - id: i32, + #[dummy(expr = "Auto::auto()")] + id: Auto, field_bool: bool, field_i8: i8, field_i16: i16, @@ -134,7 +137,7 @@ async fn migrate_all_fields_model(db: &Database) { const CREATE_ALL_FIELDS_MODEL: Operation = Operation::create_model() .table_name(Identifier::new("all_fields_model")) .fields(&[ - Field::new(Identifier::new("id"), ::TYPE) + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) .primary_key() .auto(), all_fields_migration_field!(bool), @@ -174,7 +177,6 @@ async fn all_fields_model(db: &mut TestDatabase) { } let mut models_from_db: Vec<_> = AllFieldsModel::objects().all(&**db).await.unwrap(); - models_from_db.iter_mut().for_each(|model| model.id = 0); normalize_datetimes(&mut models); normalize_datetimes(&mut models_from_db); @@ -197,3 +199,161 @@ fn normalize_datetimes(data: &mut Vec) { ); } } + +#[flareon_macros::dbtest] +async fn foreign_keys(db: &mut TestDatabase) { + #[derive(Debug, Clone, PartialEq)] + #[model] + struct Artist { + id: Auto, + name: String, + } + + #[derive(Debug, Clone, PartialEq)] + #[model] + struct Track { + id: Auto, + artist: ForeignKey, + name: String, + } + + const CREATE_ARTIST: Operation = Operation::create_model() + .table_name(Identifier::new("artist")) + .fields(&[ + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) + .primary_key() + .auto(), + Field::new(Identifier::new("name"), ::TYPE), + ]) + .build(); + const CREATE_TRACK: Operation = Operation::create_model() + .table_name(Identifier::new("track")) + .fields(&[ + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) + .primary_key() + .auto(), + Field::new( + Identifier::new("artist"), + as DatabaseField>::TYPE, + ) + .foreign_key( + ::TABLE_NAME, + ::PRIMARY_KEY_NAME, + ForeignKeyOnDeletePolicy::Restrict, + ForeignKeyOnUpdatePolicy::Restrict, + ), + Field::new(Identifier::new("name"), ::TYPE), + ]) + .build(); + + CREATE_ARTIST.forwards(db).await.unwrap(); + CREATE_TRACK.forwards(db).await.unwrap(); + + let mut artist = Artist { + id: Auto::auto(), + name: "artist".to_owned(), + }; + artist.save(&**db).await.unwrap(); + + let mut track = Track { + id: Auto::auto(), + artist: ForeignKey::from(&artist), + name: "track".to_owned(), + }; + track.save(&**db).await.unwrap(); + + let mut track = Track::objects().all(&**db).await.unwrap()[0].clone(); + let artist_from_db = track.artist.get(&**db).await.unwrap(); + assert_eq!(artist_from_db, &artist); + + let error = query!(Artist, $id == artist.id) + .delete(&**db) + .await + .unwrap_err(); + // expected foreign key violation + assert!(matches!(error, DatabaseError::DatabaseEngineError(_))); + + query!(Track, $artist == &artist) + .delete(&**db) + .await + .unwrap(); + query!(Artist, $id == artist.id) + .delete(&**db) + .await + .unwrap(); + // no error should be thrown +} + +#[flareon_macros::dbtest] +async fn foreign_keys_option(db: &mut TestDatabase) { + #[derive(Debug, Clone, PartialEq)] + #[model] + struct Parent { + id: Auto, + } + + #[derive(Debug, Clone, PartialEq)] + #[model] + struct Child { + id: Auto, + parent: Option>, + } + + const CREATE_PARENT: Operation = Operation::create_model() + .table_name(Identifier::new("parent")) + .fields(&[ + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) + .primary_key() + .auto(), + ]) + .build(); + const CREATE_CHILD: Operation = Operation::create_model() + .table_name(Identifier::new("child")) + .fields(&[ + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) + .primary_key() + .auto(), + Field::new( + Identifier::new("parent"), + > as DatabaseField>::TYPE, + ) + .foreign_key( + ::TABLE_NAME, + ::PRIMARY_KEY_NAME, + ForeignKeyOnDeletePolicy::Restrict, + ForeignKeyOnUpdatePolicy::Restrict, + ) + .set_null(> as DatabaseField>::NULLABLE), + ]) + .build(); + + CREATE_PARENT.forwards(db).await.unwrap(); + CREATE_CHILD.forwards(db).await.unwrap(); + + // no parent + let mut child = Child { + id: Auto::auto(), + parent: None, + }; + child.save(&**db).await.unwrap(); + + let mut child = Child::objects().all(&**db).await.unwrap()[0].clone(); + assert_eq!(child.parent, None); + + query!(Child, $id == child.id).delete(&**db).await.unwrap(); + + // with parent + let mut parent = Parent { id: Auto::auto() }; + parent.save(&**db).await.unwrap(); + + let mut child = Child { + id: Auto::auto(), + parent: Some(ForeignKey::from(&parent)), + }; + child.save(&**db).await.unwrap(); + + let mut child = Child::objects().all(&**db).await.unwrap()[0].clone(); + let mut parent_fk = child.parent.unwrap(); + let parent_from_db = parent_fk.get(&**db).await.unwrap(); + assert_eq!(parent_from_db, &parent); +}