From 185deb3606066773eb23e0df1fbdaaee1ce9cb2f Mon Sep 17 00:00:00 2001 From: u9g Date: Mon, 9 Oct 2023 22:28:20 -0400 Subject: [PATCH] Move to our own Type type for everything that's exposed in trustfall_core (#435) * Move to our own Type type for everything that's exposed in trustfall_core * Switch to our own type implementation * add over max list depth panic * rename InlineModifiers to Modifiers * move from_type to associated function on type * move doc comment to struct instead of impl block * is_nullable() -> nullable() * rename previous is_nullable() usages * merge conflict fix * more is_nullable renames * add docs and doctests to new methods * cargo fmt fix * Update trustfall_core/src/ir/ty.rs Co-authored-by: Predrag Gruevski <2348618+obi1kenobi@users.noreply.github.com> * improve tests * add max list depth test with non-null on innermost * Move type-related code into `ir/types/` module and fix impl bug. * Add re-export for `Type` and `NamedTypeValue`. * Add early-return check for scalar-only subtyping. * Rename `base_named_type()` to `base_type()`. * Inline helpers and use `Type` reexport by default. * Return `Result` when parsing a string to a `Type`. * Rename `Type::new()` to `Type::parse()`. * Move type intersection onto `Type` itself. * Move `equal_ignoring_nullability()` method to `Type`. * Move value type-checking fn to a `Type` method. * Move orderability check to a method on `Type`. * Move scalar-only subtyping check into `Type` methods. * Rename `operations` module since it no longer contains any operations. * Add static type names for built-in types. * Minor polish. * Create statics for common types. * Add string capacity when printing. * Add more test coverage. --------- Co-authored-by: Predrag Gruevski <2348618+obi1kenobi@users.noreply.github.com> Co-authored-by: Predrag Gruevski --- trustfall_core/src/frontend/mod.rs | 92 +- .../src/interpreter/helpers/correctness.rs | 10 +- .../src/interpreter/hints/dynamic.rs | 6 +- .../src/interpreter/hints/vertex_info.rs | 16 +- trustfall_core/src/interpreter/mod.rs | 7 +- trustfall_core/src/ir/indexed.rs | 16 +- trustfall_core/src/ir/mod.rs | 134 ++- trustfall_core/src/ir/serialization.rs | 88 -- trustfall_core/src/ir/types.rs | 431 ---------- trustfall_core/src/ir/types/base.rs | 813 ++++++++++++++++++ trustfall_core/src/ir/types/mod.rs | 5 + trustfall_core/src/ir/types/named_typed.rs | 97 +++ trustfall_core/src/schema/adapter/mod.rs | 24 +- trustfall_core/src/schema/mod.rs | 48 +- 14 files changed, 1066 insertions(+), 721 deletions(-) delete mode 100644 trustfall_core/src/ir/serialization.rs delete mode 100644 trustfall_core/src/ir/types.rs create mode 100644 trustfall_core/src/ir/types/base.rs create mode 100644 trustfall_core/src/ir/types/mod.rs create mode 100644 trustfall_core/src/ir/types/named_typed.rs diff --git a/trustfall_core/src/frontend/mod.rs b/trustfall_core/src/frontend/mod.rs index 9a279d2f..fb18f7b0 100644 --- a/trustfall_core/src/frontend/mod.rs +++ b/trustfall_core/src/frontend/mod.rs @@ -5,10 +5,9 @@ use std::{ }; use async_graphql_parser::{ - types::{BaseType, ExecutableDocument, FieldDefinition, Type, TypeDefinition, TypeKind}, + types::{ExecutableDocument, FieldDefinition, TypeDefinition, TypeKind}, Positioned, }; -use async_graphql_value::Name; use smallvec::SmallVec; use crate::{ @@ -17,12 +16,10 @@ use crate::{ query::{parse_document, FieldConnection, FieldNode, Query}, }, ir::{ - types::{intersect_types, is_argument_type_valid, NamedTypedValue}, - Argument, ContextField, EdgeParameters, Eid, FieldRef, FieldValue, FoldSpecificField, - FoldSpecificFieldKind, IREdge, IRFold, IRQuery, IRQueryComponent, IRVertex, IndexedQuery, - LocalField, Operation, Recursive, TransformationKind, VariableRef, Vid, - TYPENAME_META_FIELD, TYPENAME_META_FIELD_ARC, TYPENAME_META_FIELD_NAME, - TYPENAME_META_FIELD_TYPE, + types::NamedTypedValue, Argument, ContextField, EdgeParameters, Eid, FieldRef, FieldValue, + FoldSpecificField, FoldSpecificFieldKind, IREdge, IRFold, IRQuery, IRQueryComponent, + IRVertex, IndexedQuery, LocalField, Operation, Recursive, TransformationKind, Type, + VariableRef, Vid, TYPENAME_META_FIELD, TYPENAME_META_FIELD_ARC, }, schema::{FieldOrigin, Schema, BUILTIN_SCALARS}, util::{BTreeMapTryInsertExt, TryCollectUniqueKey}, @@ -71,13 +68,13 @@ pub fn parse_doc(schema: &Schema, document: &ExecutableDocument) -> Result( defined_fields: &'a [Positioned], field_node: &FieldNode, -) -> (&'a Name, Arc, Arc, &'a Type) { +) -> (&'a str, Arc, Arc, Type) { if field_node.name.as_ref() == TYPENAME_META_FIELD { return ( - &TYPENAME_META_FIELD_NAME, + TYPENAME_META_FIELD, TYPENAME_META_FIELD_ARC.clone(), TYPENAME_META_FIELD_ARC.clone(), - &TYPENAME_META_FIELD_TYPE, + Type::new_named_type("String", false), ); } @@ -92,7 +89,12 @@ fn get_field_name_and_type_from_schema<'a>( } else { pre_coercion_type_name.clone() }; - return (field_name, pre_coercion_type_name, post_coercion_type_name, field_raw_type); + return ( + field_name, + pre_coercion_type_name, + post_coercion_type_name, + Type::from_type(field_raw_type), + ); } } @@ -161,7 +163,7 @@ fn make_edge_parameters( // The default value must be a valid type for the parameter, // otherwise the schema itself is invalid. - assert!(is_argument_type_valid(&arg.node.ty.node, &value)); + assert!(Type::from_type(&arg.node.ty.node).is_valid_value(&value)); value }) @@ -175,7 +177,7 @@ fn make_edge_parameters( } Some(value) => { // Type-check the supplied value against the schema. - if !is_argument_type_valid(&arg.node.ty.node, value) { + if !Type::from_type(&arg.node.ty.node).is_valid_value(value) { errors.push(FrontendError::InvalidEdgeParameterType( arg_name.to_string(), edge_definition.name.node.to_string(), @@ -222,7 +224,7 @@ fn make_edge_parameters( fn infer_variable_type( property_name: &str, - property_type: &Type, + property_type: Type, operation: &Operation<(), OperatorArgument>, ) -> Result> { match operation { @@ -243,31 +245,31 @@ fn infer_variable_type( // Using a "null" valued variable doesn't make sense as a comparison. // However, [[1], [2], null] is a valid value to use in the comparison, since // there are definitely values that it is smaller than or bigger than. - Ok(Type { base: property_type.base.clone(), nullable: false }) + Ok(property_type.with_nullability(false)) } Operation::Contains(..) | Operation::NotContains(..) => { // To be able to check whether the property's value contains the operand, // the property needs to be a list. If it's not a list, this is a bad filter. - let inner_type = match &property_type.base { - BaseType::Named(_) => { - return Err(Box::new(FilterTypeError::ListFilterOperationOnNonListField( - operation.operation_name().to_string(), - property_name.to_string(), - property_type.to_string(), - ))) - } - BaseType::List(inner) => inner.as_ref(), + // let value = property_type.value(); + let inner_type = if let Some(list) = property_type.as_list() { + list + } else { + return Err(Box::new(FilterTypeError::ListFilterOperationOnNonListField( + operation.operation_name().to_string(), + property_name.to_string(), + property_type.to_string(), + ))); }; // We're trying to see if a list of element contains our element, so its type // is whatever is inside the list -- nullable or not. - Ok(inner_type.clone()) + Ok(inner_type) } Operation::OneOf(..) | Operation::NotOneOf(..) => { // Whatever the property's type is, the argument must be a non-nullable list of // the same type, so that the elements of that list may be checked for equality // against that property's value. - Ok(Type { base: BaseType::List(Box::new(property_type.clone())), nullable: false }) + Ok(Type::new_list_type(property_type.clone(), false)) } Operation::HasPrefix(..) | Operation::NotHasPrefix(..) @@ -278,7 +280,7 @@ fn infer_variable_type( | Operation::RegexMatches(..) | Operation::NotRegexMatches(..) => { // Filtering operations involving strings only take non-nullable strings as inputs. - Ok(Type { base: BaseType::Named(Name::new("String")), nullable: false }) + Ok(Type::new_named_type("String", false)) } Operation::IsNull(..) | Operation::IsNotNull(..) => { // These are unary operations, there's no place where a variable can be used. @@ -323,7 +325,7 @@ fn make_filter_expr( variable_name: var_name.clone(), variable_type: infer_variable_type( left_operand.named(), - left_operand.typed(), + left_operand.typed().clone(), &filter_directive.operation, ) .map_err(|e| *e)?, @@ -395,7 +397,7 @@ pub fn make_ir_for_query(schema: &Schema, query: &Query) -> Result Result { *existing_type = intersection; } @@ -606,7 +608,7 @@ where #[allow(clippy::type_complexity)] let mut properties: BTreeMap< (Vid, Arc), - (Arc, &'schema Type, SmallVec<[&'query FieldNode; 1]>), + (Arc, Type, SmallVec<[&'query FieldNode; 1]>), > = Default::default(); output_handler.begin_subcomponent(); @@ -874,13 +876,10 @@ fn get_recurse_implicit_coercion( #[allow(clippy::too_many_arguments)] #[allow(clippy::type_complexity)] -fn make_vertex<'schema, 'query>( - schema: &'schema Schema, +fn make_vertex<'query>( + schema: &Schema, property_names_by_vertex: &BTreeMap>>, - properties: &BTreeMap< - (Vid, Arc), - (Arc, &'schema Type, SmallVec<[&'query FieldNode; 1]>), - >, + properties: &BTreeMap<(Vid, Arc), (Arc, Type, SmallVec<[&'query FieldNode; 1]>)>, tags: &mut TagHandler, component_path: &ComponentPath, vid: Vid, @@ -977,10 +976,7 @@ fn fill_in_vertex_data<'schema, 'query, V, E>( edges: &mut BTreeMap, folds: &mut BTreeMap>, property_names_by_vertex: &mut BTreeMap>>, - properties: &mut BTreeMap< - (Vid, Arc), - (Arc, &'schema Type, SmallVec<[&'query FieldNode; 1]>), - >, + properties: &mut BTreeMap<(Vid, Arc), (Arc, Type, SmallVec<[&'query FieldNode; 1]>)>, component_path: &mut ComponentPath, output_handler: &mut OutputHandler<'query>, tags: &mut TagHandler<'query>, @@ -1098,7 +1094,7 @@ where output_handler.end_nested_scope(next_vid); } else if BUILTIN_SCALARS.contains(subfield_post_coercion_type.as_ref()) || schema.scalars.contains_key(subfield_post_coercion_type.as_ref()) - || subfield_name.as_ref() == TYPENAME_META_FIELD + || subfield_name == TYPENAME_META_FIELD { // Processing a property. @@ -1126,7 +1122,7 @@ where )); } - let subfield_name: Arc = subfield_name.as_ref().to_owned().into(); + let subfield_name: Arc = subfield_name.into(); let key = (current_vid, subfield_name.clone()); properties .entry(key) @@ -1141,7 +1137,7 @@ where .or_default() .push(subfield_name.clone()); - (subfield_name, subfield_raw_type, SmallVec::from([subfield])) + (subfield_name, subfield_raw_type.clone(), SmallVec::from([subfield])) }); for output_directive in &subfield.output { @@ -1188,7 +1184,7 @@ where let tag_field = ContextField { vertex_id: current_vid, field_name: subfield.name.clone(), - field_type: subfield_raw_type.to_owned(), + field_type: subfield_raw_type.clone(), }; // TODO: handle tags on non-fold-related transformed fields here @@ -1199,7 +1195,7 @@ where } } } else { - unreachable!("field name: {}", subfield_name.as_ref()); + unreachable!("field name: {}", subfield_name); } } diff --git a/trustfall_core/src/interpreter/helpers/correctness.rs b/trustfall_core/src/interpreter/helpers/correctness.rs index 7319b2fd..4c5d01ac 100644 --- a/trustfall_core/src/interpreter/helpers/correctness.rs +++ b/trustfall_core/src/interpreter/helpers/correctness.rs @@ -1,12 +1,10 @@ use std::{collections::BTreeMap, fmt::Debug, num::NonZeroUsize, sync::Arc}; -use async_graphql_parser::types::Type; - use crate::{ interpreter::{Adapter, DataContext, InterpretedQuery, ResolveEdgeInfo, ResolveInfo}, ir::{ ContextField, EdgeParameters, Eid, FieldValue, IREdge, IRQuery, IRQueryComponent, IRVertex, - TransparentValue, Vid, + TransparentValue, Type, Vid, }, schema::{Schema, SchemaAdapter}, TryIntoStruct, @@ -171,7 +169,7 @@ fn make_resolve_info_for_property_check( property_name.clone() => ContextField { vertex_id: vid, field_name: property_name.clone(), - field_type: Type::new(property_type).expect("not a valid type"), + field_type: Type::parse(property_type).expect("not a valid type"), } }, }), @@ -328,7 +326,7 @@ fn make_resolve_edge_info_for_edge_check( property_name.clone() => ContextField { vertex_id: vid, field_name: property_name, - field_type: Type::new("String!").expect("not a valid type"), + field_type: Type::parse("String!").expect("not a valid type"), } }, }), @@ -496,7 +494,7 @@ fn make_resolve_info_for_type_coercion( typename_property.clone() => ContextField { vertex_id: vid, field_name: typename_property.clone(), - field_type: Type::new("String!").expect("not a valid type"), + field_type: Type::parse("String!").expect("not a valid type"), } }, }), diff --git a/trustfall_core/src/interpreter/hints/dynamic.rs b/trustfall_core/src/interpreter/hints/dynamic.rs index 9ee2d6d2..c85280c8 100644 --- a/trustfall_core/src/interpreter/hints/dynamic.rs +++ b/trustfall_core/src/interpreter/hints/dynamic.rs @@ -1,7 +1,5 @@ use std::{fmt::Debug, ops::Bound, sync::Arc}; -use async_graphql_parser::types::Type; - use crate::{ interpreter::{ execution::{ @@ -12,7 +10,9 @@ use crate::{ Adapter, ContextIterator, ContextOutcomeIterator, InterpretedQuery, TaggedValue, VertexIterator, }, - ir::{ContextField, FieldRef, FieldValue, FoldSpecificField, IRQueryComponent, Operation}, + ir::{ + ContextField, FieldRef, FieldValue, FoldSpecificField, IRQueryComponent, Operation, Type, + }, }; use super::CandidateValue; diff --git a/trustfall_core/src/interpreter/hints/vertex_info.rs b/trustfall_core/src/interpreter/hints/vertex_info.rs index 9b6ae876..b83723bf 100644 --- a/trustfall_core/src/interpreter/hints/vertex_info.rs +++ b/trustfall_core/src/interpreter/hints/vertex_info.rs @@ -285,7 +285,7 @@ impl VertexInfo for T { let first_filter = relevant_filters.first()?; let initial_candidate = self.statically_required_property(property).unwrap_or_else(|| { - if first_filter.left().field_type.nullable { + if first_filter.left().field_type.nullable() { CandidateValue::All } else { CandidateValue::Range(Range::full_non_null()) @@ -395,7 +395,7 @@ fn compute_statically_known_candidate<'a, 'b>( relevant_filters: impl Iterator>, query_variables: &'b BTreeMap, FieldValue>, ) -> Option> { - let is_subject_field_nullable = field.field_type.nullable; + let is_subject_field_nullable = field.field_type.nullable(); super::filters::candidate_from_statically_evaluated_filters( relevant_filters, query_variables, @@ -407,13 +407,11 @@ fn compute_statically_known_candidate<'a, 'b>( mod tests { use std::{ops::Bound, sync::Arc}; - use async_graphql_parser::types::Type; - use crate::{ interpreter::hints::{ vertex_info::compute_statically_known_candidate, CandidateValue, Range, }, - ir::{Argument, FieldValue, LocalField, Operation, VariableRef}, + ir::{Argument, FieldValue, LocalField, Operation, Type, VariableRef}, }; #[test] @@ -424,9 +422,9 @@ mod tests { let null: Arc = Arc::from("null"); let list: Arc = Arc::from("my_list"); let longer_list: Arc = Arc::from("longer_list"); - let nullable_int_type = Type::new("Int").unwrap(); - let int_type = Type::new("Int!").unwrap(); - let list_int_type = Type::new("[Int!]!").unwrap(); + let nullable_int_type = Type::parse("Int").unwrap(); + let int_type = Type::parse("Int!").unwrap(); + let list_int_type = Type::parse("[Int!]!").unwrap(); let first_var = Argument::Variable(VariableRef { variable_name: first.clone(), @@ -603,7 +601,7 @@ mod tests { #[test] fn use_schema_to_exclude_null_from_range() { let first: Arc = Arc::from("first"); - let int_type = Type::new("Int!").unwrap(); + let int_type = Type::parse("Int!").unwrap(); let first_var = Argument::Variable(VariableRef { variable_name: first.clone(), diff --git a/trustfall_core/src/interpreter/mod.rs b/trustfall_core/src/interpreter/mod.rs index 8a57e2c5..3d970aad 100644 --- a/trustfall_core/src/interpreter/mod.rs +++ b/trustfall_core/src/interpreter/mod.rs @@ -1,13 +1,10 @@ use std::{collections::BTreeMap, fmt::Debug, sync::Arc}; -use async_graphql_parser::types::Type; use itertools::Itertools; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use crate::{ - ir::{ - types::is_argument_type_valid, EdgeParameters, Eid, FieldRef, FieldValue, IndexedQuery, Vid, - }, + ir::{EdgeParameters, Eid, FieldRef, FieldValue, IndexedQuery, Type, Vid}, util::BTreeMapTryInsertExt, }; @@ -377,7 +374,7 @@ fn validate_argument_type( variable_type: &Type, argument_value: &FieldValue, ) -> Result<(), QueryArgumentsError> { - if is_argument_type_valid(variable_type, argument_value) { + if variable_type.is_valid_value(argument_value) { Ok(()) } else { Err(QueryArgumentsError::ArgumentTypeError( diff --git a/trustfall_core/src/ir/indexed.rs b/trustfall_core/src/ir/indexed.rs index 2b5ffe99..9c030624 100644 --- a/trustfall_core/src/ir/indexed.rs +++ b/trustfall_core/src/ir/indexed.rs @@ -5,14 +5,11 @@ use std::{ sync::Arc, }; -use async_graphql_parser::types::{BaseType, Type}; use serde::{Deserialize, Serialize}; use crate::util::BTreeMapTryInsertExt; -use super::{ - types::is_scalar_only_subtype, Argument, Eid, IREdge, IRFold, IRQuery, IRQueryComponent, Vid, -}; +use super::{Argument, Eid, IREdge, IRFold, IRQuery, IRQueryComponent, Type, Vid}; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct IndexedQuery { @@ -29,8 +26,6 @@ pub struct IndexedQuery { pub struct Output { pub name: Arc, - #[serde(serialize_with = "crate::ir::serialization::serde_type_serializer")] - #[serde(deserialize_with = "crate::ir::serialization::serde_type_deserializer")] pub value_type: Type, pub vid: Vid, @@ -101,13 +96,10 @@ fn get_output_type( ) -> Type { let mut wrapped_output_type = field_type.clone(); if component_optional_vertices.contains(&output_at) { - wrapped_output_type.nullable = true; + wrapped_output_type = wrapped_output_type.with_nullability(true); } for is_fold_optional in are_folds_optional.iter().rev() { - wrapped_output_type = Type { - base: BaseType::List(Box::new(wrapped_output_type)), - nullable: *is_fold_optional, - }; + wrapped_output_type = Type::new_list_type(wrapped_output_type, *is_fold_optional); } wrapped_output_type } @@ -146,7 +138,7 @@ fn add_data_from_component( // // If the variable type at top level is not a subtype of the type here, // this query is not valid. - if !is_scalar_only_subtype(&vref.variable_type, var_type) { + if !vref.variable_type.is_scalar_only_subtype(var_type) { return Err(InvalidIRQueryError::GetBetterVariant(-2)); } } diff --git a/trustfall_core/src/ir/mod.rs b/trustfall_core/src/ir/mod.rs index ad8b1f30..757481b8 100644 --- a/trustfall_core/src/ir/mod.rs +++ b/trustfall_core/src/ir/mod.rs @@ -1,7 +1,6 @@ //! Trustfall intermediate representation (IR) mod indexed; -pub mod serialization; pub mod types; pub mod value; @@ -9,27 +8,17 @@ use std::{ cmp::Ordering, collections::BTreeMap, fmt::Debug, num::NonZeroUsize, ops::Index, sync::Arc, }; -use async_graphql_parser::types::{BaseType, Type}; -use async_graphql_value::Name; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use crate::frontend::error::FilterTypeError; pub use self::indexed::{EdgeKind, IndexedQuery, InvalidIRQueryError, Output}; -use self::types::{ - are_base_types_equal_ignoring_nullability, is_base_type_orderable, NamedTypedValue, -}; +pub use self::types::{NamedTypedValue, Type}; pub use self::value::{FieldValue, TransparentValue}; pub(crate) const TYPENAME_META_FIELD: &str = "__typename"; -pub(crate) static TYPENAME_META_FIELD_NAME: Lazy = - Lazy::new(|| Name::new(TYPENAME_META_FIELD)); - -pub(crate) static TYPENAME_META_FIELD_TYPE: Lazy = - Lazy::new(|| Type::new("String!").unwrap()); - pub(crate) static TYPENAME_META_FIELD_ARC: Lazy> = Lazy::new(|| Arc::from(TYPENAME_META_FIELD)); @@ -135,12 +124,7 @@ pub struct IRQuery { pub root_component: Arc, - #[serde( - default, - skip_serializing_if = "BTreeMap::is_empty", - serialize_with = "crate::ir::serialization::serde_variables_serializer", - deserialize_with = "crate::ir::serialization::serde_variables_deserializer" - )] + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] pub variables: BTreeMap, Type>, } @@ -232,8 +216,7 @@ pub enum FoldSpecificFieldKind { Count, // Represents the number of elements in an IRFold's component. } -static NON_NULL_INT_TYPE: Lazy = - Lazy::new(|| Type { base: BaseType::Named(Name::new("Int")), nullable: false }); +static NON_NULL_INT_TYPE: Lazy = Lazy::new(|| Type::new_named_type("Int", false)); impl FoldSpecificFieldKind { pub fn field_type(&self) -> &Type { @@ -619,7 +602,7 @@ impl Operation { match self { Operation::IsNull(_) | Operation::IsNotNull(_) => { // Checking non-nullable types for null or non-null is pointless. - if left_type.nullable { + if left_type.nullable() { Ok(()) } else { Err(vec![FilterTypeError::NonNullableTypeFilteredForNullability( @@ -636,7 +619,7 @@ impl Operation { // For the operands relative to each other, nullability doesn't matter, // but the rest of the type must be the same. let right_type = right_type.unwrap(); - if are_base_types_equal_ignoring_nullability(&left_type.base, &right_type.base) { + if left_type.equal_ignoring_nullability(right_type) { Ok(()) } else { // The right argument must be a tag at this point. If it is not a tag @@ -663,7 +646,7 @@ impl Operation { let right_type = right_type.unwrap(); let mut errors = vec![]; - if !is_base_type_orderable(&left_type.base) { + if !left_type.is_orderable() { errors.push(FilterTypeError::OrderingFilterOperationOnNonOrderableField( self.operation_name().to_string(), left.named().to_string(), @@ -671,7 +654,7 @@ impl Operation { )); } - if !is_base_type_orderable(&right_type.base) { + if !right_type.is_orderable() { // The right argument must be a tag at this point. If it is not a tag // and the second .unwrap() below panics, then our type inference // has inferred an incorrect type for the variable in the argument. @@ -687,7 +670,7 @@ impl Operation { // For the operands relative to each other, nullability doesn't matter, // but the types must be equal to each other. - if !are_base_types_equal_ignoring_nullability(&left_type.base, &right_type.base) { + if !left_type.equal_ignoring_nullability(right_type) { // The right argument must be a tag at this point. If it is not a tag // and the second .unwrap() below panics, then our type inference // has inferred an incorrect type for the variable in the argument. @@ -712,22 +695,19 @@ impl Operation { Operation::Contains(_, _) | Operation::NotContains(_, _) => { // The left-hand operand needs to be a list, ignoring nullability. // The right-hand operand may be anything, if considered individually. - let inner_type = match &left_type.base { - BaseType::List(ty) => Ok(ty), - BaseType::Named(_) => { - Err(vec![FilterTypeError::ListFilterOperationOnNonListField( - self.operation_name().to_string(), - left.named().to_string(), - left_type.to_string(), - )]) - } - }?; + let inner_type = left_type.as_list().ok_or_else(|| { + vec![FilterTypeError::ListFilterOperationOnNonListField( + self.operation_name().to_string(), + left.named().to_string(), + left_type.to_string(), + )] + })?; let right_type = right_type.unwrap(); // However, the type inside the left-hand list must be equal, // ignoring nullability, to the type of the right-hand operand. - if are_base_types_equal_ignoring_nullability(&inner_type.base, &right_type.base) { + if inner_type.equal_ignoring_nullability(right_type) { Ok(()) } else { // The right argument must be a tag at this point. If it is not a tag @@ -749,26 +729,25 @@ impl Operation { // The right-hand operand needs to be a list, ignoring nullability. // The left-hand operand may be anything, if considered individually. let right_type = right_type.unwrap(); - let inner_type = match &right_type.base { - BaseType::List(ty) => Ok(ty), - BaseType::Named(_) => { - // The right argument must be a tag at this point. If it is not a tag - // and the second .unwrap() below panics, then our type inference - // has inferred an incorrect type for the variable in the argument. - let tag = right.unwrap().as_tag().unwrap(); - - Err(vec![FilterTypeError::ListFilterOperationOnNonListTag( - self.operation_name().to_string(), - tag_name.unwrap().to_string(), - tag.field_name().to_string(), - tag.field_type().to_string(), - )]) - } + let inner_type = if let Some(list) = right_type.as_list() { + Ok(list) + } else { + // The right argument must be a tag at this point. If it is not a tag + // and the second .unwrap() below panics, then our type inference + // has inferred an incorrect type for the variable in the argument. + let tag = right.unwrap().as_tag().unwrap(); + + Err(vec![FilterTypeError::ListFilterOperationOnNonListTag( + self.operation_name().to_string(), + tag_name.unwrap().to_string(), + tag.field_name().to_string(), + tag.field_type().to_string(), + )]) }?; // However, the type inside the right-hand list must be equal, // ignoring nullability, to the type of the left-hand operand. - if are_base_types_equal_ignoring_nullability(&left_type.base, &inner_type.base) { + if left_type.equal_ignoring_nullability(&inner_type) { Ok(()) } else { // The right argument must be a tag at this point. If it is not a tag @@ -797,31 +776,26 @@ impl Operation { let mut errors = vec![]; // Both operands need to be strings, ignoring nullability. - match &left_type.base { - BaseType::Named(ty) if ty == "String" => {} - _ => { - errors.push(FilterTypeError::StringFilterOperationOnNonStringField( - self.operation_name().to_string(), - left.named().to_string(), - left_type.to_string(), - )); - } - }; - - match &right_type.unwrap().base { - BaseType::Named(ty) if ty == "String" => {} - _ => { - // The right argument must be a tag at this point. If it is not a tag - // and the second .unwrap() below panics, then our type inference - // has inferred an incorrect type for the variable in the argument. - let tag = right.unwrap().as_tag().unwrap(); - errors.push(FilterTypeError::StringFilterOperationOnNonStringTag( - self.operation_name().to_string(), - tag_name.unwrap().to_string(), - tag.field_name().to_string(), - tag.field_type().to_string(), - )); - } + if left_type.is_list() || left_type.base_type() != "String" { + errors.push(FilterTypeError::StringFilterOperationOnNonStringField( + self.operation_name().to_string(), + left.named().to_string(), + left_type.to_string(), + )); + } + + // The right argument must be a tag at this point. If it is not a tag + // and the second .unwrap() below panics, then our type inference + // has inferred an incorrect type for the variable in the argument. + let right_type = right_type.unwrap(); + if right_type.is_list() || right_type.base_type() != "String" { + let tag = right.unwrap().as_tag().unwrap(); + errors.push(FilterTypeError::StringFilterOperationOnNonStringTag( + self.operation_name().to_string(), + tag_name.unwrap().to_string(), + tag.field_name().to_string(), + tag.field_type().to_string(), + )); } if errors.is_empty() { @@ -840,8 +814,6 @@ pub struct ContextField { pub field_name: Arc, - #[serde(serialize_with = "crate::ir::serialization::serde_type_serializer")] - #[serde(deserialize_with = "crate::ir::serialization::serde_type_deserializer")] pub field_type: Type, } @@ -849,8 +821,6 @@ pub struct ContextField { pub struct LocalField { pub field_name: Arc, - #[serde(serialize_with = "crate::ir::serialization::serde_type_serializer")] - #[serde(deserialize_with = "crate::ir::serialization::serde_type_deserializer")] pub field_type: Type, } @@ -858,8 +828,6 @@ pub struct LocalField { pub struct VariableRef { pub variable_name: Arc, - #[serde(serialize_with = "crate::ir::serialization::serde_type_serializer")] - #[serde(deserialize_with = "crate::ir::serialization::serde_type_deserializer")] pub variable_type: Type, } diff --git a/trustfall_core/src/ir/serialization.rs b/trustfall_core/src/ir/serialization.rs deleted file mode 100644 index 0fc384e6..00000000 --- a/trustfall_core/src/ir/serialization.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::{collections::BTreeMap, fmt, sync::Arc}; - -use async_graphql_parser::types::Type; -use serde::{self, de::Visitor, ser::SerializeMap, Deserializer, Serialize, Serializer}; - -pub fn serde_type_serializer(value: &Type, serializer: S) -> Result -where - S: Serializer, -{ - value.to_string().serialize(serializer) -} - -pub fn serde_type_deserializer<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, -{ - struct TypeDeserializer; - - impl<'de> Visitor<'de> for TypeDeserializer { - type Value = Type; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("GraphQL type") - } - - fn visit_str(self, s: &str) -> Result - where - E: serde::de::Error, - { - let ty = - Type::new(s).ok_or_else(|| serde::de::Error::custom("not a valid GraphQL type"))?; - Ok(ty) - } - } - - deserializer.deserialize_str(TypeDeserializer) -} - -pub fn serde_variables_serializer( - value: &BTreeMap, Type>, - serializer: S, -) -> Result -where - S: Serializer, -{ - let mut serializer = serializer.serialize_map(Some(value.len()))?; - - value.iter().try_for_each(|(k, v)| serializer.serialize_entry(k, &v.to_string()))?; - - serializer.end() -} - -pub fn serde_variables_deserializer<'de, D>( - deserializer: D, -) -> Result, Type>, D::Error> -where - D: Deserializer<'de>, -{ - struct TypeDeserializer; - - impl<'de> Visitor<'de> for TypeDeserializer { - type Value = BTreeMap, &'de str>; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("map of variable names -> types") - } - - fn visit_map(self, mut map: A) -> Result - where - A: serde::de::MapAccess<'de>, - { - let mut result: BTreeMap, &'de str> = BTreeMap::new(); - while let Some((key, value)) = map.next_entry()? { - result.insert(key, value); - } - Ok(result) - } - } - - deserializer.deserialize_map(TypeDeserializer).map(|value| { - let mut result: BTreeMap, Type> = Default::default(); - for (k, v) in value { - let ty = Type::new(v).unwrap(); - result.insert(k, ty); - } - result - }) -} diff --git a/trustfall_core/src/ir/types.rs b/trustfall_core/src/ir/types.rs deleted file mode 100644 index 69f3dd6b..00000000 --- a/trustfall_core/src/ir/types.rs +++ /dev/null @@ -1,431 +0,0 @@ -use std::fmt::Debug; - -use async_graphql_parser::types::{BaseType, Type}; - -use super::{ - Argument, ContextField, FieldRef, FieldValue, FoldSpecificField, FoldSpecificFieldKind, - LocalField, VariableRef, -}; - -pub trait NamedTypedValue: Debug + Clone + PartialEq + Eq { - fn typed(&self) -> &Type; - - fn named(&self) -> &str; -} - -impl NamedTypedValue for LocalField { - fn typed(&self) -> &Type { - &self.field_type - } - - fn named(&self) -> &str { - self.field_name.as_ref() - } -} - -impl NamedTypedValue for ContextField { - fn typed(&self) -> &Type { - &self.field_type - } - - fn named(&self) -> &str { - self.field_name.as_ref() - } -} - -impl NamedTypedValue for FoldSpecificField { - fn typed(&self) -> &Type { - self.kind.field_type() - } - - fn named(&self) -> &str { - self.kind.field_name() - } -} - -impl NamedTypedValue for FoldSpecificFieldKind { - fn typed(&self) -> &Type { - self.field_type() - } - - fn named(&self) -> &str { - self.field_name() - } -} - -impl NamedTypedValue for VariableRef { - fn typed(&self) -> &Type { - &self.variable_type - } - - fn named(&self) -> &str { - &self.variable_name - } -} - -impl NamedTypedValue for FieldRef { - fn typed(&self) -> &Type { - match self { - FieldRef::ContextField(c) => c.typed(), - FieldRef::FoldSpecificField(f) => f.kind.typed(), - } - } - - fn named(&self) -> &str { - match self { - FieldRef::ContextField(c) => c.named(), - FieldRef::FoldSpecificField(f) => f.kind.named(), - } - } -} - -impl NamedTypedValue for Argument { - fn typed(&self) -> &Type { - match self { - Argument::Tag(t) => t.typed(), - Argument::Variable(v) => v.typed(), - } - } - - fn named(&self) -> &str { - match self { - Argument::Tag(t) => t.named(), - Argument::Variable(v) => v.named(), - } - } -} - -pub(crate) fn are_base_types_equal_ignoring_nullability(left: &BaseType, right: &BaseType) -> bool { - match (left, right) { - (BaseType::Named(l), BaseType::Named(r)) => l == r, - (BaseType::List(l), BaseType::List(r)) => { - are_base_types_equal_ignoring_nullability(&l.base, &r.base) - } - (BaseType::Named(_), BaseType::List(_)) | (BaseType::List(_), BaseType::Named(_)) => false, - } -} - -pub(crate) fn is_base_type_orderable(operand_type: &BaseType) -> bool { - match operand_type { - BaseType::Named(name) => { - name == "Int" || name == "Float" || name == "String" || name == "DateTime" - } - BaseType::List(l) => is_base_type_orderable(&l.base), - } -} - -pub(crate) fn get_base_named_type(ty: &Type) -> &str { - match &ty.base { - BaseType::Named(n) => n.as_ref(), - BaseType::List(l) => get_base_named_type(l.as_ref()), - } -} - -/// Check for scalar-only subtyping. -/// -/// Scalars don't have an inheritance structure, so they are able to be compared without a schema. -/// Callers of this function must guarantee that the passed types are either scalars or -/// (potentially multiply-nested) lists of scalars. -/// -/// This function considers types of different names to always be non-equal and unrelated: -/// neither is a subtype of the other. So given `interface Base` and `type Derived implements Base`, -/// that means `is_scalar_only_subtype(Base, Derived) == false`, since this function never sees -/// the definitions of `Base` and `Derived` as those are part of a schema which this function -/// never gets. -pub(crate) fn is_scalar_only_subtype(parent_type: &Type, maybe_subtype: &Type) -> bool { - // If the parent type is non-nullable, all its subtypes must be non-nullable as well. - // If the parent type is nullable, it can have both nullable and non-nullable subtypes. - if !parent_type.nullable && maybe_subtype.nullable { - return false; - } - - match (&parent_type.base, &maybe_subtype.base) { - (BaseType::Named(parent), BaseType::Named(subtype)) => parent == subtype, - (BaseType::List(parent_type), BaseType::List(maybe_subtype)) => { - is_scalar_only_subtype(parent_type, maybe_subtype) - } - (BaseType::Named(..), BaseType::List(..)) | (BaseType::List(..), BaseType::Named(..)) => { - false - } - } -} - -/// For two types, return a type that is a subtype of both, or None if no such type exists. -/// For example: -/// ```rust -/// use async_graphql_parser::types::Type; -/// use trustfall_core::ir::types::intersect_types; -/// -/// let left = Type::new("[String]!").unwrap(); -/// let right = Type::new("[String!]").unwrap(); -/// let result = intersect_types(&left, &right); -/// assert_eq!(Some(Type::new("[String!]!").unwrap()), result); -/// -/// let incompatible = Type::new("[Int]").unwrap(); -/// let result = intersect_types(&left, &incompatible); -/// assert_eq!(None, result); -/// ``` -pub fn intersect_types(left: &Type, right: &Type) -> Option { - let nullable = left.nullable && right.nullable; - - match (&left.base, &right.base) { - (BaseType::Named(l), BaseType::Named(r)) => { - if l == r { - Some(Type { base: left.base.clone(), nullable }) - } else { - None - } - } - (BaseType::List(left), BaseType::List(right)) => intersect_types(left, right) - .map(|inner| Type { base: BaseType::List(Box::new(inner)), nullable }), - (BaseType::Named(_), BaseType::List(_)) | (BaseType::List(_), BaseType::Named(_)) => None, - } -} - -/// Check if the given argument value is valid for the specified variable type. -/// -/// In particular, mixed integer types in a list are considered valid for types like `[Int]`. -/// ```rust -/// use async_graphql_parser::types::Type; -/// use trustfall_core::ir::{FieldValue, types::is_argument_type_valid}; -/// -/// let variable_type = Type::new("[Int]").unwrap(); -/// let argument_value = FieldValue::List([ -/// FieldValue::Int64(-1), -/// FieldValue::Uint64(1), -/// FieldValue::Null, -/// ].as_slice().into()); -/// assert!(is_argument_type_valid(&variable_type, &argument_value)); -/// ``` -pub fn is_argument_type_valid(variable_type: &Type, argument_value: &FieldValue) -> bool { - match argument_value { - FieldValue::Null => { - // This is a valid value only if this layer is nullable. - variable_type.nullable - } - FieldValue::Int64(_) | FieldValue::Uint64(_) => { - // This is a valid value only if the type is Int, ignoring nullability. - matches!(&variable_type.base, BaseType::Named(n) if n == "Int") - } - FieldValue::Float64(_) => { - // This is a valid value only if the type is Float, ignoring nullability. - matches!(&variable_type.base, BaseType::Named(n) if n == "Float") - } - FieldValue::String(_) => { - // This is a valid value only if the type is String, ignoring nullability. - matches!(&variable_type.base, BaseType::Named(n) if n == "String") - } - FieldValue::Boolean(_) => { - // This is a valid value only if the type is Boolean, ignoring nullability. - matches!(&variable_type.base, BaseType::Named(n) if n == "Boolean") - } - FieldValue::List(nested_values) => { - // This is a valid value only if the type is a list, and all the inner elements - // are valid instances of the type inside the list. - match &variable_type.base { - BaseType::List(inner) => { - nested_values.iter().all(|value| is_argument_type_valid(inner.as_ref(), value)) - } - BaseType::Named(_) => false, - } - } - FieldValue::Enum(_) => todo!(), - } -} - -#[cfg(test)] -mod tests { - use async_graphql_parser::types::Type; - use itertools::Itertools; - - use crate::ir::{types::is_argument_type_valid, FieldValue}; - - #[test] - fn null_values_are_only_valid_for_nullable_types() { - let nullable_types = [ - Type::new("Int").unwrap(), - Type::new("String").unwrap(), - Type::new("Boolean").unwrap(), - Type::new("[Int!]").unwrap(), - Type::new("[[Int!]!]").unwrap(), - ]; - let non_nullable_types = nullable_types - .iter() - .map(|t| Type { base: t.base.clone(), nullable: false }) - .collect_vec(); - - for nullable_type in &nullable_types { - assert!(is_argument_type_valid(nullable_type, &FieldValue::Null), "{}", nullable_type); - } - for non_nullable_type in &non_nullable_types { - assert!( - !is_argument_type_valid(non_nullable_type, &FieldValue::Null), - "{}", - non_nullable_type - ); - } - } - - #[test] - fn int_values_are_valid_only_for_int_type_regardless_of_nullability() { - let matching_types = [Type::new("Int").unwrap(), Type::new("Int!").unwrap()]; - let non_matching_types = [ - Type::new("String").unwrap(), - Type::new("[Int!]").unwrap(), - Type::new("[Int!]!").unwrap(), - Type::new("[[Int!]!]").unwrap(), - ]; - let values = [ - FieldValue::Int64(-42), - FieldValue::Int64(0), - FieldValue::Uint64(0), - FieldValue::Uint64((i64::MAX as u64) + 1), - ]; - - for value in &values { - for matching_type in &matching_types { - assert!(is_argument_type_valid(matching_type, value), "{matching_type} {value:?}",); - } - for non_matching_type in &non_matching_types { - assert!( - !is_argument_type_valid(non_matching_type, value), - "{non_matching_type} {value:?}", - ); - } - } - } - - #[test] - fn string_values_are_valid_only_for_string_type_regardless_of_nullability() { - let matching_types = [Type::new("String").unwrap(), Type::new("String!").unwrap()]; - let non_matching_types = [ - Type::new("Int").unwrap(), - Type::new("[String!]").unwrap(), - Type::new("[String!]!").unwrap(), - Type::new("[[String!]!]").unwrap(), - ]; - let values = [ - FieldValue::String("".into()), // empty string is not the same value as null - FieldValue::String("test string".into()), - ]; - - for value in &values { - for matching_type in &matching_types { - assert!(is_argument_type_valid(matching_type, value), "{matching_type} {value:?}",); - } - for non_matching_type in &non_matching_types { - assert!( - !is_argument_type_valid(non_matching_type, value), - "{non_matching_type} {value:?}", - ); - } - } - } - - #[test] - fn boolean_values_are_valid_only_for_boolean_type_regardless_of_nullability() { - let matching_types = [Type::new("Boolean").unwrap(), Type::new("Boolean!").unwrap()]; - let non_matching_types = [ - Type::new("Int").unwrap(), - Type::new("[Boolean!]").unwrap(), - Type::new("[Boolean!]!").unwrap(), - Type::new("[[Boolean!]!]").unwrap(), - ]; - let values = [FieldValue::Boolean(false), FieldValue::Boolean(true)]; - - for value in &values { - for matching_type in &matching_types { - assert!(is_argument_type_valid(matching_type, value), "{matching_type} {value:?}",); - } - for non_matching_type in &non_matching_types { - assert!( - !is_argument_type_valid(non_matching_type, value), - "{non_matching_type} {value:?}", - ); - } - } - } - - #[test] - fn list_types_correctly_check_contents_of_list() { - let non_nullable_contents_matching_types = - vec![Type::new("[Int!]").unwrap(), Type::new("[Int!]!").unwrap()]; - let nullable_contents_matching_types = - vec![Type::new("[Int]").unwrap(), Type::new("[Int]!").unwrap()]; - let non_matching_types = [ - Type::new("Int").unwrap(), - Type::new("Int!").unwrap(), - Type::new("[String!]").unwrap(), - Type::new("[String!]!").unwrap(), - Type::new("[[String!]!]").unwrap(), - ]; - let non_nullable_values = [ - FieldValue::List((1..3).map(FieldValue::Int64).collect_vec().into()), - FieldValue::List((1..3).map(FieldValue::Uint64).collect_vec().into()), - FieldValue::List( - vec![ - // Integer-typed but non-homogeneous FieldValue entries are okay. - FieldValue::Int64(-42), - FieldValue::Uint64(64), - ] - .into(), - ), - ]; - let nullable_values = [ - FieldValue::List( - vec![FieldValue::Int64(1), FieldValue::Null, FieldValue::Int64(2)].into(), - ), - FieldValue::List(vec![FieldValue::Null, FieldValue::Uint64(42)].into()), - FieldValue::List( - vec![ - // Integer-typed but non-homogeneous FieldValue entries are okay. - FieldValue::Int64(-1), - FieldValue::Uint64(1), - FieldValue::Null, - ] - .into(), - ), - ]; - - for value in &non_nullable_values { - // Values without nulls match both the nullable and the non-nullable types. - for matching_type in &nullable_contents_matching_types { - assert!(is_argument_type_valid(matching_type, value), "{matching_type} {value:?}",); - } - for matching_type in &non_nullable_contents_matching_types { - assert!(is_argument_type_valid(matching_type, value), "{matching_type} {value:?}",); - } - - // Regardless of nulls, these types don't match. - for non_matching_type in &non_matching_types { - assert!( - !is_argument_type_valid(non_matching_type, value), - "{non_matching_type} {value:?}", - ); - } - } - - for value in &nullable_values { - // Nullable values match only the nullable types. - for matching_type in &nullable_contents_matching_types { - assert!(is_argument_type_valid(matching_type, value), "{matching_type} {value:?}",); - } - - // The nullable values don't match the non-nullable types. - for non_matching_type in &non_nullable_contents_matching_types { - assert!( - !is_argument_type_valid(non_matching_type, value), - "{non_matching_type} {value:?}", - ); - } - - // Regardless of nulls, these types don't match. - for non_matching_type in &non_matching_types { - assert!( - !is_argument_type_valid(non_matching_type, value), - "{non_matching_type} {value:?}", - ); - } - } - } -} diff --git a/trustfall_core/src/ir/types/base.rs b/trustfall_core/src/ir/types/base.rs new file mode 100644 index 00000000..2b9e39a4 --- /dev/null +++ b/trustfall_core/src/ir/types/base.rs @@ -0,0 +1,813 @@ +use std::{ + fmt::{Display, Formatter}, + sync::{Arc, OnceLock}, +}; + +use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; + +use crate::ir::FieldValue; + +/// A representation of a Trustfall type, independent of which parser or query syntax we're using. +/// Equivalent in expressiveness to GraphQL types, but not explicitly tied to a GraphQL library. +#[derive(Clone, PartialEq, Eq)] +pub struct Type { + base: Arc, + modifiers: Modifiers, +} + +impl std::fmt::Debug for Type { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct(&format!("Type (represents {self})")) + .field("base", &self.base) + .field("modifiers", &self.modifiers) + .finish() + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct Modifiers { + mask: u64, // space for ~30 levels of list nesting +} + +impl Modifiers { + const NON_NULLABLE_MASK: u64 = 1; + const LIST_MASK: u64 = 2; + const MAX_LIST_DEPTH: u64 = 30; + + /// Represents the leftmost list bit that can be set before adding a new list will overflow. + /// `(Self::MAX_LIST_DEPTH - 1)` because we start shifted over once. + const MAX_LIST_DEPTH_MASK: u64 = Self::LIST_MASK << ((Self::MAX_LIST_DEPTH - 1) * 2); + + /// Returns an optionally-nullable, non-list modifiers value. + fn new(nullable: bool) -> Self { + Self { mask: if nullable { 0 } else { Self::NON_NULLABLE_MASK } } + } + + #[inline] + fn nullable(&self) -> bool { + (self.mask & Self::NON_NULLABLE_MASK) == 0 + } + + #[inline] + fn is_list(&self) -> bool { + (self.mask & Self::LIST_MASK) != 0 + } + + #[inline] + fn as_list(&self) -> Option { + self.is_list().then_some(Modifiers { mask: self.mask >> 2 }) + } + + #[inline] + fn at_max_list_depth(&self) -> bool { + (self.mask & Self::MAX_LIST_DEPTH_MASK) == Self::MAX_LIST_DEPTH_MASK + } +} + +#[derive(Debug, Clone)] +pub struct TypeParseError { + invalid_type: String, +} + +impl Display for TypeParseError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{} is not a valid Type representation", &self.invalid_type) + } +} + +impl std::error::Error for TypeParseError {} + +static STRING_TYPE_NAME: &str = "String"; +static INT_TYPE_NAME: &str = "Int"; +static FLOAT_TYPE_NAME: &str = "Float"; +static BOOLEAN_TYPE_NAME: &str = "Boolean"; + +static STRING_TYPE_NAME_ARC: OnceLock> = OnceLock::new(); +static INT_TYPE_NAME_ARC: OnceLock> = OnceLock::new(); + +#[inline] +fn get_string_type_name_arc() -> &'static Arc { + STRING_TYPE_NAME_ARC.get_or_init(|| Arc::from(STRING_TYPE_NAME)) +} + +#[inline] +fn get_int_type_name_arc() -> &'static Arc { + INT_TYPE_NAME_ARC.get_or_init(|| Arc::from(INT_TYPE_NAME)) +} + +impl Type { + /// Parses a string type representation into a new [`Type`]. + /// + /// # Example + /// ``` + /// use trustfall_core::ir::Type; + /// + /// let ty = Type::parse("[String!]!").unwrap(); + /// assert_eq!(ty.to_string(), "[String!]!"); + /// + /// assert_eq!(Type::parse("[String!]").unwrap().to_string(), "[String!]"); + /// ``` + pub fn parse(ty: &str) -> Result { + async_graphql_parser::types::Type::new(ty) + .ok_or_else(|| TypeParseError { invalid_type: ty.to_string() }) + .map(|ty| Self::from_type(&ty)) + } + + fn from_name_and_modifiers(base_type: &str, modifiers: Modifiers) -> Self { + match base_type { + "String" => Self { base: Arc::clone(get_string_type_name_arc()), modifiers }, + "Int" => Self { base: Arc::clone(get_int_type_name_arc()), modifiers }, + "Float" => Self { base: Arc::from(FLOAT_TYPE_NAME), modifiers }, + "Boolean" => Self { base: Arc::from(BOOLEAN_TYPE_NAME), modifiers }, + _ => Self { base: base_type.to_string().into(), modifiers }, + } + } + + /// Creates an individual [`Type`], not a list. + /// + /// # Example + /// ``` + /// use trustfall_core::ir::Type; + /// + /// let nullable = false; + /// let ty = Type::new_named_type("String", nullable); + /// + /// assert_eq!(ty.to_string(), "String!"); + /// assert_eq!(ty, Type::parse("String!").unwrap()); + /// ``` + pub fn new_named_type(base_type: &str, nullable: bool) -> Self { + let modifiers = Modifiers::new(nullable); + Self::from_name_and_modifiers(base_type, modifiers) + } + + /// Creates a new list layer on a [`Type`]. + /// + /// # Example + /// ``` + /// use trustfall_core::ir::Type; + /// + /// let inner_nullable = false; + /// let inner_ty = Type::new_named_type("String", inner_nullable); + /// + /// let outer_nullable = true; + /// let ty = Type::new_list_type(inner_ty, outer_nullable); + /// + /// assert_eq!(ty.to_string(), "[String!]"); + /// assert_eq!(ty, Type::parse("[String!]").unwrap()); + /// ``` + pub fn new_list_type(inner_type: Self, nullable: bool) -> Self { + if inner_type.modifiers.at_max_list_depth() { + panic!("too many nested lists: {inner_type}"); + } + + let mut new_mask = (inner_type.modifiers.mask << 2) | Modifiers::LIST_MASK; + + if !nullable { + new_mask |= Modifiers::NON_NULLABLE_MASK; + } + + Self { base: inner_type.base, modifiers: Modifiers { mask: new_mask } } + } + + /// Returns a new type that is the same as this one, but with the passed nullability. + /// + /// # Example + /// ``` + /// use trustfall_core::ir::Type; + /// + /// let nullable_ty = Type::parse("Int").unwrap(); + /// assert_eq!(nullable_ty.nullable(), true); + /// let non_nullable_ty = nullable_ty.with_nullability(false); + /// assert_eq!(non_nullable_ty.nullable(), false); + /// + /// // The original type is unchanged. + /// assert_eq!(nullable_ty.nullable(), true); + /// ``` + pub fn with_nullability(&self, nullable: bool) -> Self { + let mut new = self.clone(); + if nullable { + new.modifiers.mask &= !Modifiers::NON_NULLABLE_MASK; + } else { + new.modifiers.mask |= Modifiers::NON_NULLABLE_MASK; + } + new + } + + /// Returns whether this type is nullable, at the top level, see example. + /// + /// # Example + /// ``` + /// use trustfall_core::ir::Type; + /// + /// let nullable_ty = Type::parse("[Int!]").unwrap(); + /// assert_eq!(nullable_ty.nullable(), true); // the list is nullable + /// + /// let nullable_ty = Type::parse("Int!").unwrap(); + /// assert_eq!(nullable_ty.nullable(), false); // the `Int` is nonnullable + /// ``` + pub fn nullable(&self) -> bool { + self.modifiers.nullable() + } + + /// Returns whether the type is a list or not. + /// + /// # Example + /// ``` + /// use trustfall_core::ir::Type; + /// + /// let non_null_int_arr = Type::parse("[Int!]").unwrap(); + /// assert_eq!(non_null_int_arr.is_list(), true); + /// + /// let non_null_int = Type::parse("Int!").unwrap(); + /// assert_eq!(non_null_int.is_list(), false); + /// ``` + pub fn is_list(&self) -> bool { + self.modifiers.is_list() + } + + /// Returns the type inside the outermost list of this type if it is a list, otherwise returns `None`. + /// + /// # Example + /// ``` + /// use trustfall_core::ir::Type; + /// + /// let non_null_int_arr = Type::parse("[Int!]").unwrap(); + /// let non_null_int = Type::parse("Int!").unwrap(); + /// assert_eq!(non_null_int_arr.as_list(), Some(non_null_int.clone())); + /// assert_eq!(non_null_int.as_list(), None); + /// ``` + pub fn as_list(&self) -> Option { + Some(Self { base: Arc::clone(&self.base), modifiers: self.modifiers.as_list()? }) + } + + /// Returns the type of the elements of the first individual type found inside this type. + /// + /// # Example + /// ``` + /// use trustfall_core::ir::Type; + /// + /// let int_list_ty = Type::parse("[Int!]").unwrap(); + /// assert_eq!(int_list_ty.base_type(), "Int"); + /// + /// let string_ty = Type::parse("String!").unwrap(); + /// assert_eq!(string_ty.base_type(), "String"); + pub fn base_type(&self) -> &str { + &self.base + } + + /// Convert a [`async_graphql_parser::types::Type`] to a [`Type`]. + pub(crate) fn from_type(ty: &async_graphql_parser::types::Type) -> Type { + let mut base = &ty.base; + + let mut mask = if ty.nullable { 0 } else { Modifiers::NON_NULLABLE_MASK }; + + let mut i = 0; + + while let async_graphql_parser::types::BaseType::List(ty_inside_list) = base { + mask |= Modifiers::LIST_MASK << i; + i += 2; + if i > Modifiers::MAX_LIST_DEPTH * 2 { + panic!("too many nested lists: {ty:?}"); + } + if !ty_inside_list.nullable { + mask |= Modifiers::NON_NULLABLE_MASK << i; + } + base = &ty_inside_list.base; + } + + let async_graphql_parser::types::BaseType::Named(name) = base else { + unreachable!( + "should be impossible to get a non-named type after \ +looping through all list types: {ty:?} {base:?}" + ) + }; + + Self::from_name_and_modifiers(name.as_str(), Modifiers { mask }) + } + + /// For two types, return a type that is a subtype of both, or None if no such type exists. + /// For example: + /// ```rust + /// use trustfall_core::ir::types::Type; + /// + /// let left = Type::parse("[String]!").unwrap(); + /// let right = Type::parse("[String!]").unwrap(); + /// let result = left.intersect(&right); + /// assert_eq!(Some(Type::parse("[String!]!").unwrap()), result); + /// + /// let incompatible = Type::parse("[Int]").unwrap(); + /// let result = left.intersect(&incompatible); + /// assert_eq!(None, result); + /// ``` + pub fn intersect(&self, other: &Self) -> Option { + if self.base_type() != other.base_type() { + return None; + } + + self.intersect_impl(other) + } + + fn intersect_impl(&self, other: &Self) -> Option { + let nullable = self.nullable() && other.nullable(); + + match (self.as_list(), other.as_list()) { + (None, None) => Some(Type::new_named_type(self.base_type(), nullable)), + (Some(left), Some(right)) => { + left.intersect_impl(&right).map(|inner| Type::new_list_type(inner, nullable)) + } + _ => None, + } + } + + pub(crate) fn equal_ignoring_nullability(&self, other: &Self) -> bool { + if self.base_type() != other.base_type() { + return false; + } + + match (self.as_list(), other.as_list()) { + (None, None) => true, + (Some(left), Some(right)) => left.equal_ignoring_nullability(&right), + _ => false, + } + } + + /// Check if the given value is allowed by the specified type. + /// + /// In particular, mixed integer types in a list are considered valid for types like `[Int]`. + /// ```rust + /// use trustfall_core::ir::{FieldValue, Type}; + /// + /// let ty = Type::parse("[Int]").unwrap(); + /// let value = FieldValue::List([ + /// FieldValue::Int64(-1), + /// FieldValue::Uint64(1), + /// FieldValue::Null, + /// ].as_slice().into()); + /// assert!(ty.is_valid_value(&value)); + /// ``` + pub fn is_valid_value(&self, value: &FieldValue) -> bool { + match value { + FieldValue::Null => { + // This is a valid value only if this layer is nullable. + self.nullable() + } + FieldValue::Int64(_) | FieldValue::Uint64(_) => { + // This is a valid value only if the type is Int, ignoring nullability. + !self.is_list() && self.base_type() == "Int" + } + FieldValue::Float64(_) => { + // This is a valid value only if the type is Float, ignoring nullability. + !self.is_list() && self.base_type() == "Float" + } + FieldValue::String(_) => { + // This is a valid value only if the type is String, ignoring nullability. + !self.is_list() && self.base_type() == "String" + } + FieldValue::Boolean(_) => { + // This is a valid value only if the type is Boolean, ignoring nullability. + !self.is_list() && self.base_type() == "Boolean" + } + FieldValue::List(contents) => { + // This is a valid value only if the type is a list, and all the inner elements + // are valid instances of the type inside the list. + if let Some(content_type) = self.as_list() { + contents.iter().all(|inner| content_type.is_valid_value(inner)) + } else { + false + } + } + FieldValue::Enum(_) => { + unimplemented!("enum values are not currently supported: {self} {value:?}") + } + } + } + + /// Returns `true` if values of this type can be compared using operators like `<`. + /// + /// In Rust terms, this checks for `PartialOrd` on this `Type`. + /// + /// Lists (including nested lists) are orderable if the type they contain is orderable. + /// Lists use lexicographic ordering, i.e. `[1, 2, 3] < [3]`. + pub(crate) fn is_orderable(&self) -> bool { + matches!(self.base_type(), "Int" | "Float" | "String") + } + + /// Check for scalar-only subtyping. + /// + /// Scalars don't have an inheritance structure, so they are able to be compared without a schema. + /// Callers of this function must guarantee that the passed types are either scalars or + /// (potentially multiply-nested) lists of scalars. + /// + /// This function considers types of different names to always be non-equal and unrelated: + /// neither is a subtype of the other. So given `interface Base` and `type Derived implements Base`, + /// that means `is_scalar_only_subtype(Base, Derived) == false`, since this function never sees + /// the definitions of `Base` and `Derived` as those are part of a schema which this function + /// never gets. + pub(crate) fn is_scalar_only_subtype(&self, maybe_subtype: &Self) -> bool { + // If the parent type is non-nullable, all its subtypes must be non-nullable as well. + // If the parent type is nullable, it can have both nullable and non-nullable subtypes. + if !self.nullable() && maybe_subtype.nullable() { + return false; + } + + // If the base types don't match, there can't be a subtyping relationship here. + // Recall that callers are required to make sure only scalar / nested-lists-of-scalar types + // are passed into this function. + if self.base_type() != maybe_subtype.base_type() { + return false; + } + + match (self.as_list(), maybe_subtype.as_list()) { + (None, None) => true, + (Some(parent), Some(maybe_subtype)) => parent.is_scalar_only_subtype(&maybe_subtype), + _ => false, + } + } +} + +impl Display for Type { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // left + { + let mut current = Some(self.modifiers.clone()); + while let Some(mods) = current { + if mods.is_list() { + write!(f, "[")?; + } + current = mods.as_list(); + } + } + + write!(f, "{}", self.base)?; + + let mut builder = String::with_capacity(self.modifiers.mask.count_ones().max(4) as usize); + + // right + { + let mut current = Some(self.modifiers.clone()); + while let Some(mods) = current { + if !mods.nullable() { + builder.push('!'); + } + if mods.is_list() { + builder.push(']'); + } + current = mods.as_list(); + } + write!(f, "{}", builder.chars().rev().collect::())?; + } + + Ok(()) + } +} + +impl Serialize for Type { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for Type { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct TypeDeserializer; + + impl<'de> Visitor<'de> for TypeDeserializer { + type Value = Type; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + formatter.write_str("GraphQL type") + } + + fn visit_str(self, s: &str) -> Result + where + E: serde::de::Error, + { + Type::parse(s).map_err(|err| serde::de::Error::custom(err)) + } + } + + deserializer.deserialize_str(TypeDeserializer) + } +} + +#[cfg(test)] +mod test { + use itertools::Itertools; + + use crate::ir::{FieldValue, Type}; + + use super::Modifiers; + + #[test] + fn max_allowed_nested_lists_mask_representation() { + let type_str = format!( + "{}String{}", + "[".repeat(Modifiers::MAX_LIST_DEPTH as usize), + "]".repeat(Modifiers::MAX_LIST_DEPTH as usize) + ); + let type_modifiers = Type::parse(&type_str).unwrap().modifiers; + assert_eq!( + format!("{:b}", type_modifiers.mask), + "101010101010101010101010101010101010101010101010101010101010" + ); + } + + #[test] + fn max_allowed_nested_lists_is_at_max_list_depth() { + let type_str = format!( + "{}String{}", + "[".repeat(Modifiers::MAX_LIST_DEPTH as usize), + "]".repeat(Modifiers::MAX_LIST_DEPTH as usize) + ); + let type_modifiers = Type::parse(&type_str).unwrap().modifiers; + assert!(type_modifiers.at_max_list_depth()); + } + + #[test] + #[should_panic(expected = "too many nested lists")] + fn too_many_nested_lists_via_type_new() { + let type_str = format!( + "{}String!{}", + "[".repeat(Modifiers::MAX_LIST_DEPTH as usize + 1), + "]".repeat(Modifiers::MAX_LIST_DEPTH as usize + 1) + ); + let _ = Type::parse(&type_str); // will panic during modifier mask creation + } + + #[test] + #[should_panic(expected = "too many nested lists")] + fn too_many_nested_lists_via_new_list_type() { + let mut constructed_type = Type::new_named_type("String", false); + for _ in 0..=Modifiers::MAX_LIST_DEPTH { + constructed_type = Type::new_list_type(constructed_type, false); + } + + // will panic during new modifier mask creation for new list type + Type::new_list_type(constructed_type, false); + } + + #[test] + fn max_allowed_nested_lists_with_nonnull_on_last_list_mask_representation() { + let type_str = format!( + "{}String{}!", + "[".repeat(Modifiers::MAX_LIST_DEPTH as usize), + "]".repeat(Modifiers::MAX_LIST_DEPTH as usize) + ); + let type_modifiers = Type::parse(&type_str).unwrap().modifiers; + assert_eq!( + format!("{:b}", type_modifiers.mask), + "101010101010101010101010101010101010101010101010101010101011" + ); + } + + #[test] + fn max_allowed_nested_lists_with_nonnull_on_innermost_type_mask_representation() { + let type_str = format!( + "{}String!{}", + "[".repeat(Modifiers::MAX_LIST_DEPTH as usize), + "]".repeat(Modifiers::MAX_LIST_DEPTH as usize) + ); + let type_modifiers = Type::parse(&type_str).unwrap().modifiers; + assert_eq!( + format!("{:b}", type_modifiers.mask), + "1101010101010101010101010101010101010101010101010101010101010" + ); + } + + #[test] + fn max_allowed_nested_lists_with_non_null_on_every_list_and_inner_type() { + let type_str = format!( + "{}String!{}", + "[".repeat(Modifiers::MAX_LIST_DEPTH as usize), + "]!".repeat(Modifiers::MAX_LIST_DEPTH as usize) + ); + let type_modifiers = Type::parse(&type_str).unwrap().modifiers; + assert!(type_modifiers.at_max_list_depth()); + } + + #[test] + fn base_types_equal_ignoring_nullability() { + let test_data = [ + (Type::parse("String"), Type::parse("String"), true), + (Type::parse("String!"), Type::parse("String!"), true), + (Type::parse("Int"), Type::parse("Int!"), true), + (Type::parse("[String!]"), Type::parse("[String]!"), true), + (Type::parse("[String]"), Type::parse("[String!]!"), true), + (Type::parse("String"), Type::parse("Int"), false), + (Type::parse("String!"), Type::parse("Int!"), false), + (Type::parse("[String]"), Type::parse("String"), false), + (Type::parse("[String]!"), Type::parse("String!"), false), + (Type::parse("[String!]"), Type::parse("String!"), false), + ]; + + for (left, right, expected) in test_data { + let left = left.expect("not a valid type"); + let right = right.expect("not a valid type"); + assert_eq!(left.equal_ignoring_nullability(&right), expected, "{left} {right}"); + assert_eq!( + right.equal_ignoring_nullability(&left), + expected, + "commutativity violation in: {right} {left}" + ); + } + } + + #[test] + fn null_values_are_only_valid_for_nullable_types() { + let nullable_types = [ + Type::parse("Int").unwrap(), + Type::parse("String").unwrap(), + Type::parse("Boolean").unwrap(), + Type::parse("[Int!]").unwrap(), + Type::parse("[[Int!]!]").unwrap(), + ]; + let non_nullable_types = + nullable_types.iter().map(|t| t.with_nullability(false)).collect_vec(); + + for nullable_type in &nullable_types { + assert!(nullable_type.is_valid_value(&FieldValue::Null), "{}", nullable_type); + } + for non_nullable_type in &non_nullable_types { + assert!(!non_nullable_type.is_valid_value(&FieldValue::Null), "{}", non_nullable_type); + } + } + + #[test] + fn int_values_are_valid_only_for_int_type_regardless_of_nullability() { + let matching_types = [Type::parse("Int").unwrap(), Type::parse("Int!").unwrap()]; + let non_matching_types = [ + Type::parse("String").unwrap(), + Type::parse("[Int!]").unwrap(), + Type::parse("[Int!]!").unwrap(), + Type::parse("[[Int!]!]").unwrap(), + ]; + let values = [ + FieldValue::Int64(-42), + FieldValue::Int64(0), + FieldValue::Uint64(0), + FieldValue::Uint64((i64::MAX as u64) + 1), + ]; + + for value in &values { + for matching_type in &matching_types { + assert!(matching_type.is_valid_value(value), "{matching_type} {value:?}",); + } + for non_matching_type in &non_matching_types { + assert!(!non_matching_type.is_valid_value(value), "{non_matching_type} {value:?}",); + } + } + } + + #[test] + fn string_values_are_valid_only_for_string_type_regardless_of_nullability() { + let matching_types = [Type::parse("String").unwrap(), Type::parse("String!").unwrap()]; + let non_matching_types = [ + Type::parse("Int").unwrap(), + Type::parse("[String!]").unwrap(), + Type::parse("[String!]!").unwrap(), + Type::parse("[[String!]!]").unwrap(), + ]; + let values = [ + FieldValue::String("".into()), // empty string is not the same value as null + FieldValue::String("test string".into()), + ]; + + for value in &values { + for matching_type in &matching_types { + assert!(matching_type.is_valid_value(value), "{matching_type} {value:?}",); + } + for non_matching_type in &non_matching_types { + assert!(!non_matching_type.is_valid_value(value), "{non_matching_type} {value:?}",); + } + } + } + + #[test] + fn boolean_values_are_valid_only_for_boolean_type_regardless_of_nullability() { + let matching_types = [Type::parse("Boolean").unwrap(), Type::parse("Boolean!").unwrap()]; + let non_matching_types = [ + Type::parse("Int").unwrap(), + Type::parse("[Boolean!]").unwrap(), + Type::parse("[Boolean!]!").unwrap(), + Type::parse("[[Boolean!]!]").unwrap(), + ]; + let values = [FieldValue::Boolean(false), FieldValue::Boolean(true)]; + + for value in &values { + for matching_type in &matching_types { + assert!(matching_type.is_valid_value(value), "{matching_type} {value:?}",); + } + for non_matching_type in &non_matching_types { + assert!(!non_matching_type.is_valid_value(value), "{non_matching_type} {value:?}",); + } + } + } + + #[test] + fn list_types_correctly_check_contents_of_list() { + let non_nullable_contents_matching_types = + vec![Type::parse("[Int!]").unwrap(), Type::parse("[Int!]!").unwrap()]; + let nullable_contents_matching_types = + vec![Type::parse("[Int]").unwrap(), Type::parse("[Int]!").unwrap()]; + let non_matching_types = [ + Type::parse("Int").unwrap(), + Type::parse("Int!").unwrap(), + Type::parse("[String!]").unwrap(), + Type::parse("[String!]!").unwrap(), + Type::parse("[[String!]!]").unwrap(), + ]; + let non_nullable_values = [ + FieldValue::List((1..3).map(FieldValue::Int64).collect_vec().into()), + FieldValue::List((1..3).map(FieldValue::Uint64).collect_vec().into()), + FieldValue::List( + vec![ + // Integer-typed but non-homogeneous FieldValue entries are okay. + FieldValue::Int64(-42), + FieldValue::Uint64(64), + ] + .into(), + ), + ]; + let nullable_values = [ + FieldValue::List( + vec![FieldValue::Int64(1), FieldValue::Null, FieldValue::Int64(2)].into(), + ), + FieldValue::List(vec![FieldValue::Null, FieldValue::Uint64(42)].into()), + FieldValue::List( + vec![ + // Integer-typed but non-homogeneous FieldValue entries are okay. + FieldValue::Int64(-1), + FieldValue::Uint64(1), + FieldValue::Null, + ] + .into(), + ), + ]; + + for value in &non_nullable_values { + // Values without nulls match both the nullable and the non-nullable types. + for matching_type in &nullable_contents_matching_types { + assert!(matching_type.is_valid_value(value), "{matching_type} {value:?}",); + } + for matching_type in &non_nullable_contents_matching_types { + assert!(matching_type.is_valid_value(value), "{matching_type} {value:?}",); + } + + // Regardless of nulls, these types don't match. + for non_matching_type in &non_matching_types { + assert!(!non_matching_type.is_valid_value(value), "{non_matching_type} {value:?}",); + } + } + + for value in &nullable_values { + // Nullable values match only the nullable types. + for matching_type in &nullable_contents_matching_types { + assert!(matching_type.is_valid_value(value), "{matching_type} {value:?}",); + } + + // The nullable values don't match the non-nullable types. + for non_matching_type in &non_nullable_contents_matching_types { + assert!(!non_matching_type.is_valid_value(value), "{non_matching_type} {value:?}",); + } + + // Regardless of nulls, these types don't match. + for non_matching_type in &non_matching_types { + assert!(!non_matching_type.is_valid_value(value), "{non_matching_type} {value:?}",); + } + } + } + + #[test] + fn round_trip_serialization_and_creation() { + let test_data = [ + "String!", + "[Int]", + "[Float!]", + "[Boolean]!", + "[[[[String]!]]!]", + "[[[[[[[[String]!]]!]!]!]]]", + ]; + + for item in test_data { + let ty = Type::parse(item).expect("valid type"); + assert_eq!(item, &format!("{ty}")); + + let nullable_list = Type::new_list_type(ty.clone(), true); + let non_nullable_list = Type::new_list_type(ty.clone(), false); + + assert_eq!(format!("[{item}]"), format!("{}", &nullable_list)); + assert_eq!(format!("[{item}]!"), format!("{}", &non_nullable_list)); + + assert_eq!(nullable_list, non_nullable_list.with_nullability(true)); + assert_eq!(nullable_list, nullable_list.with_nullability(true)); + assert_eq!(non_nullable_list, nullable_list.with_nullability(false)); + assert_eq!(non_nullable_list, non_nullable_list.with_nullability(false)); + } + } +} diff --git a/trustfall_core/src/ir/types/mod.rs b/trustfall_core/src/ir/types/mod.rs new file mode 100644 index 00000000..e6cbc830 --- /dev/null +++ b/trustfall_core/src/ir/types/mod.rs @@ -0,0 +1,5 @@ +mod base; +mod named_typed; + +pub use base::Type; +pub use named_typed::NamedTypedValue; diff --git a/trustfall_core/src/ir/types/named_typed.rs b/trustfall_core/src/ir/types/named_typed.rs new file mode 100644 index 00000000..a56d2b0d --- /dev/null +++ b/trustfall_core/src/ir/types/named_typed.rs @@ -0,0 +1,97 @@ +use std::fmt::Debug; + +use super::{ + super::{ + Argument, ContextField, FieldRef, FoldSpecificField, FoldSpecificFieldKind, LocalField, + VariableRef, + }, + Type, +}; + +pub trait NamedTypedValue: Debug + Clone + PartialEq + Eq { + fn typed(&self) -> &Type; + + fn named(&self) -> &str; +} + +impl NamedTypedValue for LocalField { + fn typed(&self) -> &Type { + &self.field_type + } + + fn named(&self) -> &str { + self.field_name.as_ref() + } +} + +impl NamedTypedValue for ContextField { + fn typed(&self) -> &Type { + &self.field_type + } + + fn named(&self) -> &str { + self.field_name.as_ref() + } +} + +impl NamedTypedValue for FoldSpecificField { + fn typed(&self) -> &Type { + self.kind.field_type() + } + + fn named(&self) -> &str { + self.kind.field_name() + } +} + +impl NamedTypedValue for FoldSpecificFieldKind { + fn typed(&self) -> &Type { + self.field_type() + } + + fn named(&self) -> &str { + self.field_name() + } +} + +impl NamedTypedValue for VariableRef { + fn typed(&self) -> &Type { + &self.variable_type + } + + fn named(&self) -> &str { + &self.variable_name + } +} + +impl NamedTypedValue for FieldRef { + fn typed(&self) -> &Type { + match self { + FieldRef::ContextField(c) => c.typed(), + FieldRef::FoldSpecificField(f) => f.kind.typed(), + } + } + + fn named(&self) -> &str { + match self { + FieldRef::ContextField(c) => c.named(), + FieldRef::FoldSpecificField(f) => f.kind.named(), + } + } +} + +impl NamedTypedValue for Argument { + fn typed(&self) -> &Type { + match self { + Argument::Tag(t) => t.typed(), + Argument::Variable(v) => v.typed(), + } + } + + fn named(&self) -> &str { + match self { + Argument::Tag(t) => t.named(), + Argument::Variable(v) => v.named(), + } + } +} diff --git a/trustfall_core/src/schema/adapter/mod.rs b/trustfall_core/src/schema/adapter/mod.rs index c8a6d2a9..7fb7177c 100644 --- a/trustfall_core/src/schema/adapter/mod.rs +++ b/trustfall_core/src/schema/adapter/mod.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use async_graphql_parser::types::{ - BaseType, FieldDefinition, InputValueDefinition, Type, TypeDefinition, TypeKind, + BaseType, FieldDefinition, InputValueDefinition, TypeDefinition, TypeKind, }; use crate::{ @@ -11,7 +11,7 @@ use crate::{ CandidateValue, ContextIterator, ContextOutcomeIterator, ResolveEdgeInfo, ResolveInfo, Typename, VertexInfo, VertexIterator, }, - ir::{types::get_base_named_type, EdgeParameters, FieldValue, TransparentValue}, + ir::{EdgeParameters, FieldValue, TransparentValue, Type}, }; use super::Schema; @@ -206,17 +206,12 @@ pub struct Property<'a> { parent: &'a TypeDefinition, name: &'a str, docs: Option<&'a str>, - type_: &'a Type, + type_: Type, } impl<'a> Property<'a> { #[inline(always)] - fn new( - parent: &'a TypeDefinition, - name: &'a str, - docs: Option<&'a str>, - type_: &'a Type, - ) -> Self { + fn new(parent: &'a TypeDefinition, name: &'a str, docs: Option<&'a str>, type_: Type) -> Self { Self { parent, name, docs, type_ } } } @@ -404,7 +399,8 @@ impl<'a> crate::interpreter::Adapter<'a> for SchemaAdapter<'a> { "Edge" => match edge_name.as_ref() { "target" => resolve_neighbors_with(contexts, move |vertex| { let vertex = vertex.as_edge().expect("not an Edge"); - let target_type = get_base_named_type(&vertex.defn.ty.node); + let edge_type = Type::from_type(&vertex.defn.ty.node); + let target_type = edge_type.base_type(); Box::new( schema .vertex_types @@ -507,8 +503,8 @@ fn resolve_vertex_type_property_edge<'a>( let parent_defn = vertex.defn; Box::new(fields.iter().filter_map(move |p| { let field = &p.node; - let field_ty = &field.ty.node; - let base_ty = get_base_named_type(field_ty); + let field_ty = Type::from_type(&field.ty.node); + let base_ty = field_ty.base_type(); if !schema.vertex_types.contains_key(base_ty) { Some(SchemaVertex::Property(Property::new( @@ -533,8 +529,8 @@ fn resolve_vertex_type_edge_edge<'a>( Box::new(fields.iter().filter_map(move |p| { let field = &p.node; - let field_ty = &field.ty.node; - let base_ty = get_base_named_type(field_ty); + let field_ty = Type::from_type(&field.ty.node); + let base_ty = field_ty.base_type(); if schema.vertex_types.contains_key(base_ty) { Some(SchemaVertex::Edge(Edge::new(field))) diff --git a/trustfall_core/src/schema/mod.rs b/trustfall_core/src/schema/mod.rs index 8b4bfd76..5a57d272 100644 --- a/trustfall_core/src/schema/mod.rs +++ b/trustfall_core/src/schema/mod.rs @@ -9,7 +9,7 @@ use async_graphql_parser::{ parse_schema, types::{ BaseType, DirectiveDefinition, FieldDefinition, ObjectType, SchemaDefinition, - ServiceDocument, Type, TypeDefinition, TypeKind, TypeSystemDefinition, + ServiceDocument, TypeDefinition, TypeKind, TypeSystemDefinition, }, Positioned, }; @@ -20,7 +20,7 @@ use itertools::Itertools; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; -use crate::ir::types::{get_base_named_type, is_argument_type_valid, is_scalar_only_subtype}; +use crate::ir::Type; use crate::util::{BTreeMapTryInsertExt, HashMapTryInsertExt}; use self::error::InvalidSchemaError; @@ -280,7 +280,11 @@ directive @transform(op: String!) on FIELD get_vertex_type_implements(&self.vertex_types[vertex_type]) } - pub(crate) fn is_subtype(&self, parent_type: &Type, maybe_subtype: &Type) -> bool { + pub(crate) fn is_subtype( + &self, + parent_type: &async_graphql_parser::types::Type, + maybe_subtype: &async_graphql_parser::types::Type, + ) -> bool { is_subtype(&self.vertex_types, parent_type, maybe_subtype) } @@ -297,8 +301,8 @@ fn check_root_query_type_invariants( let mut errors: Vec = vec![]; for field_defn in &query_type.fields { - let field_type = &field_defn.node.ty.node; - let base_named_type = get_base_named_type(field_type); + let field_type = Type::from_type(&field_defn.node.ty.node); + let base_named_type = field_type.base_type(); if BUILTIN_SCALARS.contains(base_named_type) { errors.push(InvalidSchemaError::PropertyFieldOnRootQueryType( query_type_definition.name.node.to_string(), @@ -343,7 +347,9 @@ fn check_type_and_property_and_edge_invariants( )); } - let base_named_type = get_base_named_type(field_type); + let field_type = Type::from_type(field_type); + + let base_named_type = field_type.base_type(); if BUILTIN_SCALARS.contains(base_named_type) { // We're looking at a property field. if !field_defn.arguments.is_empty() { @@ -370,7 +376,7 @@ fn check_type_and_property_and_edge_invariants( let param_type = ¶m_defn.node.ty.node; match value.node.clone().try_into() { Ok(value) => { - if !is_argument_type_valid(param_type, &value) { + if !Type::from_type(param_type).is_valid_value(&value) { errors.push(InvalidSchemaError::InvalidDefaultValueForFieldParameter( type_name.to_string(), field_defn.name.node.to_string(), @@ -397,18 +403,14 @@ fn check_type_and_property_and_edge_invariants( // Check that the edge field doesn't have // a list-of-list or more nested list type. - match &field_type.base { - BaseType::Named(_) => {} - BaseType::List(inner) => match &inner.base { - BaseType::Named(_) => {} - BaseType::List(_) => { - errors.push(InvalidSchemaError::InvalidEdgeType( - type_name.to_string(), - field_defn.name.node.to_string(), - field_type.to_string(), - )); - } - }, + if let Some(inner_list) = field_type.as_list() { + if inner_list.is_list() { + errors.push(InvalidSchemaError::InvalidEdgeType( + type_name.to_string(), + field_defn.name.node.to_string(), + field_type.to_string(), + )); + } } } } else { @@ -463,8 +465,8 @@ fn is_named_type_subtype( fn is_subtype( vertex_types: &HashMap, TypeDefinition>, - parent_type: &Type, - maybe_subtype: &Type, + parent_type: &async_graphql_parser::types::Type, + maybe_subtype: &async_graphql_parser::types::Type, ) -> bool { // If the parent type is non-nullable, all its subtypes must be non-nullable as well. // If the parent type is nullable, it can have both nullable and non-nullable subtypes. @@ -697,7 +699,9 @@ fn check_field_type_narrowing( if let Some(&parent_field_type) = parent_field_parameters.get(field_parameter) { - if !is_scalar_only_subtype(field_type, parent_field_type) { + if !Type::from_type(field_type) + .is_scalar_only_subtype(&Type::from_type(parent_field_type)) + { errors.push(InvalidSchemaError::InvalidTypeNarrowingOfInheritedFieldParameter( field_name.to_owned(), type_name.to_string(),