From 73a632c62ea5a84a16e033909eeae9dd9ecfb7e9 Mon Sep 17 00:00:00 2001 From: Ihor Vovk Date: Thu, 12 Dec 2024 22:26:32 +0100 Subject: [PATCH] Initial work on the GRPC Transcoding (#54) --- .../ConnectRouteBuilder.scala | 41 +++++- .../TranscodingHandler.scala | 126 ++++++++++++++++++ .../TranscodingUrlMatcher.scala | 117 ++++++++++++++++ .../grpc/MethodRegistry.scala | 15 ++- .../http/RequestEntity.scala | 11 +- .../http/codec/Compressor.scala | 8 +- .../http/codec/JsonMessageCodec.scala | 9 +- .../http/json/JsonProcessing.scala | 59 ++++++++ ...Test.proto => HttpCommunicationTest.proto} | 4 + ...Test.scala => HttpCommunicationTest.scala} | 27 +++- .../TranscodingUrlMatcherTest.scala | 78 +++++++++++ .../http/grpc/MethodRegistryTest.scala | 2 +- 12 files changed, 474 insertions(+), 23 deletions(-) create mode 100644 core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingHandler.scala create mode 100644 core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcher.scala create mode 100644 core/src/main/scala/org/ivovk/connect_rpc_scala/http/json/JsonProcessing.scala rename core/src/test/protobuf/test/{ConnectCommunicationTest.proto => HttpCommunicationTest.proto} (82%) rename core/src/test/scala/org/ivovk/connect_rpc_scala/{ConnectCommunicationTest.scala => HttpCommunicationTest.scala} (84%) create mode 100644 core/src/test/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcherTest.scala diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala index 1bdd0d9..a856e0e 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala @@ -11,6 +11,7 @@ import org.ivovk.connect_rpc_scala.grpc.* import org.ivovk.connect_rpc_scala.http.* import org.ivovk.connect_rpc_scala.http.QueryParams.* import org.ivovk.connect_rpc_scala.http.codec.* +import scalapb.GeneratedMessage import java.util.concurrent.Executor import scala.concurrent.ExecutionContext @@ -108,8 +109,9 @@ final class ConnectRouteBuilder[F[_] : Async] private( val httpDsl = Http4sDsl[F] import httpDsl.* + val jsonCodec = customJsonCodec.getOrElse(JsonMessageCodecBuilder[F]().build) val codecRegistry = MessageCodecRegistry[F]( - customJsonCodec.getOrElse(JsonMessageCodecBuilder[F]().build), + jsonCodec, ProtoMessageCodec[F](), ) @@ -124,13 +126,13 @@ final class ConnectRouteBuilder[F[_] : Async] private( waitForShutdown, ) yield - val handler = new ConnectHandler( + val connectHandler = new ConnectHandler( channel, httpDsl, treatTrailersAsHeaders, ) - HttpRoutes[F] { + val connectRoutes = HttpRoutes[F] { case req@Method.GET -> `pathPrefix` / service / method :? EncodingQP(mediaType) +& MessageQP(message) => OptionT.fromOption[F](methodRegistry.get(service, method)) // Temporary support GET-requests for all methods, @@ -140,7 +142,7 @@ final class ConnectRouteBuilder[F[_] : Async] private( withCodec(httpDsl, codecRegistry, mediaType.some) { codec => val entity = RequestEntity[F](message, req.headers) - handler.handle(entity, methodEntry)(using codec) + connectHandler.handle(entity, methodEntry)(using codec) } } case req@Method.POST -> `pathPrefix` / service / method => @@ -149,12 +151,41 @@ final class ConnectRouteBuilder[F[_] : Async] private( withCodec(httpDsl, codecRegistry, req.contentType.map(_.mediaType)) { codec => val entity = RequestEntity[F](req.body, req.headers) - handler.handle(entity, methodEntry)(using codec) + connectHandler.handle(entity, methodEntry)(using codec) } } case _ => OptionT.none } + + val transcodingUrlMatcher = TranscodingUrlMatcher.create[F]( + methodRegistry.all, + pathPrefix, + ) + val transcodingHandler = new TranscodingHandler( + channel, + httpDsl, + treatTrailersAsHeaders, + ) + + val transcodingRoutes = HttpRoutes[F] { req => + OptionT.fromOption[F](transcodingUrlMatcher.matchesRequest(req)) + .semiflatMap { case MatchedRequest(method, json) => + given MessageCodec[F] = jsonCodec + given EncodeOptions = EncodeOptions(None) + + RequestEntity[F](req.body, req.headers) + .as[GeneratedMessage](method.requestMessageCompanion) + .flatMap { entity => + val entity2 = jsonCodec.parser.fromJson[GeneratedMessage](json)(method.requestMessageCompanion) + val finalEntity = method.requestMessageCompanion.parseFrom(entity.toByteArray ++ entity2.toByteArray) + + transcodingHandler.handleUnary(finalEntity, req.headers, method) + } + } + } + + connectRoutes <+> transcodingRoutes } def build: Resource[F, HttpApp[F]] = diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingHandler.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingHandler.scala new file mode 100644 index 0000000..38dbbad --- /dev/null +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingHandler.scala @@ -0,0 +1,126 @@ +package org.ivovk.connect_rpc_scala + +import cats.effect.Async +import cats.implicits.* +import io.grpc.* +import org.http4s.dsl.Http4sDsl +import org.http4s.{Header, Headers, MessageFailure, Response} +import org.ivovk.connect_rpc_scala.Mappings.* +import org.ivovk.connect_rpc_scala.grpc.{ClientCalls, GrpcHeaders, MethodRegistry} +import org.ivovk.connect_rpc_scala.http.Headers.`X-Test-Case-Name` +import org.ivovk.connect_rpc_scala.http.RequestEntity +import org.ivovk.connect_rpc_scala.http.RequestEntity.* +import org.ivovk.connect_rpc_scala.http.codec.{EncodeOptions, MessageCodec} +import org.slf4j.{Logger, LoggerFactory} +import scalapb.GeneratedMessage + +import scala.concurrent.duration.* +import scala.jdk.CollectionConverters.* +import scala.util.chaining.* + +object TranscodingHandler { + + extension [F[_]](response: Response[F]) { + def withMessage(entity: GeneratedMessage)(using codec: MessageCodec[F], options: EncodeOptions): Response[F] = + codec.encode(entity, options).applyTo(response) + } + +} + +class TranscodingHandler[F[_] : Async]( + channel: Channel, + httpDsl: Http4sDsl[F], + treatTrailersAsHeaders: Boolean, +) { + + import TranscodingHandler.* + import httpDsl.* + + private val logger: Logger = LoggerFactory.getLogger(getClass) + + def handleUnary( + message: GeneratedMessage, + headers: Headers, + method: MethodRegistry.Entry, + )(using MessageCodec[F], EncodeOptions): F[Response[F]] = { + if (logger.isTraceEnabled) { + // Used in conformance tests + headers.get[`X-Test-Case-Name`] match { + case Some(header) => + logger.trace(s">>> Test Case: ${header.value}") + case None => // ignore + } + } + + if (logger.isTraceEnabled) { + logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}") + } + + val callOptions = CallOptions.DEFAULT + .pipe( + headers.timeout match { + case Some(timeout) => _.withDeadlineAfter(timeout, MILLISECONDS) + case None => identity + } + ) + + ClientCalls + .asyncUnaryCall( + channel, + method.descriptor, + callOptions, + headers.toMetadata, + message + ) + .map { response => + val headers = response.headers.toHeaders() ++ + response.trailers.toHeaders(trailing = !treatTrailersAsHeaders) + + if (logger.isTraceEnabled) { + logger.trace(s"<<< Headers: ${headers.redactSensitive()}") + } + + Response(Ok, headers = headers).withMessage(response.value) + } + .handleError { e => + val grpcStatus = e match { + case e: StatusException => + e.getStatus.getDescription match { + case "an implementation is missing" => io.grpc.Status.UNIMPLEMENTED + case _ => e.getStatus + } + case e: StatusRuntimeException => e.getStatus + case _: MessageFailure => io.grpc.Status.INVALID_ARGUMENT + case _ => io.grpc.Status.INTERNAL + } + + val (message, metadata) = e match { + case e: StatusRuntimeException => (Option(e.getStatus.getDescription), e.getTrailers) + case e: StatusException => (Option(e.getStatus.getDescription), e.getTrailers) + case e => (Option(e.getMessage), new Metadata()) + } + + val httpStatus = grpcStatus.toHttpStatus + val connectCode = grpcStatus.toConnectCode + + // Should be called before converting metadata to headers + val details = Option(metadata.removeAll(GrpcHeaders.ErrorDetailsKey)) + .fold(Seq.empty)(_.asScala.toSeq) + + val headers = metadata.toHeaders(trailing = !treatTrailersAsHeaders) + + if (logger.isTraceEnabled) { + logger.trace(s"<<< Http Status: $httpStatus, Connect Error Code: $connectCode") + logger.trace(s"<<< Headers: ${headers.redactSensitive()}") + logger.trace(s"<<< Error processing request", e) + } + + Response[F](httpStatus, headers = headers).withMessage(connectrpc.Error( + code = connectCode, + message = message, + details = details + )) + } + } + +} diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcher.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcher.scala new file mode 100644 index 0000000..46ef5fe --- /dev/null +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcher.scala @@ -0,0 +1,117 @@ +package org.ivovk.connect_rpc_scala + +import cats.implicits.* +import com.google.api.HttpRule +import org.http4s.{Method, Request, Uri} +import org.ivovk.connect_rpc_scala +import org.ivovk.connect_rpc_scala.grpc.MethodRegistry +import org.json4s.JsonAST.{JField, JObject} +import org.json4s.{JString, JValue} + +import scala.util.boundary +import scala.util.boundary.break + +case class MatchedRequest(method: MethodRegistry.Entry, json: JValue) + +object TranscodingUrlMatcher { + case class Entry( + method: MethodRegistry.Entry, + httpMethodMatcher: Method => Boolean, + pattern: Uri.Path, + ) + + def create[F[_]]( + methods: Seq[MethodRegistry.Entry], + pathPrefix: Uri.Path, + ): TranscodingUrlMatcher[F] = { + val entries = methods.flatMap { method => + method.httpRule match { + case Some(httpRule) => + val (httpMethod, pattern) = extractMethodAndPattern(httpRule) + + val httpMethodMatcher: Method => Boolean = m => httpMethod.forall(_ == m) + + Entry( + method, + httpMethodMatcher, + pathPrefix.dropEndsWithSlash.concat(pattern.toRelative) + ).some + case None => none + } + } + + new TranscodingUrlMatcher( + entries, + ) + } + + private def extractMethodAndPattern(rule: HttpRule): (Option[Method], Uri.Path) = { + val (method, str) = rule.getPatternCase match + case HttpRule.PatternCase.GET => (Method.GET.some, rule.getGet) + case HttpRule.PatternCase.PUT => (Method.PUT.some, rule.getPut) + case HttpRule.PatternCase.POST => (Method.POST.some, rule.getPost) + case HttpRule.PatternCase.DELETE => (Method.DELETE.some, rule.getDelete) + case HttpRule.PatternCase.PATCH => (Method.PATCH.some, rule.getPatch) + case HttpRule.PatternCase.CUSTOM => (none, rule.getCustom.getPath) + case other => throw new RuntimeException(s"Unsupported pattern case $other (Rule: $rule)") + + val path = Uri.Path.unsafeFromString(str).dropEndsWithSlash + + (method, path) + } +} + +class TranscodingUrlMatcher[F[_]]( + entries: Seq[TranscodingUrlMatcher.Entry], +) { + + import org.ivovk.connect_rpc_scala.http.json.JsonProcessing.* + + def matchesRequest(req: Request[F]): Option[MatchedRequest] = boundary { + entries.foreach { entry => + if (entry.httpMethodMatcher(req.method)) { + matchExtract(entry.pattern, req.uri.path) match { + case Some(pathParams) => + val queryParams = req.uri.query.toList.map((k, v) => k -> JString(v.getOrElse(""))) + + val merged = mergeFields(groupFields(pathParams), groupFields(queryParams)) + + break(Some(MatchedRequest(entry.method, JObject(merged)))) + case None => // continue + } + } + } + + none + } + + /** + * Matches path segments with pattern segments and extracts variables from the path. + * Returns None if the path does not match the pattern. + */ + private def matchExtract(pattern: Uri.Path, path: Uri.Path): Option[List[JField]] = boundary { + if path.segments.length != pattern.segments.length then boundary.break(none) + + path.segments.indices + .foldLeft(List.empty[JField]) { (state, idx) => + val pathSegment = path.segments(idx) + val patternSegment = pattern.segments(idx) + + if isVariable(patternSegment) then + val varName = patternSegment.encoded.substring(1, patternSegment.encoded.length - 1) + + (varName -> JString(pathSegment.encoded)) :: state + else if pathSegment != patternSegment then + boundary.break(none) + else state + } + .some + } + + private def isVariable(segment: Uri.Path.Segment): Boolean = { + val enc = segment.encoded + val length = enc.length + + length > 2 && enc(0) == '{' && enc(length - 1) == '}' + } +} diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/grpc/MethodRegistry.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/grpc/MethodRegistry.scala index f629400..ee80124 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/grpc/MethodRegistry.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/grpc/MethodRegistry.scala @@ -1,7 +1,6 @@ package org.ivovk.connect_rpc_scala.grpc -import com.google.api.AnnotationsProto -import com.google.api.http.HttpRule +import com.google.api.{AnnotationsProto, HttpRule} import io.grpc.{MethodDescriptor, ServerMethodDefinition, ServerServiceDefinition} import scalapb.grpc.ConcreteProtoMethodDescriptorSupplier import scalapb.{GeneratedMessage, GeneratedMessageCompanion} @@ -44,7 +43,6 @@ object MethodRegistry { descriptor = methodDescriptor, ) } - .groupMapReduce(_.name.service)(e => Map(e.name.method -> e))(_ ++ _) new MethodRegistry(entries) } @@ -63,9 +61,16 @@ object MethodRegistry { } -class MethodRegistry private(entries: Map[Service, Map[Method, MethodRegistry.Entry]]) { +class MethodRegistry private(entries: Seq[MethodRegistry.Entry]) { + + private val serviceMethodEntries: Map[Service, Map[Method, MethodRegistry.Entry]] = entries + .groupMapReduce(_.name.service)(e => Map(e.name.method -> e))(_ ++ _) + + def all: Seq[MethodRegistry.Entry] = entries + + def get(name: MethodName): Option[MethodRegistry.Entry] = get(name.service, name.method) def get(service: Service, method: Method): Option[MethodRegistry.Entry] = - entries.getOrElse(service, Map.empty).get(method) + serviceMethodEntries.getOrElse(service, Map.empty).get(method) } diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala index 3558b1f..1df5e2a 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala @@ -9,6 +9,13 @@ import org.ivovk.connect_rpc_scala.http.codec.MessageCodec import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion} +object RequestEntity { + extension (h: Headers) { + def timeout: Option[Long] = + h.get[`Connect-Timeout-Ms`].map(_.value) + } +} + /** * Encoded message and headers with the knowledge how this message can be decoded. * Similar to [[org.http4s.Media]], but extends the message with `String` type representing message that is @@ -18,6 +25,7 @@ case class RequestEntity[F[_]]( message: String | Stream[F, Byte], headers: Headers, ) { + import RequestEntity.* private def contentType: Option[`Content-Type`] = headers.get[`Content-Type`] @@ -28,8 +36,7 @@ case class RequestEntity[F[_]]( def encoding: Option[ContentCoding] = headers.get[`Content-Encoding`].map(_.contentCoding) - def timeout: Option[Long] = - headers.get[`Connect-Timeout-Ms`].map(_.value) + def timeout: Option[Long] = headers.timeout def as[A <: Message](cmp: Companion[A])(using M: MonadThrow[F], codec: MessageCodec[F]): F[A] = M.rethrow(codec.decode(this)(using cmp).value) diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/Compressor.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/Compressor.scala index 3583b68..65cb31a 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/Compressor.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/Compressor.scala @@ -17,14 +17,14 @@ class Compressor[F[_] : Sync] { given Compression[F] = Compression.forSync[F] def decompressed(encoding: Option[ContentCoding], body: Stream[F, Byte]): Stream[F, Byte] = - body.through(encoding match { + encoding match { case Some(ContentCoding.gzip) => - Compression[F].gunzip().andThen(_.flatMap(_.content)) + body.through(Compression[F].gunzip().andThen(_.flatMap(_.content))) case Some(other) => throw new StatusException(Status.INVALID_ARGUMENT.withDescription(s"Unsupported encoding: $other")) case None => - identity - }) + body + } def compressed(encoding: Option[ContentCoding], entity: ResponseEntity[F]): ResponseEntity[F] = encoding match { diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/JsonMessageCodec.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/JsonMessageCodec.scala index 04aaf61..e5b3fb8 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/JsonMessageCodec.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/JsonMessageCodec.scala @@ -14,8 +14,8 @@ import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Compan import java.net.URLDecoder class JsonMessageCodec[F[_] : Sync]( - parser: Parser, - printer: Printer, + val parser: Parser, + val printer: Printer, ) extends MessageCodec[F] { private val logger = LoggerFactory.getLogger(getClass) @@ -41,7 +41,10 @@ class JsonMessageCodec[F[_] : Sync]( logger.trace(s">>> JSON: $str") } - Sync[F].delay(parser.fromJsonString(str)) + if str.nonEmpty then + Sync[F].delay(parser.fromJsonString(str)) + else + cmp.defaultInstance.pure[F] } .attemptT .leftMap(e => InvalidMessageBodyFailure(e.getMessage, e.some)) diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/http/json/JsonProcessing.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/json/JsonProcessing.scala new file mode 100644 index 0000000..e3de53a --- /dev/null +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/json/JsonProcessing.scala @@ -0,0 +1,59 @@ +package org.ivovk.connect_rpc_scala.http.json + +import org.json4s.JsonAST.{JField, JObject} +import org.json4s.{JArray, JNothing, JString, JValue} + +object JsonProcessing { + + def mergeFields(a: List[JField], b: List[JField]): List[JField] = { + if a.isEmpty then b + else if b.isEmpty then a + else a.foldLeft(b) { case (acc, (k, v)) => + acc.find(_._1 == k) match { + case Some((_, v2)) => acc.updated(acc.indexOf((k, v2)), (k, merge(v, v2))) + case None => acc :+ (k, v) + } + } + } + + private def merge(a: JValue, b: JValue): JValue = { + (a, b) match + case (JObject(xs), JObject(ys)) => JObject(mergeFields(xs, ys)) + case (JArray(xs), JArray(ys)) => JArray(xs ++ ys) + case (JArray(xs), y) => JArray(xs :+ y) + case (JNothing, x) => x + case (x, JNothing) => x + case (JString(x), JString(y)) => JArray(List(JString(x), JString(y))) + case (_, y) => y + } + + def groupFields(fields: List[JField]): List[JField] = { + groupFields2(fields.map { (k, v) => + if k.contains('.') then k.split('.').toList -> v else List(k) -> v + }) + } + + private def groupFields2(fields: List[(List[String], JValue)]): List[JField] = { + fields + .groupMapReduce((keyParts, _) => keyParts.head) { + case (_ :: Nil, v) => List(v) + case (_ :: tail, v) => List(tail -> v) + case (Nil, _) => ??? + }(_ ++ _) + .view.mapValues { fields => + if (fields.forall { + case (list: List[String], v: JValue) => true + case _ => false + }) { + JObject(groupFields2(fields.asInstanceOf[List[(List[String], JValue)]])) + } else { + val jvalues = fields.asInstanceOf[List[JValue]] + + if jvalues.length == 1 then jvalues.head + else JArray(jvalues) + } + } + .toList + } + +} diff --git a/core/src/test/protobuf/test/ConnectCommunicationTest.proto b/core/src/test/protobuf/test/HttpCommunicationTest.proto similarity index 82% rename from core/src/test/protobuf/test/ConnectCommunicationTest.proto rename to core/src/test/protobuf/test/HttpCommunicationTest.proto index 4b896b9..8011558 100644 --- a/core/src/test/protobuf/test/ConnectCommunicationTest.proto +++ b/core/src/test/protobuf/test/HttpCommunicationTest.proto @@ -2,12 +2,16 @@ syntax = "proto3"; package test; +import "google/api/annotations.proto"; + service TestService { rpc Add(AddRequest) returns (AddResponse) {} // This method can be called using GET request rpc Get(GetRequest) returns (GetResponse) { option idempotency_level = NO_SIDE_EFFECTS; + + option (google.api.http) = {get: "/get/{key}"}; } } diff --git a/core/src/test/scala/org/ivovk/connect_rpc_scala/ConnectCommunicationTest.scala b/core/src/test/scala/org/ivovk/connect_rpc_scala/HttpCommunicationTest.scala similarity index 84% rename from core/src/test/scala/org/ivovk/connect_rpc_scala/ConnectCommunicationTest.scala rename to core/src/test/scala/org/ivovk/connect_rpc_scala/HttpCommunicationTest.scala index df0b5d5..0dac151 100644 --- a/core/src/test/scala/org/ivovk/connect_rpc_scala/ConnectCommunicationTest.scala +++ b/core/src/test/scala/org/ivovk/connect_rpc_scala/HttpCommunicationTest.scala @@ -11,14 +11,14 @@ import org.http4s.{Method, *} import org.ivovk.connect_rpc_scala.http.MediaTypes import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers -import test.ConnectCommunicationTest.* -import test.ConnectCommunicationTest.TestServiceGrpc.TestService +import test.HttpCommunicationTest.TestServiceGrpc.TestService +import test.HttpCommunicationTest.{AddRequest, AddResponse, GetRequest, GetResponse} import java.net.URLEncoder import scala.concurrent.{ExecutionContext, Future} import scala.jdk.CollectionConverters.* -class ConnectCommunicationTest extends AnyFunSuite, Matchers { +class HttpCommunicationTest extends AnyFunSuite, Matchers { object TestServiceImpl extends TestService { override def add(request: AddRequest): Future[AddResponse] = @@ -83,6 +83,27 @@ class ConnectCommunicationTest extends AnyFunSuite, Matchers { .unsafeRunSync() } + test("Http-annotated GET request") { + val service = TestService.bindService(TestServiceImpl, ExecutionContext.global) + + ConnectRouteBuilder.forService[IO](service).build + .flatMap { app => + Client.fromHttpApp(app).run( + Request[IO](Method.GET, uri"/get/123") + ) + } + .use { response => + for + body <- response.as[String] + yield { + assert(body == """{"value":"Key is: 123"}""") + assert(response.status == Status.Ok) + assert(response.headers.get[`Content-Type`].map(_.mediaType).contains(MediaTypes.`application/json`)) + } + } + .unsafeRunSync() + } + test("support path prefixes") { val service = TestService.bindService(TestServiceImpl, ExecutionContext.global) diff --git a/core/src/test/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcherTest.scala b/core/src/test/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcherTest.scala new file mode 100644 index 0000000..e0b9bab --- /dev/null +++ b/core/src/test/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcherTest.scala @@ -0,0 +1,78 @@ +package org.ivovk.connect_rpc_scala + +import cats.effect.IO +import com.google.api.HttpRule +import org.http4s.Uri.Path.Root +import org.http4s.implicits.uri +import org.http4s.{Method, Request} +import org.ivovk.connect_rpc_scala.grpc.{MethodName, MethodRegistry} +import org.json4s.{JArray, JObject, JString} +import org.scalatest.funsuite.AnyFunSuiteLike + +class TranscodingUrlMatcherTest extends AnyFunSuiteLike { + + val matcher = TranscodingUrlMatcher.create[IO]( + Seq( + MethodRegistry.Entry( + MethodName("CountriesService", "CreateCountry"), + null, + Some(HttpRule.newBuilder().setPost("/countries").build()), + null + ), + MethodRegistry.Entry( + MethodName("CountriesService", "ListCountries"), + null, + Some(HttpRule.newBuilder().setGet("/countries/list").build()), + null + ), + MethodRegistry.Entry( + MethodName("CountriesService", "GetCountry"), + null, + Some(HttpRule.newBuilder().setGet("/countries/{country_id}").build()), + null + ), + ), + Root / "api" + ) + + test("matches request with GET method") { + val result = matcher.matchesRequest(Request[IO](Method.GET, uri"/api/countries/list")) + + assert(result.isDefined) + assert(result.get.method.name == MethodName("CountriesService", "ListCountries")) + assert(result.get.json == JObject()) + } + + test("matches request with POST method") { + val result = matcher.matchesRequest(Request[IO](Method.POST, uri"/api/countries")) + + assert(result.isDefined) + assert(result.get.method.name == MethodName("CountriesService", "CreateCountry")) + assert(result.get.json == JObject()) + } + + test("extracts query parameters") { + val result = matcher.matchesRequest(Request[IO](Method.GET, uri"/api/countries/list?limit=10&offset=5")) + + assert(result.isDefined) + assert(result.get.method.name == MethodName("CountriesService", "ListCountries")) + assert(result.get.json == JObject("limit" -> JString("10"), "offset" -> JString("5"))) + } + + test("matches request with path parameter and extracts it") { + val result = matcher.matchesRequest(Request[IO](Method.GET, uri"/api/countries/Uganda")) + + assert(result.isDefined) + assert(result.get.method.name == MethodName("CountriesService", "GetCountry")) + assert(result.get.json == JObject("country_id" -> JString("Uganda"))) + } + + test("extracts repeating query parameters") { + val result = matcher.matchesRequest(Request[IO](Method.GET, uri"/api/countries/list?limit=10&limit=20")) + + assert(result.isDefined) + assert(result.get.method.name == MethodName("CountriesService", "ListCountries")) + assert(result.get.json == JObject("limit" -> JArray(JString("10") :: JString("20") :: Nil))) + } + +} diff --git a/core/src/test/scala/org/ivovk/connect_rpc_scala/http/grpc/MethodRegistryTest.scala b/core/src/test/scala/org/ivovk/connect_rpc_scala/http/grpc/MethodRegistryTest.scala index 61b8ccc..e6afb5b 100644 --- a/core/src/test/scala/org/ivovk/connect_rpc_scala/http/grpc/MethodRegistryTest.scala +++ b/core/src/test/scala/org/ivovk/connect_rpc_scala/http/grpc/MethodRegistryTest.scala @@ -36,7 +36,7 @@ class MethodRegistryTest extends AnyFunSuite { val httpRule = entry.get.httpRule.get assert(httpRule.getPost == "/v1/test/http_annotation_method") - assert(httpRule.body == "*") + assert(httpRule.getBody == "*") } }