Skip to content

Commit

Permalink
[transcoding] Switch to ScalaPB's HttpRule definitions to fix CNF err…
Browse files Browse the repository at this point in the history
…ors (#60)
  • Loading branch information
igor-vovk authored Dec 14, 2024
1 parent c24cd23 commit ab4c608
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.ivovk.connect_rpc_scala

import cats.implicits.*
import com.google.api.HttpRule
import com.google.api.http.{CustomHttpPattern, HttpRule}
import org.http4s.{Method, Request, Uri}
import org.ivovk.connect_rpc_scala
import org.ivovk.connect_rpc_scala.grpc.MethodRegistry
Expand Down Expand Up @@ -60,7 +60,7 @@ object TranscodingUrlMatcher {
Node(
variableDef,
segment,
mkTree(entries.map(e => e.copy(pattern = e.pattern.splitAt(1)._2)).toVector),
mkTree(entries.map(e => e.copy(pattern = e.pattern.splitAt(1)._2))),
)
)
}
Expand All @@ -74,7 +74,7 @@ object TranscodingUrlMatcher {
val result = collection.mutable.LinkedHashMap.empty[B, Vector[A]]

it.foreach { elem =>
val key = f(elem)
val key = f(elem)
val vec = result.getOrElse(key, Vector.empty)
result.update(key, vec :+ elem)
}
Expand Down Expand Up @@ -106,7 +106,7 @@ object TranscodingUrlMatcher {
): TranscodingUrlMatcher[F] = {
val entries = methods.flatMap { method =>
method.httpRule.fold(List.empty[Entry]) { httpRule =>
val additionalBindings = httpRule.getAdditionalBindingsList.asScala.toList
val additionalBindings = httpRule.additionalBindings.toList

(httpRule :: additionalBindings).map { rule =>
val (httpMethod, pattern) = extractMethodAndPattern(rule)
Expand All @@ -126,13 +126,13 @@ object TranscodingUrlMatcher {
}

private def extractMethodAndPattern(rule: HttpRule): (Option[Method], Uri.Path) = {
val (method, str) = rule.getPatternCase match
case HttpRule.PatternCase.GET => (Method.GET.some, rule.getGet)
case HttpRule.PatternCase.PUT => (Method.PUT.some, rule.getPut)
case HttpRule.PatternCase.POST => (Method.POST.some, rule.getPost)
case HttpRule.PatternCase.DELETE => (Method.DELETE.some, rule.getDelete)
case HttpRule.PatternCase.PATCH => (Method.PATCH.some, rule.getPatch)
case HttpRule.PatternCase.CUSTOM => (none, rule.getCustom.getPath)
val (method, str) = rule.pattern match
case HttpRule.Pattern.Get(value) => (Method.GET.some, value)
case HttpRule.Pattern.Put(value) => (Method.PUT.some, value)
case HttpRule.Pattern.Post(value) => (Method.POST.some, value)
case HttpRule.Pattern.Delete(value) => (Method.DELETE.some, value)
case HttpRule.Pattern.Patch(value) => (Method.PATCH.some, value)
case HttpRule.Pattern.Custom(CustomHttpPattern(kind, value, _)) if kind == "*" => (none, value)
case other => throw new RuntimeException(s"Unsupported pattern case $other (Rule: $rule)")

val path = Uri.Path.unsafeFromString(str).dropEndsWithSlash
Expand Down Expand Up @@ -169,13 +169,13 @@ class TranscodingUrlMatcher[F[_]](
JObject(groupFields(pathVars)),
JObject(groupFields(queryParams))
).some
case RootNode(children) =>
children.colFirst(doMatch(_, path, pathVars))
case _ => none
}
}

doMatch(tree, req.uri.path.segments.toList, List.empty)
val path = req.uri.path.segments.toList

tree.children.colFirst(doMatch(_, path, Nil))
}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.ivovk.connect_rpc_scala.grpc

import com.google.api.{AnnotationsProto, HttpRule}
import com.google.api.http.HttpRule
import io.grpc.{MethodDescriptor, ServerMethodDefinition, ServerServiceDefinition}
import scalapb.grpc.ConcreteProtoMethodDescriptorSupplier
import scalapb.{GeneratedMessage, GeneratedMessageCompanion}
Expand Down Expand Up @@ -47,16 +47,19 @@ object MethodRegistry {
new MethodRegistry(entries)
}

private val HttpFieldNumber = 72295728

private def extractHttpRule(methodDescriptor: MethodDescriptor[_, _]): Option[HttpRule] = {
methodDescriptor.getSchemaDescriptor match
case sd: ConcreteProtoMethodDescriptorSupplier =>
val fields = sd.getMethodDescriptor.getOptions.getUnknownFields
val fieldNumber = AnnotationsProto.http.getNumber

if fields.hasField(fieldNumber) then
Some(HttpRule.parseFrom(fields.getField(fieldNumber).getLengthDelimitedList.get(0).toByteArray))
else None
case _ => None
if fields.hasField(HttpFieldNumber) then
Some(HttpRule.parseFrom(fields.getField(HttpFieldNumber).getLengthDelimitedList.get(0).toByteArray))
else
None
case _ =>
None
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.ivovk.connect_rpc_scala

import cats.effect.IO
import com.google.api.HttpRule
import com.google.api.http.HttpRule
import org.http4s.Uri.Path.Root
import org.http4s.implicits.uri
import org.http4s.{Method, Request}
Expand All @@ -11,24 +11,24 @@ import org.scalatest.funsuite.AnyFunSuiteLike

class TranscodingUrlMatcherTest extends AnyFunSuiteLike {

val matcher = TranscodingUrlMatcher.create[IO](
private val matcher = TranscodingUrlMatcher.create[IO](
Seq(
MethodRegistry.Entry(
MethodName("CountriesService", "CreateCountry"),
null,
Some(HttpRule.newBuilder().setPost("/countries").build()),
Some(HttpRule().withPost("/countries")),
null
),
MethodRegistry.Entry(
MethodName("CountriesService", "ListCountries"),
null,
Some(HttpRule.newBuilder().setGet("/countries/list").build()),
Some(HttpRule().withGet("/countries/list")),
null
),
MethodRegistry.Entry(
MethodName("CountriesService", "GetCountry"),
null,
Some(HttpRule.newBuilder().setGet("/countries/{country_id}").build()),
Some(HttpRule().withGet("/countries/{country_id}")),
null
),
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MethodRegistryTest extends AnyFunSuite {
val httpRule = entry.get.httpRule.get

assert(httpRule.getPost == "/v1/test/http_annotation_method")
assert(httpRule.getBody == "*")
assert(httpRule.body == "*")
}

}

0 comments on commit ab4c608

Please sign in to comment.