Skip to content

Commit

Permalink
codegen: add streaming support for application/octet-stream contents (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
hughsimpson authored Aug 6, 2024
1 parent 64c13ec commit ef22dd9
Show file tree
Hide file tree
Showing 14 changed files with 309 additions and 77 deletions.
1 change: 1 addition & 0 deletions generated-doc/out/generator/sbt-openapi-codegen.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ openapiJsonSerdeLib circe The j
openapiValidateNonDiscriminatedOneOfs true Whether to fail if variants of a oneOf without a discriminator cannot be disambiguated.
openapiMaxSchemasPerFile 400 Maximum number of schemas to generate in a single file (tweak if hitting javac class size limits).
openapiAdditionalPackages Nil Additional packageName/swaggerFile pairs for generating from multiple schemas
openapiStreamingImplementation fs2 Backend capability to assume for streaming content. Supports akka, fs2, pekko and zio.
===================================== ==================================== ==================================================================================================
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ object GenScala {
private val jsonLibOpt: Opts[Option[String]] =
Opts.option[String]("jsonLib", "Json library to use for serdes", "j").orNone

private val streamingImplementationOpt: Opts[Option[String]] =
Opts.option[String]("streamingImplementation", "Capability to use for binary streams", "s").orNone

private val destDirOpt: Opts[File] =
Opts
.option[String]("destdir", "Destination directory", "d")
Expand All @@ -84,7 +87,8 @@ object GenScala {
headTagForNamesOpt,
jsonLibOpt,
validateNonDiscriminatedOneOfsOpt,
maxSchemasPerFileOpt
maxSchemasPerFileOpt,
streamingImplementationOpt
)
.mapN {
case (
Expand All @@ -96,7 +100,8 @@ object GenScala {
headTagForNames,
jsonLib,
validateNonDiscriminatedOneOfs,
maxSchemasPerFile
maxSchemasPerFile,
streamingImplementation
) =>
val objectName = maybeObjectName.getOrElse(DefaultObjectName)

Expand All @@ -109,6 +114,7 @@ object GenScala {
targetScala3,
headTagForNames,
jsonLib.getOrElse("circe"),
streamingImplementation.getOrElse("fs2"),
validateNonDiscriminatedOneOfs,
maxSchemasPerFile.getOrElse(400)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ object JsonSerdeLib extends Enumeration {
val Circe, Jsoniter, Zio = Value
type JsonSerdeLib = Value
}
object StreamingImplementation extends Enumeration {
val Akka, FS2, Pekko, Zio = Value
type StreamingImplementation = Value
}

object BasicGenerator {

Expand All @@ -34,6 +38,7 @@ object BasicGenerator {
targetScala3: Boolean,
useHeadTagForObjectNames: Boolean,
jsonSerdeLib: String,
streamingImplementation: String,
validateNonDiscriminatedOneOfs: Boolean,
maxSchemasPerFile: Int
): Map[String, String] = {
Expand All @@ -47,9 +52,20 @@ object BasicGenerator {
)
JsonSerdeLib.Circe
}
val normalisedStreamingImplementation = streamingImplementation.toLowerCase match {
case "akka" => StreamingImplementation.Akka
case "fs2" => StreamingImplementation.FS2
case "pekko" => StreamingImplementation.Pekko
case "zio" => StreamingImplementation.Zio
case _ =>
System.err.println(
s"!!! Unrecognised value $streamingImplementation for streaming impl -- should be one of akka, fs2, pekko or zio. Defaulting to fs2 !!!"
)
StreamingImplementation.FS2
}

val EndpointDefs(endpointsByTag, queryOrPathParamRefs, jsonParamRefs, enumsDefinedOnEndpointParams) =
endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames, targetScala3, normalisedJsonLib)
endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames, targetScala3, normalisedJsonLib, normalisedStreamingImplementation)
val GeneratedClassDefinitions(classDefns, jsonSerdes, schemas) =
classGenerator
.classDefs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package sttp.tapir.codegen
import io.circe.Json
import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType, strippedToCamelCase}
import sttp.tapir.codegen.JsonSerdeLib.JsonSerdeLib
import sttp.tapir.codegen.StreamingImplementation
import sttp.tapir.codegen.StreamingImplementation.StreamingImplementation
import sttp.tapir.codegen.openapi.models.OpenapiModels.{OpenapiDocument, OpenapiParameter, OpenapiPath, OpenapiRequestBody, OpenapiResponse}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaAny,
Expand All @@ -10,7 +12,8 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaEnum,
OpenapiSchemaMap,
OpenapiSchemaRef,
OpenapiSchemaSimpleType
OpenapiSchemaSimpleType,
OpenapiSchemaString
}
import sttp.tapir.codegen.openapi.models.{OpenapiComponent, OpenapiSchemaType, OpenapiSecuritySchemeType, SpecificationExtensionRenderer}
import sttp.tapir.codegen.util.JavaEscape
Expand Down Expand Up @@ -55,12 +58,13 @@ class EndpointGenerator {
doc: OpenapiDocument,
useHeadTagForObjectNames: Boolean,
targetScala3: Boolean,
jsonSerdeLib: JsonSerdeLib
jsonSerdeLib: JsonSerdeLib,
streamingImplementation: StreamingImplementation
): EndpointDefs = {
val components = Option(doc.components).flatten
val GeneratedEndpoints(endpointsByFile, queryOrPathParamRefs, jsonParamRefs, definesEnumQueryParam) =
doc.paths
.map(generatedEndpoints(components, useHeadTagForObjectNames, targetScala3, jsonSerdeLib))
.map(generatedEndpoints(components, useHeadTagForObjectNames, targetScala3, jsonSerdeLib, streamingImplementation))
.foldLeft(GeneratedEndpoints(Nil, Set.empty, Set.empty, false))(_ merge _)
val endpointDecls = endpointsByFile.map { case GeneratedEndpointsForFile(k, ge) =>
val definitions = ge
Expand All @@ -84,7 +88,8 @@ class EndpointGenerator {
components: Option[OpenapiComponent],
useHeadTagForObjectNames: Boolean,
targetScala3: Boolean,
jsonSerdeLib: JsonSerdeLib
jsonSerdeLib: JsonSerdeLib,
streamingImplementation: StreamingImplementation
)(p: OpenapiPath): GeneratedEndpoints = {
val parameters = components.map(_.parameters).getOrElse(Map.empty)
val securitySchemes = components.map(_.securitySchemes).getOrElse(Map.empty)
Expand All @@ -106,14 +111,15 @@ class EndpointGenerator {
}

val name = strippedToCamelCase(m.operationId.getOrElse(m.methodType + p.url.capitalize))
val (inParams, maybeLocalEnums) = ins(m.resolvedParameters, m.requestBody, name, targetScala3, jsonSerdeLib)
val (inParams, maybeLocalEnums) =
ins(m.resolvedParameters, m.requestBody, name, targetScala3, jsonSerdeLib, streamingImplementation)
val definition =
s"""|endpoint
| .${m.methodType}
| ${urlMapper(p.url, m.resolvedParameters)}
|${indent(2)(security(securitySchemes, m.security))}
|${indent(2)(inParams)}
|${indent(2)(outs(m.responses))}
|${indent(2)(outs(m.responses, streamingImplementation))}
|${indent(2)(tags(m.tags))}
|$attributeString
|""".stripMargin.linesIterator.filterNot(_.trim.isEmpty).mkString("\n")
Expand Down Expand Up @@ -211,7 +217,8 @@ class EndpointGenerator {
requestBody: Option[OpenapiRequestBody],
endpointName: String,
targetScala3: Boolean,
jsonSerdeLib: JsonSerdeLib
jsonSerdeLib: JsonSerdeLib,
streamingImplementation: StreamingImplementation
)(implicit location: Location): (String, Option[String]) = {
def getEnumParamDefn(param: OpenapiParameter, e: OpenapiSchemaEnum, isArray: Boolean) = {
val enumName = endpointName.capitalize + strippedToCamelCase(param.name).capitalize
Expand Down Expand Up @@ -267,7 +274,7 @@ class EndpointGenerator {
val rqBody = requestBody.flatMap { b =>
if (b.content.isEmpty) None
else if (b.content.size != 1) bail(s"We can handle only one requestBody content! Saw ${b.content.map(_.contentType)}")
else Some(s".in(${contentTypeMapper(b.content.head.contentType, b.content.head.schema, b.required)})")
else Some(s".in(${contentTypeMapper(b.content.head.contentType, b.content.head.schema, streamingImplementation, b.required)})")
}

(params ++ rqBody).mkString("\n") -> maybeEnumDefns.foldLeft(Option.empty[String]) {
Expand Down Expand Up @@ -298,7 +305,7 @@ class EndpointGenerator {
// treats redirects as ok
private val okStatus = """([23]\d\d)""".r
private val errorStatus = """([45]\d\d)""".r
private def outs(responses: Seq[OpenapiResponse])(implicit location: Location) = {
private def outs(responses: Seq[OpenapiResponse], streamingImplementation: StreamingImplementation)(implicit location: Location) = {
// .errorOut(stringBody)
// .out(jsonBody[List[Book]])

Expand All @@ -315,13 +322,13 @@ class EndpointGenerator {
case content +: Nil =>
resp.code match {
case "200" =>
s".out(${contentTypeMapper(content.contentType, content.schema)}$d)"
s".out(${contentTypeMapper(content.contentType, content.schema, streamingImplementation)}$d)"
case okStatus(s) =>
s".out(${contentTypeMapper(content.contentType, content.schema)}$d.and(statusCode(sttp.model.StatusCode($s))))"
s".out(${contentTypeMapper(content.contentType, content.schema, streamingImplementation)}$d.and(statusCode(sttp.model.StatusCode($s))))"
case "default" =>
s".errorOut(${contentTypeMapper(content.contentType, content.schema)}$d)"
s".errorOut(${contentTypeMapper(content.contentType, content.schema, streamingImplementation)}$d)"
case errorStatus(s) =>
s".errorOut(${contentTypeMapper(content.contentType, content.schema)}$d.and(statusCode(sttp.model.StatusCode($s))))"
s".errorOut(${contentTypeMapper(content.contentType, content.schema, streamingImplementation)}$d.and(statusCode(sttp.model.StatusCode($s))))"
case x =>
bail(s"Statuscode mapping is incomplete! Cannot handle $x")
}
Expand All @@ -333,7 +340,12 @@ class EndpointGenerator {
.mkString("\n")
}

private def contentTypeMapper(contentType: String, schema: OpenapiSchemaType, required: Boolean = true)(implicit location: Location) = {
private def contentTypeMapper(
contentType: String,
schema: OpenapiSchemaType,
streamingImplementation: StreamingImplementation,
required: Boolean = true
)(implicit location: Location) = {
contentType match {
case "text/plain" =>
"stringBody"
Expand Down Expand Up @@ -362,6 +374,31 @@ class EndpointGenerator {
s"multipartBody[$t]"
case x => bail(s"$contentType only supports schema ref or binary. Found $x")
}
case "application/octet-stream" =>
val capability = streamingImplementation match {
case StreamingImplementation.Akka => "sttp.capabilities.akka.AkkaStreams"
case StreamingImplementation.FS2 => "sttp.capabilities.fs2.Fs2Streams[cats.effect.IO]"
case StreamingImplementation.Pekko => "sttp.capabilities.pekko.PekkoStreams"
case StreamingImplementation.Zio => "sttp.capabilities.zio.ZioStreams"
}
schema match {
case _: OpenapiSchemaString =>
s"streamTextBody($capability)(CodecFormat.OctetStream())"
case schema =>
val outT = schema match {
case st: OpenapiSchemaSimpleType =>
val (t, _) = mapSchemaSimpleTypeToType(st)
t
case OpenapiSchemaArray(st: OpenapiSchemaSimpleType, _) =>
val (t, _) = mapSchemaSimpleTypeToType(st)
s"List[$t]"
case OpenapiSchemaMap(st: OpenapiSchemaSimpleType, _) =>
val (t, _) = mapSchemaSimpleTypeToType(st)
s"Map[String, $t]"
case x => bail(s"Can't create this param as output (found $x)")
}
s"streamBody($capability)(Schema.binary[$outT], CodecFormat.OctetStream())"
}

case x => bail(s"Not all content types supported! Found $x")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class BasicGeneratorSpec extends CompileCheckTestBase {
useHeadTagForObjectNames = useHeadTagForObjectNames,
jsonSerdeLib = jsonSerdeLib,
validateNonDiscriminatedOneOfs = true,
maxSchemasPerFile = 400
maxSchemasPerFile = 400,
streamingImplementation = "fs2"
)
}
def gen(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,13 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {
case Left(value) => throw new Exception(value)
case Right(doc) =>
new EndpointGenerator()
.endpointDefs(doc, useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe)
.endpointDefs(
doc,
useHeadTagForObjectNames = false,
targetScala3 = false,
jsonSerdeLib = JsonSerdeLib.Circe,
streamingImplementation = StreamingImplementation.FS2
)
.endpointDecls(None)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
)
val generatedCode = BasicGenerator.imports(JsonSerdeLib.Circe) ++
new EndpointGenerator()
.endpointDefs(doc, useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe)
.endpointDefs(
doc,
useHeadTagForObjectNames = false,
targetScala3 = false,
jsonSerdeLib = JsonSerdeLib.Circe,
streamingImplementation = StreamingImplementation.FS2
)
.endpointDecls(None)
generatedCode should include("val getTestAsdId =")
generatedCode should include(""".in(query[Option[String]]("fgh-id"))""")
Expand Down Expand Up @@ -142,7 +148,13 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
)
BasicGenerator.imports(JsonSerdeLib.Circe) ++
new EndpointGenerator()
.endpointDefs(doc, useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe)
.endpointDefs(
doc,
useHeadTagForObjectNames = false,
targetScala3 = false,
jsonSerdeLib = JsonSerdeLib.Circe,
streamingImplementation = StreamingImplementation.FS2
)
.endpointDecls(None) shouldCompile ()
}

Expand Down Expand Up @@ -188,7 +200,13 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
)
val generatedCode = BasicGenerator.imports(JsonSerdeLib.Circe) ++
new EndpointGenerator()
.endpointDefs(doc, useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe)
.endpointDefs(
doc,
useHeadTagForObjectNames = false,
targetScala3 = false,
jsonSerdeLib = JsonSerdeLib.Circe,
streamingImplementation = StreamingImplementation.FS2
)
.endpointDecls(None)
generatedCode should include(
""".out(stringBody.description("Processing").and(statusCode(sttp.model.StatusCode(202))))"""
Expand Down Expand Up @@ -253,7 +271,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
useHeadTagForObjectNames = false,
jsonSerdeLib = "circe",
validateNonDiscriminatedOneOfs = true,
maxSchemasPerFile = 400
maxSchemasPerFile = 400,
streamingImplementation = "fs2"
)("TapirGeneratedEndpoints")
generatedCode should include(
"""file: sttp.model.Part[java.io.File]"""
Expand All @@ -274,7 +293,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
useHeadTagForObjectNames = false,
jsonSerdeLib = "circe",
validateNonDiscriminatedOneOfs = true,
maxSchemasPerFile = 400
maxSchemasPerFile = 400,
streamingImplementation = "fs2"
)("TapirGeneratedEndpoints")
generatedCode shouldCompile ()
val expectedAttrDecls = Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@ package sttp.tapir.sbt

import sbt._

case class OpenApiConfiguration(
swaggerFile: File,
packageName: String,
objectName: String,
useHeadTagForObjectName: Boolean,
jsonSerdeLib: String,
streamingImplementation: String,
validateNonDiscriminatedOneOfs: Boolean,
maxSchemasPerFile: Int,
additionalPackages: List[(String, File)]
)

trait OpenapiCodegenKeys {
lazy val openapiSwaggerFile = settingKey[File]("The swagger file with the api definitions.")
lazy val openapiPackage = settingKey[String]("The name for the generated package.")
Expand All @@ -13,7 +25,10 @@ trait OpenapiCodegenKeys {
lazy val openapiValidateNonDiscriminatedOneOfs =
settingKey[Boolean]("Whether to fail if variants of a oneOf without a discriminator cannot be disambiguated..")
lazy val openapiMaxSchemasPerFile = settingKey[Int]("Maximum number of schemas to generate for a single file")
lazy val openapiAdditionalPackages = taskKey[List[(String, File)]]("Addition package -> spec mappings to generate.")
lazy val openapiAdditionalPackages = settingKey[List[(String, File)]]("Addition package -> spec mappings to generate.")
lazy val openapiStreamingImplementation = settingKey[String]("Implementation for streamTextBody. Supports: akka, fs2, pekko, zio.")
lazy val openapiOpenApiConfiguration =
settingKey[OpenApiConfiguration]("Aggregation of other settings. Manually set value will be disregarded.")

lazy val generateTapirDefinitions = taskKey[Unit]("The task that generates tapir definitions based on the input swagger file.")
}
Expand Down
Loading

0 comments on commit ef22dd9

Please sign in to comment.