Skip to content

Commit

Permalink
[transcoding] Work with a tree of routes instead of a list (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
igor-vovk authored Dec 13, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 73a632c commit 1fbc2f0
Showing 7 changed files with 159 additions and 74 deletions.
Original file line number Diff line number Diff line change
@@ -108,7 +108,7 @@ class ConnectHandler[F[_] : Async](
}
}

req.as[GeneratedMessage](method.requestMessageCompanion)
req.as[GeneratedMessage](using method.requestMessageCompanion)
.flatMap { message =>
if (logger.isTraceEnabled) {
logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}")
Original file line number Diff line number Diff line change
@@ -11,7 +11,8 @@ import org.ivovk.connect_rpc_scala.grpc.*
import org.ivovk.connect_rpc_scala.http.*
import org.ivovk.connect_rpc_scala.http.QueryParams.*
import org.ivovk.connect_rpc_scala.http.codec.*
import scalapb.GeneratedMessage
import org.ivovk.connect_rpc_scala.syntax.all.*
import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion}

import java.util.concurrent.Executor
import scala.concurrent.ExecutionContext
@@ -170,17 +171,21 @@ final class ConnectRouteBuilder[F[_] : Async] private(

val transcodingRoutes = HttpRoutes[F] { req =>
OptionT.fromOption[F](transcodingUrlMatcher.matchesRequest(req))
.semiflatMap { case MatchedRequest(method, json) =>
.semiflatMap { case MatchedRequest(method, pathJson, queryJson) =>
given MessageCodec[F] = jsonCodec
given EncodeOptions = EncodeOptions(None)

RequestEntity[F](req.body, req.headers)
.as[GeneratedMessage](method.requestMessageCompanion)
.flatMap { entity =>
val entity2 = jsonCodec.parser.fromJson[GeneratedMessage](json)(method.requestMessageCompanion)
val finalEntity = method.requestMessageCompanion.parseFrom(entity.toByteArray ++ entity2.toByteArray)
given Companion[Message] = method.requestMessageCompanion

transcodingHandler.handleUnary(finalEntity, req.headers, method)
RequestEntity[F](req.body, req.headers).as[Message]
.flatMap { bodyMessage =>
val pathMessage = jsonCodec.parser.fromJson[Message](pathJson)
val queryMessage = jsonCodec.parser.fromJson[Message](queryJson)

transcodingHandler.handleUnary(
bodyMessage.concat(pathMessage, queryMessage),
req.headers,
method
)
}
}
}
Original file line number Diff line number Diff line change
@@ -5,43 +5,123 @@ import com.google.api.HttpRule
import org.http4s.{Method, Request, Uri}
import org.ivovk.connect_rpc_scala
import org.ivovk.connect_rpc_scala.grpc.MethodRegistry
import org.ivovk.connect_rpc_scala.http.json.JsonProcessing.*
import org.json4s.JsonAST.{JField, JObject}
import org.json4s.{JString, JValue}

import scala.util.boundary
import scala.util.boundary.break
import scala.jdk.CollectionConverters.*

case class MatchedRequest(method: MethodRegistry.Entry, json: JValue)
case class MatchedRequest(
method: MethodRegistry.Entry,
pathJson: JValue,
queryJson: JValue,
)

object TranscodingUrlMatcher {
case class Entry(
method: MethodRegistry.Entry,
httpMethodMatcher: Method => Boolean,
httpMethod: Option[Method],
pattern: Uri.Path,
)

sealed trait RouteTree

case class RootNode(
children: Vector[RouteTree],
) extends RouteTree

case class Node(
isVariable: Boolean,
segment: String,
children: Vector[RouteTree],
) extends RouteTree

case class Leaf(
httpMethod: Option[Method],
method: MethodRegistry.Entry,
) extends RouteTree

private def mkTree(entries: Seq[Entry]): Vector[RouteTree] = {
entries.groupByOrd(_.pattern.segments.headOption)
.flatMap { (maybeSegment, entries) =>
maybeSegment match {
case None =>
entries.map { entry =>
Leaf(entry.httpMethod, entry.method)
}
case Some(head) =>
val variableDef = this.isVariable(head)
val segment =
if variableDef then
head.encoded.substring(1, head.encoded.length - 1)
else head.encoded

List(
Node(
variableDef,
segment,
mkTree(entries.map(e => e.copy(pattern = e.pattern.splitAt(1)._2)).toVector),
)
)
}
}
.toVector
}

extension [A](it: Iterable[A]) {
// Preserves ordering of elements
def groupByOrd[B](f: A => B): Map[B, Vector[A]] = {
val result = collection.mutable.LinkedHashMap.empty[B, Vector[A]]

it.foreach { elem =>
val key = f(elem)
val vec = result.getOrElse(key, Vector.empty)
result.update(key, vec :+ elem)
}

result.toMap
}

// Returns the first element that is Some
def colFirst[B](f: A => Option[B]): Option[B] = {
val iter = it.iterator
while (iter.hasNext) {
val x = f(iter.next())
if x.isDefined then return x
}
None
}
}

private def isVariable(segment: Uri.Path.Segment): Boolean = {
val enc = segment.encoded
val length = enc.length

length > 2 && enc(0) == '{' && enc(length - 1) == '}'
}

def create[F[_]](
methods: Seq[MethodRegistry.Entry],
pathPrefix: Uri.Path,
): TranscodingUrlMatcher[F] = {
val entries = methods.flatMap { method =>
method.httpRule match {
case Some(httpRule) =>
val (httpMethod, pattern) = extractMethodAndPattern(httpRule)
method.httpRule.fold(List.empty[Entry]) { httpRule =>
val additionalBindings = httpRule.getAdditionalBindingsList.asScala.toList

val httpMethodMatcher: Method => Boolean = m => httpMethod.forall(_ == m)
(httpRule :: additionalBindings).map { rule =>
val (httpMethod, pattern) = extractMethodAndPattern(rule)

Entry(
method,
httpMethodMatcher,
pathPrefix.dropEndsWithSlash.concat(pattern.toRelative)
).some
case None => none
httpMethod,
pathPrefix.concat(pattern),
)
}
}
}

new TranscodingUrlMatcher(
entries,
RootNode(mkTree(entries)),
)
}

@@ -62,56 +142,40 @@ object TranscodingUrlMatcher {
}

class TranscodingUrlMatcher[F[_]](
entries: Seq[TranscodingUrlMatcher.Entry],
tree: TranscodingUrlMatcher.RootNode,
) {

import org.ivovk.connect_rpc_scala.http.json.JsonProcessing.*
import TranscodingUrlMatcher.*

def matchesRequest(req: Request[F]): Option[MatchedRequest] = boundary {
entries.foreach { entry =>
if (entry.httpMethodMatcher(req.method)) {
matchExtract(entry.pattern, req.uri.path) match {
case Some(pathParams) =>
val queryParams = req.uri.query.toList.map((k, v) => k -> JString(v.getOrElse("")))
def matchesRequest(req: Request[F]): Option[MatchedRequest] = {
def doMatch(node: RouteTree, path: List[Uri.Path.Segment], pathVars: List[JField]): Option[MatchedRequest] = {
node match {
case Node(isVariable, patternSegment, children) if path.nonEmpty =>
val pathSegment = path.head
val pathTail = path.tail

val merged = mergeFields(groupFields(pathParams), groupFields(queryParams))
if isVariable then
val newPatchVars = (patternSegment -> JString(pathSegment.encoded)) :: pathVars

break(Some(MatchedRequest(entry.method, JObject(merged))))
case None => // continue
}
children.colFirst(doMatch(_, pathTail, newPatchVars))
else if pathSegment.encoded == patternSegment then
children.colFirst(doMatch(_, pathTail, pathVars))
else none
case Leaf(httpMethod, method) if path.isEmpty && httpMethod.forall(_ == req.method) =>
val queryParams = req.uri.query.toList.map((k, v) => k -> JString(v.getOrElse("")))

MatchedRequest(
method,
JObject(groupFields(pathVars)),
JObject(groupFields(queryParams))
).some
case RootNode(children) =>
children.colFirst(doMatch(_, path, pathVars))
case _ => none
}
}

none
doMatch(tree, req.uri.path.segments.toList, List.empty)
}

/**
* Matches path segments with pattern segments and extracts variables from the path.
* Returns None if the path does not match the pattern.
*/
private def matchExtract(pattern: Uri.Path, path: Uri.Path): Option[List[JField]] = boundary {
if path.segments.length != pattern.segments.length then boundary.break(none)

path.segments.indices
.foldLeft(List.empty[JField]) { (state, idx) =>
val pathSegment = path.segments(idx)
val patternSegment = pattern.segments(idx)

if isVariable(patternSegment) then
val varName = patternSegment.encoded.substring(1, patternSegment.encoded.length - 1)

(varName -> JString(pathSegment.encoded)) :: state
else if pathSegment != patternSegment then
boundary.break(none)
else state
}
.some
}

private def isVariable(segment: Uri.Path.Segment): Boolean = {
val enc = segment.encoded
val length = enc.length

length > 2 && enc(0) == '{' && enc(length - 1) == '}'
}
}
Original file line number Diff line number Diff line change
@@ -38,7 +38,6 @@ case class RequestEntity[F[_]](

def timeout: Option[Long] = headers.timeout

def as[A <: Message](cmp: Companion[A])(using M: MonadThrow[F], codec: MessageCodec[F]): F[A] =
M.rethrow(codec.decode(this)(using cmp).value)

def as[A <: Message: Companion](using M: MonadThrow[F], codec: MessageCodec[F]): F[A] =
M.rethrow(codec.decode(this)(using summon[Companion[A]]).value)
}
Original file line number Diff line number Diff line change
@@ -10,6 +10,10 @@ case class EncodeOptions(
encoding: Option[ContentCoding]
)

object EncodeOptions {
given EncodeOptions = EncodeOptions(None)
}

trait MessageCodec[F[_]] {

val mediaType: MediaType
17 changes: 16 additions & 1 deletion core/src/main/scala/org/ivovk/connect_rpc_scala/syntax/all.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package org.ivovk.connect_rpc_scala.syntax

import com.google.protobuf.ByteString
import io.grpc.{StatusException, StatusRuntimeException}
import org.ivovk.connect_rpc_scala.grpc.GrpcHeaders
import scalapb.GeneratedMessage
import scalapb.{GeneratedMessage, GeneratedMessageCompanion}

object all extends ExceptionSyntax, ProtoMappingsSyntax

@@ -33,6 +34,20 @@ trait ExceptionSyntax {
trait ProtoMappingsSyntax {

extension [T <: GeneratedMessage](t: T) {
def concat(other: T, more: T*): T = {
val cmp = t.companion.asInstanceOf[GeneratedMessageCompanion[T]]
val empty = cmp.defaultInstance

val els = (t :: other :: more.toList).filter(_ != empty)

els match
case Nil => empty
case el :: Nil => el
case _ =>
val is = els.foldLeft(ByteString.empty)(_ concat _.toByteString).newCodedInput()
cmp.parseFrom(is)
}

def toProtoAny: com.google.protobuf.any.Any = {
com.google.protobuf.any.Any(
typeUrl = "type.googleapis.com/" + t.companion.scalaDescriptor.fullName,
Original file line number Diff line number Diff line change
@@ -40,39 +40,37 @@ class TranscodingUrlMatcherTest extends AnyFunSuiteLike {

assert(result.isDefined)
assert(result.get.method.name == MethodName("CountriesService", "ListCountries"))
assert(result.get.json == JObject())
}

test("matches request with POST method") {
val result = matcher.matchesRequest(Request[IO](Method.POST, uri"/api/countries"))

assert(result.isDefined)
assert(result.get.method.name == MethodName("CountriesService", "CreateCountry"))
assert(result.get.json == JObject())
}

test("extracts query parameters") {
val result = matcher.matchesRequest(Request[IO](Method.GET, uri"/api/countries/list?limit=10&offset=5"))

assert(result.isDefined)
assert(result.get.method.name == MethodName("CountriesService", "ListCountries"))
assert(result.get.json == JObject("limit" -> JString("10"), "offset" -> JString("5")))
assert(result.get.queryJson == JObject("limit" -> JString("10"), "offset" -> JString("5")))
}

test("matches request with path parameter and extracts it") {
val result = matcher.matchesRequest(Request[IO](Method.GET, uri"/api/countries/Uganda"))

assert(result.isDefined)
assert(result.get.method.name == MethodName("CountriesService", "GetCountry"))
assert(result.get.json == JObject("country_id" -> JString("Uganda")))
assert(result.get.pathJson == JObject("country_id" -> JString("Uganda")))
}

test("extracts repeating query parameters") {
val result = matcher.matchesRequest(Request[IO](Method.GET, uri"/api/countries/list?limit=10&limit=20"))

assert(result.isDefined)
assert(result.get.method.name == MethodName("CountriesService", "ListCountries"))
assert(result.get.json == JObject("limit" -> JArray(JString("10") :: JString("20") :: Nil)))
assert(result.get.queryJson == JObject("limit" -> JArray(JString("10") :: JString("20") :: Nil)))
}

}

0 comments on commit 1fbc2f0

Please sign in to comment.