Skip to content

Commit

Permalink
feat(discriminator): Add Discriminator support for Interfaces (#2966)
Browse files Browse the repository at this point in the history
Co-authored-by: Kiryl Mialeshka <[email protected]>
  • Loading branch information
karatakis and meskill authored Oct 11, 2024
1 parent 56d536b commit a48cf84
Show file tree
Hide file tree
Showing 28 changed files with 879 additions and 180 deletions.
4 changes: 3 additions & 1 deletion src/core/blueprint/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashSet;

use async_graphql_value::ConstValue;
use directive::Directive;
use interface_resolver::update_interface_resolver;
use regex::Regex;
use union_resolver::update_union_resolver;

Expand Down Expand Up @@ -510,6 +511,7 @@ pub fn to_field_definition(
.and(update_protected(object_name).trace(Protected::trace_name().as_str()))
.and(update_enum_alias())
.and(update_union_resolver())
.and(update_interface_resolver())
.try_fold(
&(config_module, field, type_of, name),
FieldDefinition::default(),
Expand All @@ -528,7 +530,7 @@ pub fn to_definitions<'a>() -> TryFold<'a, ConfigModule, Vec<Definition>, String
Definition::Object(object_type_definition) => {
if config_module.input_types().contains(name) {
to_input_object_type_definition(object_type_definition).trace(name)
} else if config_module.interface_types().contains(name) {
} else if config_module.interfaces_types_map().contains_key(name) {
to_interface_type_definition(object_type_definition).trace(name)
} else {
Valid::succeed(definition)
Expand Down
57 changes: 57 additions & 0 deletions src/core/blueprint/interface_resolver.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use std::collections::HashSet;

use crate::core::blueprint::FieldDefinition;
use crate::core::config;
use crate::core::config::{ConfigModule, Field};
use crate::core::ir::model::IR;
use crate::core::ir::Discriminator;
use crate::core::try_fold::TryFold;
use crate::core::valid::{Valid, Validator};

fn compile_interface_resolver(
config: &ConfigModule,
interface_name: &str,
interface_types: HashSet<String>,
) -> Valid<Discriminator, String> {
Valid::from_iter(&interface_types, |type_name| {
Valid::from_option(
config
.find_type(type_name)
.map(|type_| (type_name.as_str(), type_)),
"Can't find a type that is member of interface type".to_string(),
)
})
.and_then(|types| {
let types: Vec<_> = types.into_iter().collect();

Discriminator::new(interface_name, &types)
})
}

pub fn update_interface_resolver<'a>(
) -> TryFold<'a, (&'a ConfigModule, &'a Field, &'a config::Type, &'a str), FieldDefinition, String>
{
TryFold::<(&ConfigModule, &Field, &config::Type, &str), FieldDefinition, String>::new(
|(config, field, _, _), mut b_field| {
let Some(interface_types) = config.interfaces_types_map().get(field.type_of.name())
else {
return Valid::succeed(b_field);
};

compile_interface_resolver(
config,
field.type_of.name(),
interface_types.iter().cloned().collect(),
)
.map(|discriminator| {
b_field.resolver = Some(
b_field
.resolver
.unwrap_or(IR::ContextPath(vec![b_field.name.clone()])),
);
b_field.map_expr(move |expr| IR::Discriminate(discriminator, expr.into()));
b_field
})
},
)
}
1 change: 1 addition & 0 deletions src/core/blueprint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod directive;
mod dynamic_value;
mod from_config;
mod index;
mod interface_resolver;
mod into_document;
mod into_schema;
mod links;
Expand Down
110 changes: 103 additions & 7 deletions src/core/config/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,17 +511,63 @@ impl Config {
types
}

/// Returns a list of all the types that are used as interface
pub fn interface_types(&self) -> HashSet<String> {
let mut types = HashSet::new();
pub fn interfaces_types_map(&self) -> BTreeMap<String, BTreeSet<String>> {
let mut interfaces_types: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();

for (type_name, type_definition) in self.types.iter() {
for implement_name in type_definition.implements.clone() {
interfaces_types
.entry(implement_name)
.or_default()
.insert(type_name.clone());
}
}

for ty in self.types.values() {
for interface in ty.implements.iter() {
types.insert(interface.clone());
fn recursive_interface_type_merging(
types_set: &BTreeSet<String>,
interfaces_types: &BTreeMap<String, BTreeSet<String>>,
temp_interface_types: &mut BTreeMap<String, BTreeSet<String>>,
) -> BTreeSet<String> {
let mut types_set_local = BTreeSet::new();

for type_name in types_set.iter() {
match (
interfaces_types.get(type_name),
temp_interface_types.get(type_name),
) {
(Some(types_set_inner), None) => {
let types_set_inner = recursive_interface_type_merging(
types_set_inner,
interfaces_types,
temp_interface_types,
);
temp_interface_types.insert(type_name.to_string(), types_set_inner.clone());
types_set_local = types_set_local.merge_right(types_set_inner);
}
(Some(_), Some(types_set_inner)) => {
types_set_local = types_set_local.merge_right(types_set_inner.clone());
}
_ => {
types_set_local.insert(type_name.to_string());
}
}
}

types_set_local
}

types
let mut interfaces_types_map: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
let mut temp_interface_types: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
for (interface_name, types_set) in interfaces_types.iter() {
let types_set = recursive_interface_type_merging(
types_set,
&interfaces_types,
&mut temp_interface_types,
);
interfaces_types_map.insert(interface_name.clone(), types_set);
}

interfaces_types_map
}

/// Returns a list of all the arguments in the configuration
Expand Down Expand Up @@ -764,4 +810,54 @@ mod tests {
.collect();
assert_eq!(union_types, expected_union_types);
}

#[test]
fn test_interfaces_types_map() {
let sdl = std::fs::read_to_string(tailcall_fixtures::configs::INTERFACE_CONFIG).unwrap();
let config = Config::from_sdl(&sdl).to_result().unwrap();
let interfaces_types_map = config.interfaces_types_map();

let mut expected_union_types = BTreeMap::new();

{
let mut set = BTreeSet::new();
set.insert("E".to_string());
set.insert("F".to_string());
expected_union_types.insert("T0".to_string(), set);
}

{
let mut set = BTreeSet::new();
set.insert("A".to_string());
set.insert("E".to_string());
set.insert("B".to_string());
set.insert("C".to_string());
set.insert("D".to_string());
expected_union_types.insert("T1".to_string(), set);
}

{
let mut set = BTreeSet::new();
set.insert("B".to_string());
set.insert("E".to_string());
set.insert("D".to_string());
expected_union_types.insert("T2".to_string(), set);
}

{
let mut set = BTreeSet::new();
set.insert("C".to_string());
set.insert("E".to_string());
set.insert("D".to_string());
expected_union_types.insert("T3".to_string(), set);
}

{
let mut set = BTreeSet::new();
set.insert("D".to_string());
expected_union_types.insert("T4".to_string(), set);
}

assert_eq!(interfaces_types_map, expected_union_types);
}
}
16 changes: 8 additions & 8 deletions src/core/config/config_module.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet};
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::ops::Deref;

use jsonwebtoken::jwk::JwkSet;
Expand Down Expand Up @@ -30,20 +30,20 @@ struct Cache {
config: Config,
input_types: HashSet<String>,
output_types: HashSet<String>,
interface_types: HashSet<String>,
interfaces_types_map: BTreeMap<String, BTreeSet<String>>,
}

impl From<Config> for Cache {
fn from(value: Config) -> Self {
let input_types = value.input_types();
let output_types = value.output_types();
let interface_types = value.interface_types();
let interfaces_types_map = value.interfaces_types_map();

Cache {
config: value,
input_types: input_types.clone(),
output_types: output_types.clone(),
interface_types: interface_types.clone(),
input_types,
output_types,
interfaces_types_map,
}
}
}
Expand Down Expand Up @@ -79,8 +79,8 @@ impl ConfigModule {
&self.cache.output_types
}

pub fn interface_types(&self) -> &HashSet<String> {
&self.cache.interface_types
pub fn interfaces_types_map(&self) -> &BTreeMap<String, BTreeSet<String>> {
&self.cache.interfaces_types_map
}

pub fn transform<T: Transform<Value = Config>>(self, transformer: T) -> Valid<Self, T::Error> {
Expand Down
6 changes: 3 additions & 3 deletions src/core/config/config_module/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ impl Invariant for Cache {
let is_self_input = self.input_types.contains(&type_name);
let is_other_input = other.input_types.contains(&type_name);
let is_self_output = self.output_types.contains(&type_name)
|| self.interface_types.contains(&type_name);
|| self.interfaces_types_map.contains_key(&type_name);
let is_other_output = other.output_types.contains(&type_name)
|| other.interface_types.contains(&type_name);
|| other.interfaces_types_map.contains_key(&type_name);

match (
is_self_input,
Expand Down Expand Up @@ -279,7 +279,7 @@ impl Invariant for Cache {
config,
input_types: self.input_types.merge_right(other.input_types),
output_types: self.output_types.merge_right(other.output_types),
interface_types: self.interface_types.merge_right(other.interface_types),
interfaces_types_map: self.interfaces_types_map.merge_right(other.interfaces_types_map),
}
})
}
Expand Down
1 change: 1 addition & 0 deletions src/core/config/from_document.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ fn to_types(
fn to_scalar_type() -> config::Type {
config::Type { ..Default::default() }
}

fn to_union_types(
type_definitions: &[&Positioned<TypeDefinition>],
) -> Valid<BTreeMap<String, Union>, String> {
Expand Down
4 changes: 2 additions & 2 deletions src/core/config/into_document.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ fn config_document(config: &Config) -> ServiceDocument {
.map(|name| pos(Name::new(name))),
};
definitions.push(TypeSystemDefinition::Schema(pos(schema_definition)));
let interface_types = config.interface_types();
let interface_types = config.interfaces_types_map();
let input_types = config.input_types();
for (type_name, type_def) in config.types.iter() {
let kind = if interface_types.contains(type_name) {
let kind = if interface_types.contains_key(type_name) {
TypeKind::Interface(InterfaceType {
implements: type_def
.implements
Expand Down
2 changes: 1 addition & 1 deletion src/core/config/transformer/merge_types/mergeable_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl MergeableTypes {
input_types: config.input_types(),
union_types: config.union_types(),
output_types: config.output_types(),
interface_types: config.interface_types(),
interface_types: config.interfaces_types_map().keys().cloned().collect(),
threshold,
}
}
Expand Down
36 changes: 25 additions & 11 deletions src/core/ir/eval.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::future::Future;
use std::ops::Deref;

Expand Down Expand Up @@ -69,18 +70,32 @@ impl IR {
}
}
IR::Map(Map { input, map }) => {
let value = input.eval(ctx).await?;
if let ConstValue::String(key) = value {
if let Some(value) = map.get(&key) {
Ok(ConstValue::String(value.to_owned()))
} else {
Err(Error::ExprEval(format!("Can't find mapped key: {}.", key)))
fn recursive_map_enum(
val: Result<ConstValue, Error>,
map: &HashMap<String, String>,
) -> Result<ConstValue, Error> {
match val? {
ConstValue::Null => Ok(ConstValue::Null),
ConstValue::String(key) => {
if let Some(value) = map.get(&key) {
Ok(ConstValue::String(value.to_owned()))
} else {
Err(Error::ExprEval(format!("Can't find mapped key: {}.", key)))
}
}
ConstValue::List(vec) => {
let vec = vec
.into_iter()
.map(|value| recursive_map_enum(Ok(value), map))
.collect::<Result<Vec<_>, _>>()?;
Ok(ConstValue::List(vec))
}
_ => Err(Error::ExprEval(
"Mapped key must be either string or array value.".to_owned(),
)),
}
} else {
Err(Error::ExprEval(
"Mapped key must be string value.".to_owned(),
))
}
recursive_map_enum(input.eval(ctx).await, map)
}
IR::Pipe(first, second) => {
let args = first.eval(&mut ctx.clone()).await?;
Expand Down Expand Up @@ -115,7 +130,6 @@ impl IR {
))?;

let mut tasks = Vec::with_capacity(representations.len());

for repr in representations {
// TODO: combine errors, instead of fail fast?
let type_name = repr.get_type_name().ok_or(Error::Entity(
Expand Down
Loading

1 comment on commit a48cf84

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running 30s test @ http://localhost:8000/graphql

4 threads and 100 connections

Thread Stats Avg Stdev Max +/- Stdev
Latency 7.36ms 3.21ms 41.73ms 72.80%
Req/Sec 3.46k 393.29 3.89k 94.25%

412554 requests in 30.01s, 797.27MB read

Requests/sec: 13746.07

Transfer/sec: 26.56MB

Please sign in to comment.