-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial work on the GRPC Transcoding (#54)
- Loading branch information
Showing
12 changed files
with
474 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
126 changes: 126 additions & 0 deletions
126
core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingHandler.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
package org.ivovk.connect_rpc_scala | ||
|
||
import cats.effect.Async | ||
import cats.implicits.* | ||
import io.grpc.* | ||
import org.http4s.dsl.Http4sDsl | ||
import org.http4s.{Header, Headers, MessageFailure, Response} | ||
import org.ivovk.connect_rpc_scala.Mappings.* | ||
import org.ivovk.connect_rpc_scala.grpc.{ClientCalls, GrpcHeaders, MethodRegistry} | ||
import org.ivovk.connect_rpc_scala.http.Headers.`X-Test-Case-Name` | ||
import org.ivovk.connect_rpc_scala.http.RequestEntity | ||
import org.ivovk.connect_rpc_scala.http.RequestEntity.* | ||
import org.ivovk.connect_rpc_scala.http.codec.{EncodeOptions, MessageCodec} | ||
import org.slf4j.{Logger, LoggerFactory} | ||
import scalapb.GeneratedMessage | ||
|
||
import scala.concurrent.duration.* | ||
import scala.jdk.CollectionConverters.* | ||
import scala.util.chaining.* | ||
|
||
object TranscodingHandler { | ||
|
||
extension [F[_]](response: Response[F]) { | ||
def withMessage(entity: GeneratedMessage)(using codec: MessageCodec[F], options: EncodeOptions): Response[F] = | ||
codec.encode(entity, options).applyTo(response) | ||
} | ||
|
||
} | ||
|
||
class TranscodingHandler[F[_] : Async]( | ||
channel: Channel, | ||
httpDsl: Http4sDsl[F], | ||
treatTrailersAsHeaders: Boolean, | ||
) { | ||
|
||
import TranscodingHandler.* | ||
import httpDsl.* | ||
|
||
private val logger: Logger = LoggerFactory.getLogger(getClass) | ||
|
||
def handleUnary( | ||
message: GeneratedMessage, | ||
headers: Headers, | ||
method: MethodRegistry.Entry, | ||
)(using MessageCodec[F], EncodeOptions): F[Response[F]] = { | ||
if (logger.isTraceEnabled) { | ||
// Used in conformance tests | ||
headers.get[`X-Test-Case-Name`] match { | ||
case Some(header) => | ||
logger.trace(s">>> Test Case: ${header.value}") | ||
case None => // ignore | ||
} | ||
} | ||
|
||
if (logger.isTraceEnabled) { | ||
logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}") | ||
} | ||
|
||
val callOptions = CallOptions.DEFAULT | ||
.pipe( | ||
headers.timeout match { | ||
case Some(timeout) => _.withDeadlineAfter(timeout, MILLISECONDS) | ||
case None => identity | ||
} | ||
) | ||
|
||
ClientCalls | ||
.asyncUnaryCall( | ||
channel, | ||
method.descriptor, | ||
callOptions, | ||
headers.toMetadata, | ||
message | ||
) | ||
.map { response => | ||
val headers = response.headers.toHeaders() ++ | ||
response.trailers.toHeaders(trailing = !treatTrailersAsHeaders) | ||
|
||
if (logger.isTraceEnabled) { | ||
logger.trace(s"<<< Headers: ${headers.redactSensitive()}") | ||
} | ||
|
||
Response(Ok, headers = headers).withMessage(response.value) | ||
} | ||
.handleError { e => | ||
val grpcStatus = e match { | ||
case e: StatusException => | ||
e.getStatus.getDescription match { | ||
case "an implementation is missing" => io.grpc.Status.UNIMPLEMENTED | ||
case _ => e.getStatus | ||
} | ||
case e: StatusRuntimeException => e.getStatus | ||
case _: MessageFailure => io.grpc.Status.INVALID_ARGUMENT | ||
case _ => io.grpc.Status.INTERNAL | ||
} | ||
|
||
val (message, metadata) = e match { | ||
case e: StatusRuntimeException => (Option(e.getStatus.getDescription), e.getTrailers) | ||
case e: StatusException => (Option(e.getStatus.getDescription), e.getTrailers) | ||
case e => (Option(e.getMessage), new Metadata()) | ||
} | ||
|
||
val httpStatus = grpcStatus.toHttpStatus | ||
val connectCode = grpcStatus.toConnectCode | ||
|
||
// Should be called before converting metadata to headers | ||
val details = Option(metadata.removeAll(GrpcHeaders.ErrorDetailsKey)) | ||
.fold(Seq.empty)(_.asScala.toSeq) | ||
|
||
val headers = metadata.toHeaders(trailing = !treatTrailersAsHeaders) | ||
|
||
if (logger.isTraceEnabled) { | ||
logger.trace(s"<<< Http Status: $httpStatus, Connect Error Code: $connectCode") | ||
logger.trace(s"<<< Headers: ${headers.redactSensitive()}") | ||
logger.trace(s"<<< Error processing request", e) | ||
} | ||
|
||
Response[F](httpStatus, headers = headers).withMessage(connectrpc.Error( | ||
code = connectCode, | ||
message = message, | ||
details = details | ||
)) | ||
} | ||
} | ||
|
||
} |
117 changes: 117 additions & 0 deletions
117
core/src/main/scala/org/ivovk/connect_rpc_scala/TranscodingUrlMatcher.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
package org.ivovk.connect_rpc_scala | ||
|
||
import cats.implicits.* | ||
import com.google.api.HttpRule | ||
import org.http4s.{Method, Request, Uri} | ||
import org.ivovk.connect_rpc_scala | ||
import org.ivovk.connect_rpc_scala.grpc.MethodRegistry | ||
import org.json4s.JsonAST.{JField, JObject} | ||
import org.json4s.{JString, JValue} | ||
|
||
import scala.util.boundary | ||
import scala.util.boundary.break | ||
|
||
case class MatchedRequest(method: MethodRegistry.Entry, json: JValue) | ||
|
||
object TranscodingUrlMatcher { | ||
case class Entry( | ||
method: MethodRegistry.Entry, | ||
httpMethodMatcher: Method => Boolean, | ||
pattern: Uri.Path, | ||
) | ||
|
||
def create[F[_]]( | ||
methods: Seq[MethodRegistry.Entry], | ||
pathPrefix: Uri.Path, | ||
): TranscodingUrlMatcher[F] = { | ||
val entries = methods.flatMap { method => | ||
method.httpRule match { | ||
case Some(httpRule) => | ||
val (httpMethod, pattern) = extractMethodAndPattern(httpRule) | ||
|
||
val httpMethodMatcher: Method => Boolean = m => httpMethod.forall(_ == m) | ||
|
||
Entry( | ||
method, | ||
httpMethodMatcher, | ||
pathPrefix.dropEndsWithSlash.concat(pattern.toRelative) | ||
).some | ||
case None => none | ||
} | ||
} | ||
|
||
new TranscodingUrlMatcher( | ||
entries, | ||
) | ||
} | ||
|
||
private def extractMethodAndPattern(rule: HttpRule): (Option[Method], Uri.Path) = { | ||
val (method, str) = rule.getPatternCase match | ||
case HttpRule.PatternCase.GET => (Method.GET.some, rule.getGet) | ||
case HttpRule.PatternCase.PUT => (Method.PUT.some, rule.getPut) | ||
case HttpRule.PatternCase.POST => (Method.POST.some, rule.getPost) | ||
case HttpRule.PatternCase.DELETE => (Method.DELETE.some, rule.getDelete) | ||
case HttpRule.PatternCase.PATCH => (Method.PATCH.some, rule.getPatch) | ||
case HttpRule.PatternCase.CUSTOM => (none, rule.getCustom.getPath) | ||
case other => throw new RuntimeException(s"Unsupported pattern case $other (Rule: $rule)") | ||
|
||
val path = Uri.Path.unsafeFromString(str).dropEndsWithSlash | ||
|
||
(method, path) | ||
} | ||
} | ||
|
||
class TranscodingUrlMatcher[F[_]]( | ||
entries: Seq[TranscodingUrlMatcher.Entry], | ||
) { | ||
|
||
import org.ivovk.connect_rpc_scala.http.json.JsonProcessing.* | ||
|
||
def matchesRequest(req: Request[F]): Option[MatchedRequest] = boundary { | ||
entries.foreach { entry => | ||
if (entry.httpMethodMatcher(req.method)) { | ||
matchExtract(entry.pattern, req.uri.path) match { | ||
case Some(pathParams) => | ||
val queryParams = req.uri.query.toList.map((k, v) => k -> JString(v.getOrElse(""))) | ||
|
||
val merged = mergeFields(groupFields(pathParams), groupFields(queryParams)) | ||
|
||
break(Some(MatchedRequest(entry.method, JObject(merged)))) | ||
case None => // continue | ||
} | ||
} | ||
} | ||
|
||
none | ||
} | ||
|
||
/** | ||
* Matches path segments with pattern segments and extracts variables from the path. | ||
* Returns None if the path does not match the pattern. | ||
*/ | ||
private def matchExtract(pattern: Uri.Path, path: Uri.Path): Option[List[JField]] = boundary { | ||
if path.segments.length != pattern.segments.length then boundary.break(none) | ||
|
||
path.segments.indices | ||
.foldLeft(List.empty[JField]) { (state, idx) => | ||
val pathSegment = path.segments(idx) | ||
val patternSegment = pattern.segments(idx) | ||
|
||
if isVariable(patternSegment) then | ||
val varName = patternSegment.encoded.substring(1, patternSegment.encoded.length - 1) | ||
|
||
(varName -> JString(pathSegment.encoded)) :: state | ||
else if pathSegment != patternSegment then | ||
boundary.break(none) | ||
else state | ||
} | ||
.some | ||
} | ||
|
||
private def isVariable(segment: Uri.Path.Segment): Boolean = { | ||
val enc = segment.encoded | ||
val length = enc.length | ||
|
||
length > 2 && enc(0) == '{' && enc(length - 1) == '}' | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.