diff --git a/Cargo.lock b/Cargo.lock index 60ae7da967..cbcac0edf3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5147,6 +5147,7 @@ dependencies = [ "serde_yaml", "stripmargin", "strum_macros 0.26.2", + "tailcall-macros", "tailcall-prettier", "temp-env", "tempfile", @@ -5217,6 +5218,15 @@ dependencies = [ "worker", ] +[[package]] +name = "tailcall-macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.59", +] + [[package]] name = "tailcall-prettier" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 76211a3034..b7fd57f72c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -144,6 +144,7 @@ async-graphql = { workspace = true, features = [ dotenvy = "0.15" convert_case = "0.6.0" rand = "0.8.5" +tailcall-macros = { path = "tailcall-macros" } tonic-types = "0.11.0" @@ -197,6 +198,7 @@ members = [ "tailcall-autogen", "tailcall-aws-lambda", "tailcall-cloudflare", + "tailcall-macros", "tailcall-prettier", "tailcall-query-plan", ] diff --git a/src/config/config.rs b/src/config/config.rs index 20af7f2887..d294bac5b5 100644 --- a/src/config/config.rs +++ b/src/config/config.rs @@ -16,11 +16,21 @@ use crate::directive::DirectiveCodec; use crate::http::Method; use crate::is_default; use crate::json::JsonSchema; +use crate::macros::MergeRight; use crate::merge_right::MergeRight; use crate::valid::{Valid, Validator}; #[derive( - Serialize, Deserialize, Clone, Debug, Default, Setters, PartialEq, Eq, schemars::JsonSchema, + Serialize, + Deserialize, + Clone, + Debug, + Default, + Setters, + PartialEq, + Eq, + schemars::JsonSchema, + MergeRight, )] #[serde(rename_all = "camelCase")] pub struct Config { @@ -154,28 +164,12 @@ impl Config { } } -impl MergeRight for Config { - fn merge_right(self, other: Self) -> Self { - let server = self.server.merge_right(other.server); - let types = merge_types(self.types, other.types); - let unions = merge_unions(self.unions, other.unions); - let schema = self.schema.merge_right(other.schema); - let upstream = self.upstream.merge_right(other.upstream); - let links = merge_links(self.links, other.links); - let telemetry = self.telemetry.merge_right(other.telemetry); - - Self { server, upstream, types, schema, unions, links, telemetry } - } -} - -fn merge_links(self_links: Vec, other_links: Vec) -> Vec { - self_links.merge_right(other_links) -} - /// /// Represents a GraphQL type. /// A type can be an object, interface, enum or scalar. -#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema)] +#[derive( + Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema, MergeRight, +)] pub struct Type { /// /// A map of field name and its definition. @@ -229,22 +223,9 @@ impl Type { } } -impl MergeRight for Type { - fn merge_right(mut self, other: Self) -> Self { - let fields = self.fields.merge_right(other.fields); - self.implements = self.implements.merge_right(other.implements); - if let Some(ref variants) = self.variants { - if let Some(ref other) = other.variants { - self.variants = Some(variants.union(other).cloned().collect()); - } - } else { - self.variants = other.variants; - } - Self { fields, ..self } - } -} - -#[derive(Clone, Debug, Default, PartialEq, Deserialize, Serialize, Eq, schemars::JsonSchema)] +#[derive( + Clone, Debug, Default, PartialEq, Deserialize, Serialize, Eq, schemars::JsonSchema, MergeRight, +)] #[serde(deny_unknown_fields)] /// Used to represent an identifier for a type. Typically used via only by the /// configuration generators to provide additional information about the type. @@ -253,7 +234,7 @@ pub struct Tag { pub id: String, } -#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, Eq, schemars::JsonSchema)] +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, Eq, schemars::JsonSchema, MergeRight)] /// The @cache operator enables caching for the query, field or type it is /// applied to. #[serde(rename_all = "camelCase")] @@ -264,38 +245,22 @@ pub struct Cache { pub max_age: NonZeroU64, } -#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Default, schemars::JsonSchema)] +#[derive( + Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Default, schemars::JsonSchema, MergeRight, +)] pub struct Protected {} -fn merge_types( - mut self_types: BTreeMap, - other_types: BTreeMap, -) -> BTreeMap { - for (name, mut other_type) in other_types { - if let Some(self_type) = self_types.remove(&name) { - other_type = self_type.merge_right(other_type); - } - - self_types.insert(name, other_type); - } - self_types -} - -fn merge_unions( - mut self_unions: BTreeMap, - other_unions: BTreeMap, -) -> BTreeMap { - for (name, mut other_union) in other_unions { - if let Some(self_union) = self_unions.remove(&name) { - other_union = self_union.merge_right(other_union); - } - self_unions.insert(name, other_union); - } - self_unions -} - #[derive( - Serialize, Deserialize, Clone, Debug, Default, Setters, PartialEq, Eq, schemars::JsonSchema, + Serialize, + Deserialize, + Clone, + Debug, + Default, + Setters, + PartialEq, + Eq, + schemars::JsonSchema, + MergeRight, )] #[setters(strip_option)] pub struct RootSchema { @@ -306,17 +271,6 @@ pub struct RootSchema { pub subscription: Option, } -impl MergeRight for RootSchema { - // TODO: add unit-tests - fn merge_right(self, other: Self) -> Self { - Self { - query: self.query.merge_right(other.query), - mutation: self.mutation.merge_right(other.mutation), - subscription: self.subscription.merge_right(other.subscription), - } - } -} - #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)] #[serde(deny_unknown_fields)] /// Used to omit a field from public consumption. @@ -409,6 +363,13 @@ pub struct Field { pub protected: Option, } +// It's a terminal implementation of MergeRight +impl MergeRight for Field { + fn merge_right(self, other: Self) -> Self { + other + } +} + impl Field { pub fn has_resolver(&self) -> bool { self.http.is_some() @@ -522,19 +483,12 @@ pub struct Arg { pub default_value: Option, } -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)] +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema, MergeRight)] pub struct Union { pub types: BTreeSet, pub doc: Option, } -impl MergeRight for Union { - fn merge_right(mut self, other: Self) -> Self { - self.types = self.types.merge_right(other.types); - self - } -} - #[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema)] #[serde(deny_unknown_fields)] /// The @http operator indicates that a field or node is backed by a REST API. diff --git a/src/config/config_module.rs b/src/config/config_module.rs index 9c043a5f02..266c687b2f 100644 --- a/src/config/config_module.rs +++ b/src/config/config_module.rs @@ -9,13 +9,14 @@ use rustls_pki_types::{CertificateDer, PrivateKeyDer}; use crate::blueprint::GrpcMethod; use crate::config::Config; +use crate::macros::MergeRight; use crate::merge_right::MergeRight; use crate::rest::{EndpointSet, Unchecked}; use crate::scalar; /// A wrapper on top of Config that contains all the resolved extensions and /// computed values. -#[derive(Clone, Debug, Default, Setters)] +#[derive(Clone, Debug, Default, Setters, MergeRight)] pub struct ConfigModule { pub config: Config, pub extensions: Extensions, @@ -39,7 +40,7 @@ impl Deref for Content { /// Extensions are meta-information required before we can generate the /// blueprint. Typically, this information cannot be inferred without performing /// an IO operation, i.e., reading a file, making an HTTP call, etc. -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, MergeRight)] pub struct Extensions { /// Contains the file descriptor sets resolved from the links pub grpc_file_descriptors: Vec>, @@ -79,35 +80,6 @@ impl Extensions { } } -impl MergeRight for Extensions { - fn merge_right(mut self, mut other: Self) -> Self { - self.grpc_file_descriptors = self - .grpc_file_descriptors - .merge_right(other.grpc_file_descriptors); - self.script = self.script.merge_right(other.script.take()); - self.cert = self.cert.merge_right(other.cert); - self.keys = if !other.keys.is_empty() { - other.keys - } else { - self.keys - }; - self.endpoint_set = self.endpoint_set.merge_right(other.endpoint_set); - self.htpasswd = self.htpasswd.merge_right(other.htpasswd); - self.jwks = self.jwks.merge_right(other.jwks); - self - } -} - -impl MergeRight for ConfigModule { - fn merge_right(mut self, other: Self) -> Self { - self.config = self.config.merge_right(other.config); - self.extensions = self.extensions.merge_right(other.extensions); - self.input_types = self.input_types.merge_right(other.input_types); - self.output_types = self.output_types.merge_right(other.output_types); - self - } -} - impl Deref for ConfigModule { type Target = Config; fn deref(&self) -> &Self::Target { diff --git a/src/config/cors.rs b/src/config/cors.rs index 3934a719f6..e9a81af341 100644 --- a/src/config/cors.rs +++ b/src/config/cors.rs @@ -3,9 +3,13 @@ use serde::{Deserialize, Serialize}; use crate::http::Method; use crate::is_default; +use crate::macros::MergeRight; +use crate::merge_right::MergeRight; /// Type to configure Cross-Origin Resource Sharing (CORS) for a server. -#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema)] +#[derive( + Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema, MergeRight, +)] #[serde(rename_all = "camelCase")] pub struct Cors { /// Indicates whether the server allows credentials (e.g., cookies, diff --git a/src/config/headers.rs b/src/config/headers.rs index 0dd82f5cd1..71930dbb9b 100644 --- a/src/config/headers.rs +++ b/src/config/headers.rs @@ -5,8 +5,12 @@ use serde::{Deserialize, Serialize}; use crate::config::cors::Cors; use crate::config::KeyValue; use crate::is_default; +use crate::macros::MergeRight; +use crate::merge_right::MergeRight; -#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema)] +#[derive( + Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema, MergeRight, +)] #[serde(rename_all = "camelCase")] pub struct Headers { #[serde(default, skip_serializing_if = "is_default")] @@ -47,21 +51,3 @@ impl Headers { self.cors.clone() } } - -pub fn merge_headers(current: Option, other: Option) -> Option { - let mut headers = current.clone(); - - if let Some(other_headers) = other { - if let Some(mut self_headers) = current.clone() { - self_headers.cache_control = other_headers.cache_control.or(self_headers.cache_control); - self_headers.custom.extend(other_headers.custom); - self_headers.cors = other_headers.cors.or(self_headers.cors); - - headers = Some(self_headers); - } else { - headers = Some(other_headers); - } - } - - headers -} diff --git a/src/config/server.rs b/src/config/server.rs index 797e169425..f83e88f88a 100644 --- a/src/config/server.rs +++ b/src/config/server.rs @@ -2,13 +2,16 @@ use std::collections::{BTreeMap, BTreeSet}; use serde::{Deserialize, Serialize}; -use super::{merge_headers, merge_key_value_vecs}; +use super::merge_key_value_vecs; use crate::config::headers::Headers; use crate::config::KeyValue; use crate::is_default; +use crate::macros::MergeRight; use crate::merge_right::MergeRight; -#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema)] +#[derive( + Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema, MergeRight, +)] #[serde(deny_unknown_fields)] #[serde(rename_all = "camelCase")] /// The `@server` directive, when applied at the schema level, offers a @@ -82,6 +85,7 @@ pub struct Server { pub showcase: Option, #[serde(default, skip_serializing_if = "is_default")] + #[merge_right(merge_right_fn = "merge_right_vars")] /// This configuration defines local variables for server operations. Useful /// for storing constant configurations, secrets, or shared information. pub vars: Vec, @@ -97,31 +101,35 @@ pub struct Server { pub workers: Option, } -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)] +fn merge_right_vars(mut left: Vec, right: Vec) -> Vec { + left = right.iter().fold(left.to_vec(), |mut acc, kv| { + let position = acc.iter().position(|x| x.key == kv.key); + if let Some(pos) = position { + acc[pos] = kv.clone(); + } else { + acc.push(kv.clone()); + }; + acc + }); + left = merge_key_value_vecs(&left, &right); + left +} + +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema, MergeRight)] #[serde(rename_all = "camelCase")] pub struct ScriptOptions { pub timeout: Option, } -impl MergeRight for ScriptOptions { - fn merge_right(self, other: Self) -> Self { - ScriptOptions { timeout: self.timeout.merge_right(other.timeout) } - } -} - -#[derive(Deserialize, Serialize, Debug, PartialEq, Eq, Clone, Default, schemars::JsonSchema)] +#[derive( + Deserialize, Serialize, Debug, PartialEq, Eq, Clone, Default, schemars::JsonSchema, MergeRight, +)] pub enum HttpVersion { #[default] HTTP1, HTTP2, } -impl MergeRight for HttpVersion { - fn merge_right(self, other: Self) -> Self { - other - } -} - impl Server { pub fn enable_apollo_tracing(&self) -> bool { self.apollo_tracing.unwrap_or(false) @@ -208,41 +216,6 @@ impl Server { } } -impl MergeRight for Server { - fn merge_right(mut self, other: Self) -> Self { - self.apollo_tracing = self.apollo_tracing.merge_right(other.apollo_tracing); - self.headers = merge_headers(self.headers, other.headers); - self.graphiql = self.graphiql.merge_right(other.graphiql); - self.introspection = self.introspection.merge_right(other.introspection); - self.query_validation = self.query_validation.merge_right(other.query_validation); - self.response_validation = self - .response_validation - .merge_right(other.response_validation); - self.batch_requests = self.batch_requests.merge_right(other.batch_requests); - self.global_response_timeout = self - .global_response_timeout - .merge_right(other.global_response_timeout); - self.showcase = self.showcase.merge_right(other.showcase); - self.workers = self.workers.merge_right(other.workers); - self.port = self.port.merge_right(other.port); - self.hostname = self.hostname.merge_right(other.hostname); - self.vars = other.vars.iter().fold(self.vars.to_vec(), |mut acc, kv| { - let position = acc.iter().position(|x| x.key == kv.key); - if let Some(pos) = position { - acc[pos] = kv.clone(); - } else { - acc.push(kv.clone()); - }; - acc - }); - self.vars = merge_key_value_vecs(&self.vars, &other.vars); - self.version = self.version.merge_right(other.version); - self.pipeline_flush = self.pipeline_flush.merge_right(other.pipeline_flush); - self.script = self.script.merge_right(other.script); - self - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/config/telemetry.rs b/src/config/telemetry.rs index 8be111bd7e..efaf01f12e 100644 --- a/src/config/telemetry.rs +++ b/src/config/telemetry.rs @@ -5,6 +5,7 @@ use super::KeyValue; use crate::config::{Apollo, ConfigReaderContext}; use crate::helpers::headers::to_mustache_headers; use crate::is_default; +use crate::macros::MergeRight; use crate::merge_right::MergeRight; use crate::mustache::Mustache; use crate::valid::Validator; @@ -18,7 +19,7 @@ mod defaults { } /// Output the opentelemetry data to the stdout. Mostly used for debug purposes -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema, MergeRight)] #[serde(rename_all = "camelCase")] pub struct StdoutExporter { /// Output to stdout in pretty human-readable format @@ -26,14 +27,8 @@ pub struct StdoutExporter { pub pretty: bool, } -impl MergeRight for StdoutExporter { - fn merge_right(self, other: Self) -> Self { - Self { pretty: other.pretty || self.pretty } - } -} - /// Output the opentelemetry data to otlp collector -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema, MergeRight)] #[serde(rename_all = "camelCase")] pub struct OtlpExporter { pub url: String, @@ -41,15 +36,6 @@ pub struct OtlpExporter { pub headers: Vec, } -impl MergeRight for OtlpExporter { - fn merge_right(self, other: Self) -> Self { - let mut headers = self.headers; - headers.extend(other.headers.iter().cloned()); - - Self { url: other.url, headers } - } -} - /// Output format for prometheus data #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema)] #[serde(rename_all = "camelCase")] @@ -72,13 +58,7 @@ pub struct PrometheusExporter { pub format: PrometheusFormat, } -impl PrometheusExporter { - fn merge_right(&self, other: Self) -> Self { - other - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema, MergeRight)] #[serde(rename_all = "camelCase")] pub enum TelemetryExporter { Stdout(StdoutExporter), @@ -87,23 +67,6 @@ pub enum TelemetryExporter { Apollo(Apollo), } -impl MergeRight for TelemetryExporter { - fn merge_right(self, other: Self) -> Self { - match (self, other) { - (TelemetryExporter::Stdout(left), TelemetryExporter::Stdout(right)) => { - TelemetryExporter::Stdout(left.merge_right(right)) - } - (TelemetryExporter::Otlp(left), TelemetryExporter::Otlp(right)) => { - TelemetryExporter::Otlp(left.merge_right(right)) - } - (TelemetryExporter::Prometheus(left), TelemetryExporter::Prometheus(right)) => { - TelemetryExporter::Prometheus(left.merge_right(right)) - } - (_, other) => other, - } - } -} - #[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema)] #[serde(deny_unknown_fields)] #[serde(rename_all = "camelCase")] @@ -225,10 +188,7 @@ mod tests { Telemetry { export: Some(TelemetryExporter::Otlp(OtlpExporter { url: "test-url-2".to_owned(), - headers: vec![ - KeyValue { key: "header_a".to_owned(), value: "a".to_owned() }, - KeyValue { key: "header_b".to_owned(), value: "b".to_owned() } - ] + headers: vec![KeyValue { key: "header_b".to_owned(), value: "b".to_owned() }] })), request_headers: vec!["Api-Key-A".to_string(), "Api-Key-B".to_string(),] } diff --git a/src/config/upstream.rs b/src/config/upstream.rs index 47e8d9b367..a4d3b9b58b 100644 --- a/src/config/upstream.rs +++ b/src/config/upstream.rs @@ -4,9 +4,12 @@ use derive_setters::Setters; use serde::{Deserialize, Serialize}; use crate::is_default; +use crate::macros::MergeRight; use crate::merge_right::MergeRight; -#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug, Setters, schemars::JsonSchema)] +#[derive( + Serialize, Deserialize, PartialEq, Eq, Clone, Debug, Setters, schemars::JsonSchema, MergeRight, +)] #[serde(rename_all = "camelCase", default)] pub struct Batch { pub delay: usize, @@ -20,19 +23,22 @@ impl Default for Batch { } } -#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug, schemars::JsonSchema)] +#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug, schemars::JsonSchema, MergeRight)] pub struct Proxy { pub url: String, } -impl MergeRight for Proxy { - fn merge_right(self, other: Self) -> Self { - other - } -} - #[derive( - Serialize, Deserialize, PartialEq, Eq, Clone, Debug, Setters, Default, schemars::JsonSchema, + Serialize, + Deserialize, + PartialEq, + Eq, + Clone, + Debug, + Setters, + Default, + schemars::JsonSchema, + MergeRight, )] #[serde(deny_unknown_fields)] #[serde(rename_all = "camelCase", default)] @@ -174,45 +180,6 @@ impl Upstream { } } -impl MergeRight for Upstream { - // TODO: add unit tests for merge - fn merge_right(mut self, other: Self) -> Self { - self.allowed_headers = self.allowed_headers.merge_right(other.allowed_headers); - self.base_url = self.base_url.merge_right(other.base_url); - self.connect_timeout = self.connect_timeout.merge_right(other.connect_timeout); - self.http_cache = self.http_cache.merge_right(other.http_cache); - self.keep_alive_interval = self - .keep_alive_interval - .merge_right(other.keep_alive_interval); - self.keep_alive_timeout = self - .keep_alive_timeout - .merge_right(other.keep_alive_timeout); - self.keep_alive_while_idle = self - .keep_alive_while_idle - .merge_right(other.keep_alive_while_idle); - self.pool_idle_timeout = self.pool_idle_timeout.merge_right(other.pool_idle_timeout); - self.pool_max_idle_per_host = self - .pool_max_idle_per_host - .merge_right(other.pool_max_idle_per_host); - self.proxy = self.proxy.merge_right(other.proxy); - self.tcp_keep_alive = self.tcp_keep_alive.merge_right(other.tcp_keep_alive); - self.timeout = self.timeout.merge_right(other.timeout); - self.user_agent = self.user_agent.merge_right(other.user_agent); - - if let Some(other) = other.batch { - let mut batch = self.batch.unwrap_or_default(); - batch.max_size = other.max_size; - batch.delay = other.delay; - batch.headers = batch.headers.merge_right(other.headers); - - self.batch = Some(batch); - } - - self.http2_only = self.http2_only.merge_right(other.http2_only); - self - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/lib.rs b/src/lib.rs index 5242bd7ade..6ab74992b2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,12 +38,14 @@ pub mod tracing; pub mod try_fold; pub mod valid; +// Re-export everything from `tailcall_macros` as `macros` use std::borrow::Cow; use std::hash::Hash; use std::num::NonZeroU64; use async_graphql_value::ConstValue; use http::Response; +pub use tailcall_macros as macros; pub trait EnvIO: Send + Sync + 'static { fn get(&self, key: &str) -> Option>; diff --git a/src/merge_right.rs b/src/merge_right.rs index e9e7f9c5ca..43df242828 100644 --- a/src/merge_right.rs +++ b/src/merge_right.rs @@ -1,4 +1,5 @@ use std::collections::{BTreeMap, BTreeSet, HashSet}; +use std::sync::Arc; pub trait MergeRight { fn merge_right(self, other: Self) -> Self; @@ -15,6 +16,14 @@ impl MergeRight for Option { } } +impl MergeRight for Arc { + fn merge_right(self, other: Self) -> Self { + let l = Arc::into_inner(self); + let r = Arc::into_inner(other); + Arc::new(l.merge_right(r).unwrap_or_default()) + } +} + impl MergeRight for Vec { fn merge_right(mut self, other: Self) -> Self { self.extend(other); @@ -25,10 +34,16 @@ impl MergeRight for Vec { impl MergeRight for BTreeMap where K: Ord, - V: Clone, + V: Clone + MergeRight, { fn merge_right(mut self, other: Self) -> Self { - self.extend(other); + for (other_name, mut other_value) in other { + if let Some(self_value) = self.remove(&other_name) { + other_value = self_value.merge_right(other_value); + } + + self.insert(other_name, other_value); + } self } } diff --git a/src/primitive.rs b/src/primitive.rs index 2c205009f6..8c50cb0e1a 100644 --- a/src/primitive.rs +++ b/src/primitive.rs @@ -1,34 +1,26 @@ +use std::marker::PhantomData; +use std::num::NonZeroU64; + use crate::merge_right::MergeRight; pub trait Primitive {} -impl Primitive for u64 {} - -impl Primitive for u32 {} - -impl Primitive for u16 {} - -impl Primitive for u8 {} - -impl Primitive for usize {} - -impl Primitive for i64 {} - -impl Primitive for i32 {} - -impl Primitive for i16 {} - -impl Primitive for i8 {} - -impl Primitive for f64 {} - -impl Primitive for f32 {} - impl Primitive for bool {} - impl Primitive for char {} - +impl Primitive for f32 {} +impl Primitive for f64 {} +impl Primitive for i16 {} +impl Primitive for i32 {} +impl Primitive for i64 {} +impl Primitive for i8 {} +impl Primitive for NonZeroU64 {} impl Primitive for String {} +impl Primitive for u16 {} +impl Primitive for u32 {} +impl Primitive for u64 {} +impl Primitive for u8 {} +impl Primitive for usize {} +impl Primitive for PhantomData {} impl MergeRight for A { fn merge_right(self, other: Self) -> Self { diff --git a/src/rest/endpoint_set.rs b/src/rest/endpoint_set.rs index 1805fe836e..3cd2597af6 100644 --- a/src/rest/endpoint_set.rs +++ b/src/rest/endpoint_set.rs @@ -5,13 +5,14 @@ use super::partial_request::PartialRequest; use super::Request; use crate::blueprint::Blueprint; use crate::http::RequestContext; +use crate::macros::MergeRight; use crate::merge_right::MergeRight; use crate::rest::operation::OperationQuery; use crate::runtime::TargetRuntime; use crate::valid::Validator; /// Collection of endpoints -#[derive(Default, Clone, Debug)] +#[derive(Default, Clone, Debug, MergeRight)] pub struct EndpointSet { endpoints: Vec, marker: std::marker::PhantomData, @@ -81,13 +82,6 @@ impl EndpointSet { } } -impl MergeRight for EndpointSet { - fn merge_right(mut self, other: Self) -> Self { - self.extend(other); - self - } -} - impl EndpointSet { pub fn matches(&self, request: &Request) -> Option { self.endpoints.iter().find_map(|e| e.matches(request)) diff --git a/tailcall-macros/Cargo.toml b/tailcall-macros/Cargo.toml new file mode 100644 index 0000000000..8b695509b5 --- /dev/null +++ b/tailcall-macros/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "tailcall-macros" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "2.0.58", features = ["derive", "full"] } +quote = "1.0" +proc-macro2 = "1.0" diff --git a/tailcall-macros/src/lib.rs b/tailcall-macros/src/lib.rs new file mode 100644 index 0000000000..b0fd5ed4d1 --- /dev/null +++ b/tailcall-macros/src/lib.rs @@ -0,0 +1,12 @@ +extern crate proc_macro; + +use proc_macro::TokenStream; + +mod merge_right; + +use crate::merge_right::expand_merge_right_derive; + +#[proc_macro_derive(MergeRight, attributes(merge_right))] +pub fn merge_right_derive(input: TokenStream) -> TokenStream { + expand_merge_right_derive(input) +} diff --git a/tailcall-macros/src/merge_right.rs b/tailcall-macros/src/merge_right.rs new file mode 100644 index 0000000000..dce5c428f5 --- /dev/null +++ b/tailcall-macros/src/merge_right.rs @@ -0,0 +1,151 @@ +extern crate proc_macro; + +use proc_macro::TokenStream; +use quote::quote; +use syn::spanned::Spanned; +use syn::{parse_macro_input, Data, DeriveInput, Fields}; + +const MERGE_RIGHT_FN: &str = "merge_right_fn"; +const MERGE_RIGHT: &str = "merge_right"; + +#[derive(Default)] +struct Attrs { + merge_right_fn: Option, +} + +fn get_attrs(attrs: &[syn::Attribute]) -> syn::Result { + let mut attrs_ret = Attrs::default(); + for attr in attrs { + if attr.path().is_ident(MERGE_RIGHT) { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident(MERGE_RIGHT_FN) { + let p: syn::Expr = meta.value()?.parse()?; + let lit = + if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit), .. }) = p { + let suffix = lit.suffix(); + if !suffix.is_empty() { + return Err(syn::Error::new( + lit.span(), + format!("unexpected suffix `{}` on string literal", suffix), + )); + } + lit + } else { + return Err(syn::Error::new( + p.span(), + format!( + "expected merge_right {} attribute to be a string.", + MERGE_RIGHT_FN + ), + )); + }; + let expr_path: syn::ExprPath = lit.parse()?; + attrs_ret.merge_right_fn = Some(expr_path); + Ok(()) + } else { + Err(syn::Error::new(attr.span(), "Unknown helper attribute.")) + } + })?; + } + } + Ok(attrs_ret) +} + +pub fn expand_merge_right_derive(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + + let name = input.ident.clone(); + let generics = input.generics.clone(); + let gen = match input.data { + // Implement for structs + Data::Struct(data) => { + let fields = if let Fields::Named(fields) = data.fields { + fields.named + } else { + // Adjust this match arm to handle other kinds of struct fields (unnamed/tuple + // structs, unit structs) + unimplemented!() + }; + + let merge_logic = fields.iter().map(|f| { + let attrs = get_attrs(&f.attrs); + if let Err(err) = attrs { + panic!("{}", err); + } + let attrs = attrs.unwrap(); + let name = &f.ident; + if let Some(merge_right_fn) = attrs.merge_right_fn { + quote! { + #name: #merge_right_fn(self.#name, other.#name), + } + } else { + quote! { + #name: self.#name.merge_right(other.#name), + } + } + }); + + let generics_lt = generics.lt_token; + let generics_gt = generics.gt_token; + let generics_params = generics.params; + + let generics_del = quote! { + #generics_lt #generics_params #generics_gt + }; + + quote! { + impl #generics_del MergeRight for #name #generics_del { + fn merge_right(self, other: Self) -> Self { + Self { + #(#merge_logic)* + } + } + } + } + } + // Implement for enums + Data::Enum(_) => quote! { + impl MergeRight for #name { + fn merge_right(self, other: Self) -> Self { + other + } + } + }, + // Optionally handle or disallow unions + Data::Union(_) => { + return syn::Error::new_spanned(input, "Union types are not supported by MergeRight") + .to_compile_error() + .into() + } + }; + + gen.into() +} + +#[cfg(test)] +mod tests { + use syn::{parse_quote, Attribute}; + + use super::*; + + #[test] + fn test_get_attrs_invalid_type() { + let attrs: Vec = vec![parse_quote!(#[merge_right(merge_right_fn = 123)])]; + let result = get_attrs(&attrs); + assert!( + result.is_err(), + "Expected error with non-string type for `merge_right_fn`" + ); + } + + #[test] + fn test_get_attrs_unexpected_suffix() { + let attrs: Vec = + vec![parse_quote!(#[merge_right(merge_right_fn = "some_fn()")])]; + let result = get_attrs(&attrs); + assert!( + result.is_err(), + "Expected error with unexpected suffix on string literal" + ); + } +}