diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala index bc7d2ea..365e406 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala @@ -108,7 +108,7 @@ class ConnectHandler[F[_] : Async]( } } - req.as[GeneratedMessage](method.requestMessageCompanion) + req.as[GeneratedMessage](using method.requestMessageCompanion) .flatMap { message => if (logger.isTraceEnabled) { logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}") diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala index a856e0e..2879434 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectRouteBuilder.scala @@ -11,7 +11,8 @@ import org.ivovk.connect_rpc_scala.grpc.* import org.ivovk.connect_rpc_scala.http.* import org.ivovk.connect_rpc_scala.http.QueryParams.* import org.ivovk.connect_rpc_scala.http.codec.* -import scalapb.GeneratedMessage +import org.ivovk.connect_rpc_scala.syntax.all.* +import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion} import java.util.concurrent.Executor import scala.concurrent.ExecutionContext @@ -170,17 +171,21 @@ final class ConnectRouteBuilder[F[_] : Async] private( val transcodingRoutes = HttpRoutes[F] { req => OptionT.fromOption[F](transcodingUrlMatcher.matchesRequest(req)) - .semiflatMap { case MatchedRequest(method, json) => + .semiflatMap { case MatchedRequest(method, pathJson, queryJson) => given MessageCodec[F] = jsonCodec - given EncodeOptions = EncodeOptions(None) - RequestEntity[F](req.body, req.headers) - .as[GeneratedMessage](method.requestMessageCompanion) - .flatMap { entity => - val entity2 = jsonCodec.parser.fromJson[GeneratedMessage](json)(method.requestMessageCompanion) - val finalEntity = method.requestMessageCompanion.parseFrom(entity.toByteArray ++ entity2.toByteArray) + given Companion[Message] = method.requestMessageCompanion - transcodingHandler.handleUnary(finalEntity, req.headers, method) + RequestEntity[F](req.body, req.headers).as[Message] + .flatMap { bodyMessage => + val pathMessage = jsonCodec.parser.fromJson[Message](pathJson) + val queryMessage = jsonCodec.parser.fromJson[Message](queryJson) + + transcodingHandler.handleUnary( + bodyMessage.concat(pathMessage, queryMessage), + req.headers, + method + ) } } } diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcher.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcher.scala index 46ef5fe..e93cb8a 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcher.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcher.scala @@ -5,43 +5,123 @@ import com.google.api.HttpRule import org.http4s.{Method, Request, Uri} import org.ivovk.connect_rpc_scala import org.ivovk.connect_rpc_scala.grpc.MethodRegistry +import org.ivovk.connect_rpc_scala.http.json.JsonProcessing.* import org.json4s.JsonAST.{JField, JObject} import org.json4s.{JString, JValue} -import scala.util.boundary -import scala.util.boundary.break +import scala.jdk.CollectionConverters.* -case class MatchedRequest(method: MethodRegistry.Entry, json: JValue) +case class MatchedRequest( + method: MethodRegistry.Entry, + pathJson: JValue, + queryJson: JValue, +) object TranscodingUrlMatcher { case class Entry( method: MethodRegistry.Entry, - httpMethodMatcher: Method => Boolean, + httpMethod: Option[Method], pattern: Uri.Path, ) + sealed trait RouteTree + + case class RootNode( + children: Vector[RouteTree], + ) extends RouteTree + + case class Node( + isVariable: Boolean, + segment: String, + children: Vector[RouteTree], + ) extends RouteTree + + case class Leaf( + httpMethod: Option[Method], + method: MethodRegistry.Entry, + ) extends RouteTree + + private def mkTree(entries: Seq[Entry]): Vector[RouteTree] = { + entries.groupByOrd(_.pattern.segments.headOption) + .flatMap { (maybeSegment, entries) => + maybeSegment match { + case None => + entries.map { entry => + Leaf(entry.httpMethod, entry.method) + } + case Some(head) => + val variableDef = this.isVariable(head) + val segment = + if variableDef then + head.encoded.substring(1, head.encoded.length - 1) + else head.encoded + + List( + Node( + variableDef, + segment, + mkTree(entries.map(e => e.copy(pattern = e.pattern.splitAt(1)._2)).toVector), + ) + ) + } + } + .toVector + } + + extension [A](it: Iterable[A]) { + // Preserves ordering of elements + def groupByOrd[B](f: A => B): Map[B, Vector[A]] = { + val result = collection.mutable.LinkedHashMap.empty[B, Vector[A]] + + it.foreach { elem => + val key = f(elem) + val vec = result.getOrElse(key, Vector.empty) + result.update(key, vec :+ elem) + } + + result.toMap + } + + // Returns the first element that is Some + def colFirst[B](f: A => Option[B]): Option[B] = { + val iter = it.iterator + while (iter.hasNext) { + val x = f(iter.next()) + if x.isDefined then return x + } + None + } + } + + private def isVariable(segment: Uri.Path.Segment): Boolean = { + val enc = segment.encoded + val length = enc.length + + length > 2 && enc(0) == '{' && enc(length - 1) == '}' + } + def create[F[_]]( methods: Seq[MethodRegistry.Entry], pathPrefix: Uri.Path, ): TranscodingUrlMatcher[F] = { val entries = methods.flatMap { method => - method.httpRule match { - case Some(httpRule) => - val (httpMethod, pattern) = extractMethodAndPattern(httpRule) + method.httpRule.fold(List.empty[Entry]) { httpRule => + val additionalBindings = httpRule.getAdditionalBindingsList.asScala.toList - val httpMethodMatcher: Method => Boolean = m => httpMethod.forall(_ == m) + (httpRule :: additionalBindings).map { rule => + val (httpMethod, pattern) = extractMethodAndPattern(rule) Entry( method, - httpMethodMatcher, - pathPrefix.dropEndsWithSlash.concat(pattern.toRelative) - ).some - case None => none + httpMethod, + pathPrefix.concat(pattern), + ) + } } } new TranscodingUrlMatcher( - entries, + RootNode(mkTree(entries)), ) } @@ -62,56 +142,40 @@ object TranscodingUrlMatcher { } class TranscodingUrlMatcher[F[_]]( - entries: Seq[TranscodingUrlMatcher.Entry], + tree: TranscodingUrlMatcher.RootNode, ) { - import org.ivovk.connect_rpc_scala.http.json.JsonProcessing.* + import TranscodingUrlMatcher.* - def matchesRequest(req: Request[F]): Option[MatchedRequest] = boundary { - entries.foreach { entry => - if (entry.httpMethodMatcher(req.method)) { - matchExtract(entry.pattern, req.uri.path) match { - case Some(pathParams) => - val queryParams = req.uri.query.toList.map((k, v) => k -> JString(v.getOrElse(""))) + def matchesRequest(req: Request[F]): Option[MatchedRequest] = { + def doMatch(node: RouteTree, path: List[Uri.Path.Segment], pathVars: List[JField]): Option[MatchedRequest] = { + node match { + case Node(isVariable, patternSegment, children) if path.nonEmpty => + val pathSegment = path.head + val pathTail = path.tail - val merged = mergeFields(groupFields(pathParams), groupFields(queryParams)) + if isVariable then + val newPatchVars = (patternSegment -> JString(pathSegment.encoded)) :: pathVars - break(Some(MatchedRequest(entry.method, JObject(merged)))) - case None => // continue - } + children.colFirst(doMatch(_, pathTail, newPatchVars)) + else if pathSegment.encoded == patternSegment then + children.colFirst(doMatch(_, pathTail, pathVars)) + else none + case Leaf(httpMethod, method) if path.isEmpty && httpMethod.forall(_ == req.method) => + val queryParams = req.uri.query.toList.map((k, v) => k -> JString(v.getOrElse(""))) + + MatchedRequest( + method, + JObject(groupFields(pathVars)), + JObject(groupFields(queryParams)) + ).some + case RootNode(children) => + children.colFirst(doMatch(_, path, pathVars)) + case _ => none } } - none + doMatch(tree, req.uri.path.segments.toList, List.empty) } - /** - * Matches path segments with pattern segments and extracts variables from the path. - * Returns None if the path does not match the pattern. - */ - private def matchExtract(pattern: Uri.Path, path: Uri.Path): Option[List[JField]] = boundary { - if path.segments.length != pattern.segments.length then boundary.break(none) - - path.segments.indices - .foldLeft(List.empty[JField]) { (state, idx) => - val pathSegment = path.segments(idx) - val patternSegment = pattern.segments(idx) - - if isVariable(patternSegment) then - val varName = patternSegment.encoded.substring(1, patternSegment.encoded.length - 1) - - (varName -> JString(pathSegment.encoded)) :: state - else if pathSegment != patternSegment then - boundary.break(none) - else state - } - .some - } - - private def isVariable(segment: Uri.Path.Segment): Boolean = { - val enc = segment.encoded - val length = enc.length - - length > 2 && enc(0) == '{' && enc(length - 1) == '}' - } } diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala index 1df5e2a..105ce0d 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala @@ -38,7 +38,6 @@ case class RequestEntity[F[_]]( def timeout: Option[Long] = headers.timeout - def as[A <: Message](cmp: Companion[A])(using M: MonadThrow[F], codec: MessageCodec[F]): F[A] = - M.rethrow(codec.decode(this)(using cmp).value) - + def as[A <: Message: Companion](using M: MonadThrow[F], codec: MessageCodec[F]): F[A] = + M.rethrow(codec.decode(this)(using summon[Companion[A]]).value) } diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/MessageCodec.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/MessageCodec.scala index 28b8446..ffad397 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/MessageCodec.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/codec/MessageCodec.scala @@ -10,6 +10,10 @@ case class EncodeOptions( encoding: Option[ContentCoding] ) +object EncodeOptions { + given EncodeOptions = EncodeOptions(None) +} + trait MessageCodec[F[_]] { val mediaType: MediaType diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/syntax/all.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/syntax/all.scala index d2d40e9..c4b126b 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/syntax/all.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/syntax/all.scala @@ -1,8 +1,9 @@ package org.ivovk.connect_rpc_scala.syntax +import com.google.protobuf.ByteString import io.grpc.{StatusException, StatusRuntimeException} import org.ivovk.connect_rpc_scala.grpc.GrpcHeaders -import scalapb.GeneratedMessage +import scalapb.{GeneratedMessage, GeneratedMessageCompanion} object all extends ExceptionSyntax, ProtoMappingsSyntax @@ -33,6 +34,20 @@ trait ExceptionSyntax { trait ProtoMappingsSyntax { extension [T <: GeneratedMessage](t: T) { + def concat(other: T, more: T*): T = { + val cmp = t.companion.asInstanceOf[GeneratedMessageCompanion[T]] + val empty = cmp.defaultInstance + + val els = (t :: other :: more.toList).filter(_ != empty) + + els match + case Nil => empty + case el :: Nil => el + case _ => + val is = els.foldLeft(ByteString.empty)(_ concat _.toByteString).newCodedInput() + cmp.parseFrom(is) + } + def toProtoAny: com.google.protobuf.any.Any = { com.google.protobuf.any.Any( typeUrl = "type.googleapis.com/" + t.companion.scalaDescriptor.fullName, diff --git a/core/src/test/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcherTest.scala b/core/src/test/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcherTest.scala index e0b9bab..f76a848 100644 --- a/core/src/test/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcherTest.scala +++ b/core/src/test/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcherTest.scala @@ -40,7 +40,6 @@ class TranscodingUrlMatcherTest extends AnyFunSuiteLike { assert(result.isDefined) assert(result.get.method.name == MethodName("CountriesService", "ListCountries")) - assert(result.get.json == JObject()) } test("matches request with POST method") { @@ -48,7 +47,6 @@ class TranscodingUrlMatcherTest extends AnyFunSuiteLike { assert(result.isDefined) assert(result.get.method.name == MethodName("CountriesService", "CreateCountry")) - assert(result.get.json == JObject()) } test("extracts query parameters") { @@ -56,7 +54,7 @@ class TranscodingUrlMatcherTest extends AnyFunSuiteLike { assert(result.isDefined) assert(result.get.method.name == MethodName("CountriesService", "ListCountries")) - assert(result.get.json == JObject("limit" -> JString("10"), "offset" -> JString("5"))) + assert(result.get.queryJson == JObject("limit" -> JString("10"), "offset" -> JString("5"))) } test("matches request with path parameter and extracts it") { @@ -64,7 +62,7 @@ class TranscodingUrlMatcherTest extends AnyFunSuiteLike { assert(result.isDefined) assert(result.get.method.name == MethodName("CountriesService", "GetCountry")) - assert(result.get.json == JObject("country_id" -> JString("Uganda"))) + assert(result.get.pathJson == JObject("country_id" -> JString("Uganda"))) } test("extracts repeating query parameters") { @@ -72,7 +70,7 @@ class TranscodingUrlMatcherTest extends AnyFunSuiteLike { assert(result.isDefined) assert(result.get.method.name == MethodName("CountriesService", "ListCountries")) - assert(result.get.json == JObject("limit" -> JArray(JString("10") :: JString("20") :: Nil))) + assert(result.get.queryJson == JObject("limit" -> JArray(JString("10") :: JString("20") :: Nil))) } }