From 33e5ae1b6536434ccfd9f5130d671a6ce6748870 Mon Sep 17 00:00:00 2001 From: Brendan Allan Date: Thu, 28 Dec 2023 22:56:11 +0800 Subject: [PATCH] allow arrays in unique where params --- cli/src/generator/models/field.rs | 8 ++++---- cli/src/generator/models/where_params.rs | 8 +++++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/cli/src/generator/models/field.rs b/cli/src/generator/models/field.rs index 330faac1..a96c08ff 100644 --- a/cli/src/generator/models/field.rs +++ b/cli/src/generator/models/field.rs @@ -174,18 +174,18 @@ pub fn module( let filter_enum = format_ident!("{}Filter", &read_filter.name); // Add equals query functions. Unique/Where enum variants are added in unique/primary key sections earlier on. - let equals = match (model.field_is_primary(field_name), model.field_is_unique(field_name), field.arity.is_required()) { - (true, _, _) | (_, true, true) => quote! { + let equals = match (model.field_is_primary(field_name), model.field_is_unique(field_name), field.arity.is_optional()) { + (true, _, _) | (_, true, false) => quote! { pub fn equals>(value: #field_type) -> T { UniqueWhereParam::#equals_variant(value).into() } }, - (_, true, false) => quote! { + (_, true, true) => quote! { pub fn equals>(value: A) -> T { T::from_arg(value) } }, - (_, _, _) => quote! { + (false, false, _) => quote! { pub fn equals(value: #field_type) -> WhereParam { WhereParam::#field_name_pascal(_prisma::read_filters::#filter_enum::Equals(value)) } diff --git a/cli/src/generator/models/where_params.rs b/cli/src/generator/models/where_params.rs index 19064b51..d5164aab 100644 --- a/cli/src/generator/models/where_params.rs +++ b/cli/src/generator/models/where_params.rs @@ -24,7 +24,13 @@ impl Variant { field_name: field.name().to_string(), field_required_type: field .field_type() - .to_tokens(module_path, &dml::FieldArity::Required) + .to_tokens( + module_path, + &match field.arity() { + dml::FieldArity::Optional => dml::FieldArity::Required, + a => *a, + }, + ) .unwrap(), read_filter_name: read_filter.name.to_string(), optional: matches!(field.arity(), dml::FieldArity::Optional),