diff --git a/Cargo.lock b/Cargo.lock
index 60ae7da967..cbcac0edf3 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -5147,6 +5147,7 @@ dependencies = [
"serde_yaml",
"stripmargin",
"strum_macros 0.26.2",
+ "tailcall-macros",
"tailcall-prettier",
"temp-env",
"tempfile",
@@ -5217,6 +5218,15 @@ dependencies = [
"worker",
]
+[[package]]
+name = "tailcall-macros"
+version = "0.1.0"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.59",
+]
+
[[package]]
name = "tailcall-prettier"
version = "0.1.0"
diff --git a/Cargo.toml b/Cargo.toml
index 76211a3034..b7fd57f72c 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -144,6 +144,7 @@ async-graphql = { workspace = true, features = [
dotenvy = "0.15"
convert_case = "0.6.0"
rand = "0.8.5"
+tailcall-macros = { path = "tailcall-macros" }
tonic-types = "0.11.0"
@@ -197,6 +198,7 @@ members = [
"tailcall-autogen",
"tailcall-aws-lambda",
"tailcall-cloudflare",
+ "tailcall-macros",
"tailcall-prettier",
"tailcall-query-plan",
]
diff --git a/src/config/config.rs b/src/config/config.rs
index 20af7f2887..d294bac5b5 100644
--- a/src/config/config.rs
+++ b/src/config/config.rs
@@ -16,11 +16,21 @@ use crate::directive::DirectiveCodec;
use crate::http::Method;
use crate::is_default;
use crate::json::JsonSchema;
+use crate::macros::MergeRight;
use crate::merge_right::MergeRight;
use crate::valid::{Valid, Validator};
#[derive(
- Serialize, Deserialize, Clone, Debug, Default, Setters, PartialEq, Eq, schemars::JsonSchema,
+ Serialize,
+ Deserialize,
+ Clone,
+ Debug,
+ Default,
+ Setters,
+ PartialEq,
+ Eq,
+ schemars::JsonSchema,
+ MergeRight,
)]
#[serde(rename_all = "camelCase")]
pub struct Config {
@@ -154,28 +164,12 @@ impl Config {
}
}
-impl MergeRight for Config {
- fn merge_right(self, other: Self) -> Self {
- let server = self.server.merge_right(other.server);
- let types = merge_types(self.types, other.types);
- let unions = merge_unions(self.unions, other.unions);
- let schema = self.schema.merge_right(other.schema);
- let upstream = self.upstream.merge_right(other.upstream);
- let links = merge_links(self.links, other.links);
- let telemetry = self.telemetry.merge_right(other.telemetry);
-
- Self { server, upstream, types, schema, unions, links, telemetry }
- }
-}
-
-fn merge_links(self_links: Vec, other_links: Vec) -> Vec {
- self_links.merge_right(other_links)
-}
-
///
/// Represents a GraphQL type.
/// A type can be an object, interface, enum or scalar.
-#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema)]
+#[derive(
+ Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema, MergeRight,
+)]
pub struct Type {
///
/// A map of field name and its definition.
@@ -229,22 +223,9 @@ impl Type {
}
}
-impl MergeRight for Type {
- fn merge_right(mut self, other: Self) -> Self {
- let fields = self.fields.merge_right(other.fields);
- self.implements = self.implements.merge_right(other.implements);
- if let Some(ref variants) = self.variants {
- if let Some(ref other) = other.variants {
- self.variants = Some(variants.union(other).cloned().collect());
- }
- } else {
- self.variants = other.variants;
- }
- Self { fields, ..self }
- }
-}
-
-#[derive(Clone, Debug, Default, PartialEq, Deserialize, Serialize, Eq, schemars::JsonSchema)]
+#[derive(
+ Clone, Debug, Default, PartialEq, Deserialize, Serialize, Eq, schemars::JsonSchema, MergeRight,
+)]
#[serde(deny_unknown_fields)]
/// Used to represent an identifier for a type. Typically used via only by the
/// configuration generators to provide additional information about the type.
@@ -253,7 +234,7 @@ pub struct Tag {
pub id: String,
}
-#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, Eq, schemars::JsonSchema)]
+#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, Eq, schemars::JsonSchema, MergeRight)]
/// The @cache operator enables caching for the query, field or type it is
/// applied to.
#[serde(rename_all = "camelCase")]
@@ -264,38 +245,22 @@ pub struct Cache {
pub max_age: NonZeroU64,
}
-#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Default, schemars::JsonSchema)]
+#[derive(
+ Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Default, schemars::JsonSchema, MergeRight,
+)]
pub struct Protected {}
-fn merge_types(
- mut self_types: BTreeMap,
- other_types: BTreeMap,
-) -> BTreeMap {
- for (name, mut other_type) in other_types {
- if let Some(self_type) = self_types.remove(&name) {
- other_type = self_type.merge_right(other_type);
- }
-
- self_types.insert(name, other_type);
- }
- self_types
-}
-
-fn merge_unions(
- mut self_unions: BTreeMap,
- other_unions: BTreeMap,
-) -> BTreeMap {
- for (name, mut other_union) in other_unions {
- if let Some(self_union) = self_unions.remove(&name) {
- other_union = self_union.merge_right(other_union);
- }
- self_unions.insert(name, other_union);
- }
- self_unions
-}
-
#[derive(
- Serialize, Deserialize, Clone, Debug, Default, Setters, PartialEq, Eq, schemars::JsonSchema,
+ Serialize,
+ Deserialize,
+ Clone,
+ Debug,
+ Default,
+ Setters,
+ PartialEq,
+ Eq,
+ schemars::JsonSchema,
+ MergeRight,
)]
#[setters(strip_option)]
pub struct RootSchema {
@@ -306,17 +271,6 @@ pub struct RootSchema {
pub subscription: Option,
}
-impl MergeRight for RootSchema {
- // TODO: add unit-tests
- fn merge_right(self, other: Self) -> Self {
- Self {
- query: self.query.merge_right(other.query),
- mutation: self.mutation.merge_right(other.mutation),
- subscription: self.subscription.merge_right(other.subscription),
- }
- }
-}
-
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)]
#[serde(deny_unknown_fields)]
/// Used to omit a field from public consumption.
@@ -409,6 +363,13 @@ pub struct Field {
pub protected: Option,
}
+// It's a terminal implementation of MergeRight
+impl MergeRight for Field {
+ fn merge_right(self, other: Self) -> Self {
+ other
+ }
+}
+
impl Field {
pub fn has_resolver(&self) -> bool {
self.http.is_some()
@@ -522,19 +483,12 @@ pub struct Arg {
pub default_value: Option,
}
-#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)]
+#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema, MergeRight)]
pub struct Union {
pub types: BTreeSet,
pub doc: Option,
}
-impl MergeRight for Union {
- fn merge_right(mut self, other: Self) -> Self {
- self.types = self.types.merge_right(other.types);
- self
- }
-}
-
#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema)]
#[serde(deny_unknown_fields)]
/// The @http operator indicates that a field or node is backed by a REST API.
diff --git a/src/config/config_module.rs b/src/config/config_module.rs
index 9c043a5f02..266c687b2f 100644
--- a/src/config/config_module.rs
+++ b/src/config/config_module.rs
@@ -9,13 +9,14 @@ use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use crate::blueprint::GrpcMethod;
use crate::config::Config;
+use crate::macros::MergeRight;
use crate::merge_right::MergeRight;
use crate::rest::{EndpointSet, Unchecked};
use crate::scalar;
/// A wrapper on top of Config that contains all the resolved extensions and
/// computed values.
-#[derive(Clone, Debug, Default, Setters)]
+#[derive(Clone, Debug, Default, Setters, MergeRight)]
pub struct ConfigModule {
pub config: Config,
pub extensions: Extensions,
@@ -39,7 +40,7 @@ impl Deref for Content {
/// Extensions are meta-information required before we can generate the
/// blueprint. Typically, this information cannot be inferred without performing
/// an IO operation, i.e., reading a file, making an HTTP call, etc.
-#[derive(Clone, Debug, Default)]
+#[derive(Clone, Debug, Default, MergeRight)]
pub struct Extensions {
/// Contains the file descriptor sets resolved from the links
pub grpc_file_descriptors: Vec>,
@@ -79,35 +80,6 @@ impl Extensions {
}
}
-impl MergeRight for Extensions {
- fn merge_right(mut self, mut other: Self) -> Self {
- self.grpc_file_descriptors = self
- .grpc_file_descriptors
- .merge_right(other.grpc_file_descriptors);
- self.script = self.script.merge_right(other.script.take());
- self.cert = self.cert.merge_right(other.cert);
- self.keys = if !other.keys.is_empty() {
- other.keys
- } else {
- self.keys
- };
- self.endpoint_set = self.endpoint_set.merge_right(other.endpoint_set);
- self.htpasswd = self.htpasswd.merge_right(other.htpasswd);
- self.jwks = self.jwks.merge_right(other.jwks);
- self
- }
-}
-
-impl MergeRight for ConfigModule {
- fn merge_right(mut self, other: Self) -> Self {
- self.config = self.config.merge_right(other.config);
- self.extensions = self.extensions.merge_right(other.extensions);
- self.input_types = self.input_types.merge_right(other.input_types);
- self.output_types = self.output_types.merge_right(other.output_types);
- self
- }
-}
-
impl Deref for ConfigModule {
type Target = Config;
fn deref(&self) -> &Self::Target {
diff --git a/src/config/cors.rs b/src/config/cors.rs
index 3934a719f6..e9a81af341 100644
--- a/src/config/cors.rs
+++ b/src/config/cors.rs
@@ -3,9 +3,13 @@ use serde::{Deserialize, Serialize};
use crate::http::Method;
use crate::is_default;
+use crate::macros::MergeRight;
+use crate::merge_right::MergeRight;
/// Type to configure Cross-Origin Resource Sharing (CORS) for a server.
-#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema)]
+#[derive(
+ Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema, MergeRight,
+)]
#[serde(rename_all = "camelCase")]
pub struct Cors {
/// Indicates whether the server allows credentials (e.g., cookies,
diff --git a/src/config/headers.rs b/src/config/headers.rs
index 0dd82f5cd1..71930dbb9b 100644
--- a/src/config/headers.rs
+++ b/src/config/headers.rs
@@ -5,8 +5,12 @@ use serde::{Deserialize, Serialize};
use crate::config::cors::Cors;
use crate::config::KeyValue;
use crate::is_default;
+use crate::macros::MergeRight;
+use crate::merge_right::MergeRight;
-#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema)]
+#[derive(
+ Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema, MergeRight,
+)]
#[serde(rename_all = "camelCase")]
pub struct Headers {
#[serde(default, skip_serializing_if = "is_default")]
@@ -47,21 +51,3 @@ impl Headers {
self.cors.clone()
}
}
-
-pub fn merge_headers(current: Option, other: Option) -> Option {
- let mut headers = current.clone();
-
- if let Some(other_headers) = other {
- if let Some(mut self_headers) = current.clone() {
- self_headers.cache_control = other_headers.cache_control.or(self_headers.cache_control);
- self_headers.custom.extend(other_headers.custom);
- self_headers.cors = other_headers.cors.or(self_headers.cors);
-
- headers = Some(self_headers);
- } else {
- headers = Some(other_headers);
- }
- }
-
- headers
-}
diff --git a/src/config/server.rs b/src/config/server.rs
index 797e169425..f83e88f88a 100644
--- a/src/config/server.rs
+++ b/src/config/server.rs
@@ -2,13 +2,16 @@ use std::collections::{BTreeMap, BTreeSet};
use serde::{Deserialize, Serialize};
-use super::{merge_headers, merge_key_value_vecs};
+use super::merge_key_value_vecs;
use crate::config::headers::Headers;
use crate::config::KeyValue;
use crate::is_default;
+use crate::macros::MergeRight;
use crate::merge_right::MergeRight;
-#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema)]
+#[derive(
+ Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema, MergeRight,
+)]
#[serde(deny_unknown_fields)]
#[serde(rename_all = "camelCase")]
/// The `@server` directive, when applied at the schema level, offers a
@@ -82,6 +85,7 @@ pub struct Server {
pub showcase: Option,
#[serde(default, skip_serializing_if = "is_default")]
+ #[merge_right(merge_right_fn = "merge_right_vars")]
/// This configuration defines local variables for server operations. Useful
/// for storing constant configurations, secrets, or shared information.
pub vars: Vec,
@@ -97,31 +101,35 @@ pub struct Server {
pub workers: Option,
}
-#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)]
+fn merge_right_vars(mut left: Vec, right: Vec) -> Vec {
+ left = right.iter().fold(left.to_vec(), |mut acc, kv| {
+ let position = acc.iter().position(|x| x.key == kv.key);
+ if let Some(pos) = position {
+ acc[pos] = kv.clone();
+ } else {
+ acc.push(kv.clone());
+ };
+ acc
+ });
+ left = merge_key_value_vecs(&left, &right);
+ left
+}
+
+#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema, MergeRight)]
#[serde(rename_all = "camelCase")]
pub struct ScriptOptions {
pub timeout: Option,
}
-impl MergeRight for ScriptOptions {
- fn merge_right(self, other: Self) -> Self {
- ScriptOptions { timeout: self.timeout.merge_right(other.timeout) }
- }
-}
-
-#[derive(Deserialize, Serialize, Debug, PartialEq, Eq, Clone, Default, schemars::JsonSchema)]
+#[derive(
+ Deserialize, Serialize, Debug, PartialEq, Eq, Clone, Default, schemars::JsonSchema, MergeRight,
+)]
pub enum HttpVersion {
#[default]
HTTP1,
HTTP2,
}
-impl MergeRight for HttpVersion {
- fn merge_right(self, other: Self) -> Self {
- other
- }
-}
-
impl Server {
pub fn enable_apollo_tracing(&self) -> bool {
self.apollo_tracing.unwrap_or(false)
@@ -208,41 +216,6 @@ impl Server {
}
}
-impl MergeRight for Server {
- fn merge_right(mut self, other: Self) -> Self {
- self.apollo_tracing = self.apollo_tracing.merge_right(other.apollo_tracing);
- self.headers = merge_headers(self.headers, other.headers);
- self.graphiql = self.graphiql.merge_right(other.graphiql);
- self.introspection = self.introspection.merge_right(other.introspection);
- self.query_validation = self.query_validation.merge_right(other.query_validation);
- self.response_validation = self
- .response_validation
- .merge_right(other.response_validation);
- self.batch_requests = self.batch_requests.merge_right(other.batch_requests);
- self.global_response_timeout = self
- .global_response_timeout
- .merge_right(other.global_response_timeout);
- self.showcase = self.showcase.merge_right(other.showcase);
- self.workers = self.workers.merge_right(other.workers);
- self.port = self.port.merge_right(other.port);
- self.hostname = self.hostname.merge_right(other.hostname);
- self.vars = other.vars.iter().fold(self.vars.to_vec(), |mut acc, kv| {
- let position = acc.iter().position(|x| x.key == kv.key);
- if let Some(pos) = position {
- acc[pos] = kv.clone();
- } else {
- acc.push(kv.clone());
- };
- acc
- });
- self.vars = merge_key_value_vecs(&self.vars, &other.vars);
- self.version = self.version.merge_right(other.version);
- self.pipeline_flush = self.pipeline_flush.merge_right(other.pipeline_flush);
- self.script = self.script.merge_right(other.script);
- self
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
diff --git a/src/config/telemetry.rs b/src/config/telemetry.rs
index 8be111bd7e..efaf01f12e 100644
--- a/src/config/telemetry.rs
+++ b/src/config/telemetry.rs
@@ -5,6 +5,7 @@ use super::KeyValue;
use crate::config::{Apollo, ConfigReaderContext};
use crate::helpers::headers::to_mustache_headers;
use crate::is_default;
+use crate::macros::MergeRight;
use crate::merge_right::MergeRight;
use crate::mustache::Mustache;
use crate::valid::Validator;
@@ -18,7 +19,7 @@ mod defaults {
}
/// Output the opentelemetry data to the stdout. Mostly used for debug purposes
-#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema, MergeRight)]
#[serde(rename_all = "camelCase")]
pub struct StdoutExporter {
/// Output to stdout in pretty human-readable format
@@ -26,14 +27,8 @@ pub struct StdoutExporter {
pub pretty: bool,
}
-impl MergeRight for StdoutExporter {
- fn merge_right(self, other: Self) -> Self {
- Self { pretty: other.pretty || self.pretty }
- }
-}
-
/// Output the opentelemetry data to otlp collector
-#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema, MergeRight)]
#[serde(rename_all = "camelCase")]
pub struct OtlpExporter {
pub url: String,
@@ -41,15 +36,6 @@ pub struct OtlpExporter {
pub headers: Vec,
}
-impl MergeRight for OtlpExporter {
- fn merge_right(self, other: Self) -> Self {
- let mut headers = self.headers;
- headers.extend(other.headers.iter().cloned());
-
- Self { url: other.url, headers }
- }
-}
-
/// Output format for prometheus data
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema)]
#[serde(rename_all = "camelCase")]
@@ -72,13 +58,7 @@ pub struct PrometheusExporter {
pub format: PrometheusFormat,
}
-impl PrometheusExporter {
- fn merge_right(&self, other: Self) -> Self {
- other
- }
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema, MergeRight)]
#[serde(rename_all = "camelCase")]
pub enum TelemetryExporter {
Stdout(StdoutExporter),
@@ -87,23 +67,6 @@ pub enum TelemetryExporter {
Apollo(Apollo),
}
-impl MergeRight for TelemetryExporter {
- fn merge_right(self, other: Self) -> Self {
- match (self, other) {
- (TelemetryExporter::Stdout(left), TelemetryExporter::Stdout(right)) => {
- TelemetryExporter::Stdout(left.merge_right(right))
- }
- (TelemetryExporter::Otlp(left), TelemetryExporter::Otlp(right)) => {
- TelemetryExporter::Otlp(left.merge_right(right))
- }
- (TelemetryExporter::Prometheus(left), TelemetryExporter::Prometheus(right)) => {
- TelemetryExporter::Prometheus(left.merge_right(right))
- }
- (_, other) => other,
- }
- }
-}
-
#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema)]
#[serde(deny_unknown_fields)]
#[serde(rename_all = "camelCase")]
@@ -225,10 +188,7 @@ mod tests {
Telemetry {
export: Some(TelemetryExporter::Otlp(OtlpExporter {
url: "test-url-2".to_owned(),
- headers: vec![
- KeyValue { key: "header_a".to_owned(), value: "a".to_owned() },
- KeyValue { key: "header_b".to_owned(), value: "b".to_owned() }
- ]
+ headers: vec![KeyValue { key: "header_b".to_owned(), value: "b".to_owned() }]
})),
request_headers: vec!["Api-Key-A".to_string(), "Api-Key-B".to_string(),]
}
diff --git a/src/config/upstream.rs b/src/config/upstream.rs
index 47e8d9b367..a4d3b9b58b 100644
--- a/src/config/upstream.rs
+++ b/src/config/upstream.rs
@@ -4,9 +4,12 @@ use derive_setters::Setters;
use serde::{Deserialize, Serialize};
use crate::is_default;
+use crate::macros::MergeRight;
use crate::merge_right::MergeRight;
-#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug, Setters, schemars::JsonSchema)]
+#[derive(
+ Serialize, Deserialize, PartialEq, Eq, Clone, Debug, Setters, schemars::JsonSchema, MergeRight,
+)]
#[serde(rename_all = "camelCase", default)]
pub struct Batch {
pub delay: usize,
@@ -20,19 +23,22 @@ impl Default for Batch {
}
}
-#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug, schemars::JsonSchema)]
+#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug, schemars::JsonSchema, MergeRight)]
pub struct Proxy {
pub url: String,
}
-impl MergeRight for Proxy {
- fn merge_right(self, other: Self) -> Self {
- other
- }
-}
-
#[derive(
- Serialize, Deserialize, PartialEq, Eq, Clone, Debug, Setters, Default, schemars::JsonSchema,
+ Serialize,
+ Deserialize,
+ PartialEq,
+ Eq,
+ Clone,
+ Debug,
+ Setters,
+ Default,
+ schemars::JsonSchema,
+ MergeRight,
)]
#[serde(deny_unknown_fields)]
#[serde(rename_all = "camelCase", default)]
@@ -174,45 +180,6 @@ impl Upstream {
}
}
-impl MergeRight for Upstream {
- // TODO: add unit tests for merge
- fn merge_right(mut self, other: Self) -> Self {
- self.allowed_headers = self.allowed_headers.merge_right(other.allowed_headers);
- self.base_url = self.base_url.merge_right(other.base_url);
- self.connect_timeout = self.connect_timeout.merge_right(other.connect_timeout);
- self.http_cache = self.http_cache.merge_right(other.http_cache);
- self.keep_alive_interval = self
- .keep_alive_interval
- .merge_right(other.keep_alive_interval);
- self.keep_alive_timeout = self
- .keep_alive_timeout
- .merge_right(other.keep_alive_timeout);
- self.keep_alive_while_idle = self
- .keep_alive_while_idle
- .merge_right(other.keep_alive_while_idle);
- self.pool_idle_timeout = self.pool_idle_timeout.merge_right(other.pool_idle_timeout);
- self.pool_max_idle_per_host = self
- .pool_max_idle_per_host
- .merge_right(other.pool_max_idle_per_host);
- self.proxy = self.proxy.merge_right(other.proxy);
- self.tcp_keep_alive = self.tcp_keep_alive.merge_right(other.tcp_keep_alive);
- self.timeout = self.timeout.merge_right(other.timeout);
- self.user_agent = self.user_agent.merge_right(other.user_agent);
-
- if let Some(other) = other.batch {
- let mut batch = self.batch.unwrap_or_default();
- batch.max_size = other.max_size;
- batch.delay = other.delay;
- batch.headers = batch.headers.merge_right(other.headers);
-
- self.batch = Some(batch);
- }
-
- self.http2_only = self.http2_only.merge_right(other.http2_only);
- self
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
diff --git a/src/lib.rs b/src/lib.rs
index 5242bd7ade..6ab74992b2 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -38,12 +38,14 @@ pub mod tracing;
pub mod try_fold;
pub mod valid;
+// Re-export everything from `tailcall_macros` as `macros`
use std::borrow::Cow;
use std::hash::Hash;
use std::num::NonZeroU64;
use async_graphql_value::ConstValue;
use http::Response;
+pub use tailcall_macros as macros;
pub trait EnvIO: Send + Sync + 'static {
fn get(&self, key: &str) -> Option>;
diff --git a/src/merge_right.rs b/src/merge_right.rs
index e9e7f9c5ca..43df242828 100644
--- a/src/merge_right.rs
+++ b/src/merge_right.rs
@@ -1,4 +1,5 @@
use std::collections::{BTreeMap, BTreeSet, HashSet};
+use std::sync::Arc;
pub trait MergeRight {
fn merge_right(self, other: Self) -> Self;
@@ -15,6 +16,14 @@ impl MergeRight for Option {
}
}
+impl MergeRight for Arc {
+ fn merge_right(self, other: Self) -> Self {
+ let l = Arc::into_inner(self);
+ let r = Arc::into_inner(other);
+ Arc::new(l.merge_right(r).unwrap_or_default())
+ }
+}
+
impl MergeRight for Vec {
fn merge_right(mut self, other: Self) -> Self {
self.extend(other);
@@ -25,10 +34,16 @@ impl MergeRight for Vec {
impl MergeRight for BTreeMap
where
K: Ord,
- V: Clone,
+ V: Clone + MergeRight,
{
fn merge_right(mut self, other: Self) -> Self {
- self.extend(other);
+ for (other_name, mut other_value) in other {
+ if let Some(self_value) = self.remove(&other_name) {
+ other_value = self_value.merge_right(other_value);
+ }
+
+ self.insert(other_name, other_value);
+ }
self
}
}
diff --git a/src/primitive.rs b/src/primitive.rs
index 2c205009f6..8c50cb0e1a 100644
--- a/src/primitive.rs
+++ b/src/primitive.rs
@@ -1,34 +1,26 @@
+use std::marker::PhantomData;
+use std::num::NonZeroU64;
+
use crate::merge_right::MergeRight;
pub trait Primitive {}
-impl Primitive for u64 {}
-
-impl Primitive for u32 {}
-
-impl Primitive for u16 {}
-
-impl Primitive for u8 {}
-
-impl Primitive for usize {}
-
-impl Primitive for i64 {}
-
-impl Primitive for i32 {}
-
-impl Primitive for i16 {}
-
-impl Primitive for i8 {}
-
-impl Primitive for f64 {}
-
-impl Primitive for f32 {}
-
impl Primitive for bool {}
-
impl Primitive for char {}
-
+impl Primitive for f32 {}
+impl Primitive for f64 {}
+impl Primitive for i16 {}
+impl Primitive for i32 {}
+impl Primitive for i64 {}
+impl Primitive for i8 {}
+impl Primitive for NonZeroU64 {}
impl Primitive for String {}
+impl Primitive for u16 {}
+impl Primitive for u32 {}
+impl Primitive for u64 {}
+impl Primitive for u8 {}
+impl Primitive for usize {}
+impl Primitive for PhantomData {}
impl MergeRight for A {
fn merge_right(self, other: Self) -> Self {
diff --git a/src/rest/endpoint_set.rs b/src/rest/endpoint_set.rs
index 1805fe836e..3cd2597af6 100644
--- a/src/rest/endpoint_set.rs
+++ b/src/rest/endpoint_set.rs
@@ -5,13 +5,14 @@ use super::partial_request::PartialRequest;
use super::Request;
use crate::blueprint::Blueprint;
use crate::http::RequestContext;
+use crate::macros::MergeRight;
use crate::merge_right::MergeRight;
use crate::rest::operation::OperationQuery;
use crate::runtime::TargetRuntime;
use crate::valid::Validator;
/// Collection of endpoints
-#[derive(Default, Clone, Debug)]
+#[derive(Default, Clone, Debug, MergeRight)]
pub struct EndpointSet {
endpoints: Vec,
marker: std::marker::PhantomData,
@@ -81,13 +82,6 @@ impl EndpointSet {
}
}
-impl MergeRight for EndpointSet {
- fn merge_right(mut self, other: Self) -> Self {
- self.extend(other);
- self
- }
-}
-
impl EndpointSet {
pub fn matches(&self, request: &Request) -> Option {
self.endpoints.iter().find_map(|e| e.matches(request))
diff --git a/tailcall-macros/Cargo.toml b/tailcall-macros/Cargo.toml
new file mode 100644
index 0000000000..8b695509b5
--- /dev/null
+++ b/tailcall-macros/Cargo.toml
@@ -0,0 +1,12 @@
+[package]
+name = "tailcall-macros"
+version = "0.1.0"
+edition = "2021"
+
+[lib]
+proc-macro = true
+
+[dependencies]
+syn = { version = "2.0.58", features = ["derive", "full"] }
+quote = "1.0"
+proc-macro2 = "1.0"
diff --git a/tailcall-macros/src/lib.rs b/tailcall-macros/src/lib.rs
new file mode 100644
index 0000000000..b0fd5ed4d1
--- /dev/null
+++ b/tailcall-macros/src/lib.rs
@@ -0,0 +1,12 @@
+extern crate proc_macro;
+
+use proc_macro::TokenStream;
+
+mod merge_right;
+
+use crate::merge_right::expand_merge_right_derive;
+
+#[proc_macro_derive(MergeRight, attributes(merge_right))]
+pub fn merge_right_derive(input: TokenStream) -> TokenStream {
+ expand_merge_right_derive(input)
+}
diff --git a/tailcall-macros/src/merge_right.rs b/tailcall-macros/src/merge_right.rs
new file mode 100644
index 0000000000..dce5c428f5
--- /dev/null
+++ b/tailcall-macros/src/merge_right.rs
@@ -0,0 +1,151 @@
+extern crate proc_macro;
+
+use proc_macro::TokenStream;
+use quote::quote;
+use syn::spanned::Spanned;
+use syn::{parse_macro_input, Data, DeriveInput, Fields};
+
+const MERGE_RIGHT_FN: &str = "merge_right_fn";
+const MERGE_RIGHT: &str = "merge_right";
+
+#[derive(Default)]
+struct Attrs {
+ merge_right_fn: Option,
+}
+
+fn get_attrs(attrs: &[syn::Attribute]) -> syn::Result {
+ let mut attrs_ret = Attrs::default();
+ for attr in attrs {
+ if attr.path().is_ident(MERGE_RIGHT) {
+ attr.parse_nested_meta(|meta| {
+ if meta.path.is_ident(MERGE_RIGHT_FN) {
+ let p: syn::Expr = meta.value()?.parse()?;
+ let lit =
+ if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit), .. }) = p {
+ let suffix = lit.suffix();
+ if !suffix.is_empty() {
+ return Err(syn::Error::new(
+ lit.span(),
+ format!("unexpected suffix `{}` on string literal", suffix),
+ ));
+ }
+ lit
+ } else {
+ return Err(syn::Error::new(
+ p.span(),
+ format!(
+ "expected merge_right {} attribute to be a string.",
+ MERGE_RIGHT_FN
+ ),
+ ));
+ };
+ let expr_path: syn::ExprPath = lit.parse()?;
+ attrs_ret.merge_right_fn = Some(expr_path);
+ Ok(())
+ } else {
+ Err(syn::Error::new(attr.span(), "Unknown helper attribute."))
+ }
+ })?;
+ }
+ }
+ Ok(attrs_ret)
+}
+
+pub fn expand_merge_right_derive(input: TokenStream) -> TokenStream {
+ let input = parse_macro_input!(input as DeriveInput);
+
+ let name = input.ident.clone();
+ let generics = input.generics.clone();
+ let gen = match input.data {
+ // Implement for structs
+ Data::Struct(data) => {
+ let fields = if let Fields::Named(fields) = data.fields {
+ fields.named
+ } else {
+ // Adjust this match arm to handle other kinds of struct fields (unnamed/tuple
+ // structs, unit structs)
+ unimplemented!()
+ };
+
+ let merge_logic = fields.iter().map(|f| {
+ let attrs = get_attrs(&f.attrs);
+ if let Err(err) = attrs {
+ panic!("{}", err);
+ }
+ let attrs = attrs.unwrap();
+ let name = &f.ident;
+ if let Some(merge_right_fn) = attrs.merge_right_fn {
+ quote! {
+ #name: #merge_right_fn(self.#name, other.#name),
+ }
+ } else {
+ quote! {
+ #name: self.#name.merge_right(other.#name),
+ }
+ }
+ });
+
+ let generics_lt = generics.lt_token;
+ let generics_gt = generics.gt_token;
+ let generics_params = generics.params;
+
+ let generics_del = quote! {
+ #generics_lt #generics_params #generics_gt
+ };
+
+ quote! {
+ impl #generics_del MergeRight for #name #generics_del {
+ fn merge_right(self, other: Self) -> Self {
+ Self {
+ #(#merge_logic)*
+ }
+ }
+ }
+ }
+ }
+ // Implement for enums
+ Data::Enum(_) => quote! {
+ impl MergeRight for #name {
+ fn merge_right(self, other: Self) -> Self {
+ other
+ }
+ }
+ },
+ // Optionally handle or disallow unions
+ Data::Union(_) => {
+ return syn::Error::new_spanned(input, "Union types are not supported by MergeRight")
+ .to_compile_error()
+ .into()
+ }
+ };
+
+ gen.into()
+}
+
+#[cfg(test)]
+mod tests {
+ use syn::{parse_quote, Attribute};
+
+ use super::*;
+
+ #[test]
+ fn test_get_attrs_invalid_type() {
+ let attrs: Vec = vec![parse_quote!(#[merge_right(merge_right_fn = 123)])];
+ let result = get_attrs(&attrs);
+ assert!(
+ result.is_err(),
+ "Expected error with non-string type for `merge_right_fn`"
+ );
+ }
+
+ #[test]
+ fn test_get_attrs_unexpected_suffix() {
+ let attrs: Vec =
+ vec![parse_quote!(#[merge_right(merge_right_fn = "some_fn()")])];
+ let result = get_attrs(&attrs);
+ assert!(
+ result.is_err(),
+ "Expected error with unexpected suffix on string literal"
+ );
+ }
+}