From 75ec70c790ffe7a764b44eb4957d36345ee41724 Mon Sep 17 00:00:00 2001 From: Ihor Vovk Date: Sat, 14 Dec 2024 13:47:32 +0100 Subject: [PATCH] Be able to filter incoming headers, remove Connection* headers by default (#62) --- .../org/ivovk/connect_rpc_scala/ConnectHandler.scala | 3 ++- .../connect_rpc_scala/ConnectRouteBuilder.scala | 12 ++++++++++++ .../scala/org/ivovk/connect_rpc_scala/Mappings.scala | 6 ++++-- .../ivovk/connect_rpc_scala/TranscodingHandler.scala | 3 ++- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala index 4039782..50df22d 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala @@ -21,6 +21,7 @@ class ConnectHandler[F[_] : Async]( channel: Channel, errorHandler: ErrorHandler[F], treatTrailersAsHeaders: Boolean, + incomingHeadersFilter: String => Boolean, ) { private val logger: Logger = LoggerFactory.getLogger(getClass) @@ -75,7 +76,7 @@ class ConnectHandler[F[_] : Async]( channel, method.descriptor, callOptions, - req.headers.toMetadata, + req.headers.toMetadata(incomingHeadersFilter), message ) } diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala index 183f247..c26c8cf 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala @@ -21,6 +21,9 @@ import scala.concurrent.duration.* object ConnectRouteBuilder { + private val DefaultIncomingHeadersFilter: String => Boolean = name => + !name.toLowerCase.startsWith("connection") + def forService[F[_] : Async](service: ServerServiceDefinition): ConnectRouteBuilder[F] = forServices(Seq(service)) @@ -33,6 +36,7 @@ object ConnectRouteBuilder { serverConfigurator = identity, channelConfigurator = identity, customJsonCodec = None, + incomingHeadersFilter = DefaultIncomingHeadersFilter, pathPrefix = Uri.Path.Root, executor = ExecutionContext.global, waitForShutdown = 5.seconds, @@ -46,6 +50,7 @@ final class ConnectRouteBuilder[F[_] : Async] private( serverConfigurator: Endo[ServerBuilder[_]], channelConfigurator: Endo[ManagedChannelBuilder[_]], customJsonCodec: Option[JsonMessageCodec[F]], + incomingHeadersFilter: String => Boolean, pathPrefix: Uri.Path, executor: Executor, waitForShutdown: Duration, @@ -57,6 +62,7 @@ final class ConnectRouteBuilder[F[_] : Async] private( serverConfigurator: Endo[ServerBuilder[_]] = serverConfigurator, channelConfigurator: Endo[ManagedChannelBuilder[_]] = channelConfigurator, customJsonCodec: Option[JsonMessageCodec[F]] = customJsonCodec, + incomingHeadersFilter: String => Boolean = incomingHeadersFilter, pathPrefix: Uri.Path = pathPrefix, executor: Executor = executor, waitForShutdown: Duration = waitForShutdown, @@ -67,6 +73,7 @@ final class ConnectRouteBuilder[F[_] : Async] private( serverConfigurator, channelConfigurator, customJsonCodec, + incomingHeadersFilter, pathPrefix, executor, waitForShutdown, @@ -82,6 +89,9 @@ final class ConnectRouteBuilder[F[_] : Async] private( def withJsonCodecConfigurator(method: Endo[JsonMessageCodecBuilder[F]]): ConnectRouteBuilder[F] = copy(customJsonCodec = Some(method(JsonMessageCodecBuilder[F]()).build)) + def withIncomingHeadersFilter(filter: String => Boolean): ConnectRouteBuilder[F] = + copy(incomingHeadersFilter = filter) + def withPathPrefix(path: Uri.Path): ConnectRouteBuilder[F] = copy(pathPrefix = path) @@ -136,6 +146,7 @@ final class ConnectRouteBuilder[F[_] : Async] private( channel, errorHandler, treatTrailersAsHeaders, + incomingHeadersFilter, ) val connectRoutes = HttpRoutes[F] { @@ -172,6 +183,7 @@ final class ConnectRouteBuilder[F[_] : Async] private( val transcodingHandler = new TranscodingHandler( channel, errorHandler, + incomingHeadersFilter, ) val transcodingRoutes = HttpRoutes[F] { req => diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/Mappings.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/Mappings.scala index e982c71..2682ed4 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/Mappings.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/Mappings.scala @@ -12,10 +12,12 @@ object Mappings extends HeaderMappings, StatusCodeMappings, ResponseCodeExtensio trait HeaderMappings { extension (headers: Headers) { - def toMetadata: Metadata = { + def toMetadata(filter: String => Boolean): Metadata = { val metadata = new Metadata() headers.foreach { header => - metadata.put(asciiKey(header.name.toString), header.value) + if (filter(header.name.toString)) { + metadata.put(asciiKey(header.name.toString), header.value) + } } metadata } diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingHandler.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingHandler.scala index a4800f3..6eb3dff 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingHandler.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingHandler.scala @@ -20,6 +20,7 @@ import scala.util.chaining.* class TranscodingHandler[F[_] : Async]( channel: Channel, errorHandler: ErrorHandler[F], + incomingHeadersFilter: String => Boolean, ) { private val logger: Logger = LoggerFactory.getLogger(getClass) @@ -55,7 +56,7 @@ class TranscodingHandler[F[_] : Async]( channel, method.descriptor, callOptions, - headers.toMetadata, + headers.toMetadata(incomingHeadersFilter), message ) .map { response =>