Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codegen: Improve enum support #3861

Merged
merged 15 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion doc/generator/sbt-openapi-codegen.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ jsoniter "com.github.plokhotnyuk.jsoniter-scala" %% "jsoniter-scala
Currently, string-like enums in Scala 2 depend upon the enumeratum library (`"com.beachape" %% "enumeratum"`).
For Scala 3 we derive native enums, and depend on `"io.github.bishabosha" %% "enum-extensions"` for generating query
param serdes.
Other forms of OpenApi enum are not currently supported.

Models containing binary data cannot be re-used between json and multi-part form endpoints, due to having different
representation types for the binary data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ object BasicGenerator {
JsonSerdeLib.Circe
}

val EndpointDefs(endpointsByTag, queryParamRefs, jsonParamRefs) = endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames)
val EndpointDefs(endpointsByTag, queryParamRefs, jsonParamRefs, enumsDefinedOnEndpointParams) =
endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames, targetScala3, normalisedJsonLib)
val GeneratedClassDefinitions(classDefns, jsonSerdes, schemas) =
classGenerator
.classDefs(
Expand All @@ -59,7 +60,8 @@ object BasicGenerator {
jsonParamRefs = jsonParamRefs,
fullModelPath = s"$packagePath.$objName",
validateNonDiscriminatedOneOfs = validateNonDiscriminatedOneOfs,
maxSchemasPerFile = maxSchemasPerFile
maxSchemasPerFile = maxSchemasPerFile,
enumsDefinedOnEndpointParams = enumsDefinedOnEndpointParams
)
.getOrElse(GeneratedClassDefinitions("", None, Nil))
val hasJsonSerdes = jsonSerdes.nonEmpty
Expand Down Expand Up @@ -140,13 +142,50 @@ object BasicGenerator {
.mkString("\n")

val extraImports = if (endpointsInMain.nonEmpty) s"$maybeJsonImport$maybeSchemaImport" else ""
val queryParamSupport =
"""
|case class CommaSeparatedValues[T](values: List[T])
|case class ExplodedValues[T](values: List[T])
|trait QueryParamSupport[T] {
| def decode(s: String): sttp.tapir.DecodeResult[T]
| def encode(t: T): String
|}
|implicit def makeQueryCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(support.decode)(support.encode)
|}
|implicit def makeQueryOptCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], Option[T], sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.listHeadOption[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(maybeV => DecodeResult.sequence(maybeV.toSeq.map(support.decode)).map(_.headOption))(_.map(support.encode))
|}
|implicit def makeUnexplodedQuerySeqCodecFromListHead[T](implicit support: sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], CommaSeparatedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(values => DecodeResult.sequence(values.split(',').toSeq.map(e => support.rawDecode(List(e)))).map(s => CommaSeparatedValues(s.toList)))(_.values.map(support.encode).mkString(","))
|}
|implicit def makeUnexplodedQueryOptSeqCodecFromListHead[T](implicit support: sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], Option[CommaSeparatedValues[T]], sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.listHeadOption[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode{
| case None => DecodeResult.Value(None)
| case Some(values) => DecodeResult.sequence(values.split(',').toSeq.map(e => support.rawDecode(List(e)))).map(r => Some(CommaSeparatedValues(r.toList)))
| }(_.map(_.values.map(support.encode).mkString(",")))
|}
|implicit def makeExplodedQuerySeqCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], ExplodedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.list[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(values => DecodeResult.sequence(values.map(support.decode)).map(s => ExplodedValues(s.toList)))(_.values.map(support.encode))
|}
|implicit def makeExplodedQuerySeqCodecFromListSeq[T](implicit support: sttp.tapir.Codec[List[String], List[T], sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], ExplodedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
| support.mapDecode(l => DecodeResult.Value(ExplodedValues(l)))(_.values)
|}
|""".stripMargin
val mainObj = s"""
|package $packagePath
|
|object $objName {
|
|${indent(2)(imports(normalisedJsonLib) + extraImports)}
|
|${indent(2)(queryParamSupport)}
|
|${indent(2)(classDefns)}
|
|${indent(2)(maybeSpecificationExtensionKeys)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ class ClassDefinitionGenerator {
jsonParamRefs: Set[String] = Set.empty,
fullModelPath: String = "",
validateNonDiscriminatedOneOfs: Boolean = true,
maxSchemasPerFile: Int = 400
maxSchemasPerFile: Int = 400,
enumsDefinedOnEndpointParams: Boolean = false
): Option[GeneratedClassDefinitions] = {
val allSchemas: Map[String, OpenapiSchemaType] = doc.components.toSeq.flatMap(_.schemas).toMap
val allOneOfSchemas = allSchemas.collect { case (name, oneOf: OpenapiSchemaOneOf) => name -> oneOf }.toSeq
val adtInheritanceMap: Map[String, Seq[String]] = mkMapParentsByChild(allOneOfSchemas)
val generatesQueryParamEnums =
val generatesQueryParamEnums = enumsDefinedOnEndpointParams ||
allSchemas
.collect { case (name, _: OpenapiSchemaEnum) => name }
.exists(queryParamRefs.contains)
Expand All @@ -49,14 +50,15 @@ class ClassDefinitionGenerator {
allTransitiveJsonParamRefs,
fullModelPath,
validateNonDiscriminatedOneOfs,
adtInheritanceMap
adtInheritanceMap,
targetScala3
)
val defns = doc.components
.map(_.schemas.flatMap {
case (name, obj: OpenapiSchemaObject) =>
generateClass(allSchemas, name, obj, allTransitiveJsonParamRefs, adtInheritanceMap)
generateClass(allSchemas, name, obj, allTransitiveJsonParamRefs, adtInheritanceMap, jsonSerdeLib, targetScala3)
case (name, obj: OpenapiSchemaEnum) =>
generateEnum(name, obj, targetScala3, queryParamRefs, jsonSerdeLib, allTransitiveJsonParamRefs)
EnumGenerator.generateEnum(name, obj, targetScala3, queryParamRefs, jsonSerdeLib, allTransitiveJsonParamRefs)
case (name, OpenapiSchemaMap(valueSchema, _)) => generateMap(name, valueSchema)
case (_, _: OpenapiSchemaOneOf) => Nil
case (n, x) => throw new NotImplementedError(s"Only objects, enums and maps supported! (for $n found ${x})")
Expand Down Expand Up @@ -95,50 +97,55 @@ class ClassDefinitionGenerator {
.groupBy(_._1)
.mapValues(_.map(_._2))

private def enumQuerySerdeHelperDefn(targetScala3: Boolean): String = if (targetScala3)
"""
|def enumMap[E: enumextensions.EnumMirror]: Map[String, E] =
| Map.from(
| for e <- enumextensions.EnumMirror[E].values yield e.name.toUpperCase -> e
| )
|
|def makeQueryCodecForEnum[T: enumextensions.EnumMirror]: sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] =
| sttp.tapir.Codec
| .listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(s =>
| // Case-insensitive mapping
| scala.util
| .Try(enumMap[T](using enumextensions.EnumMirror[T])(s.toUpperCase))
| .fold(
| _ =>
| sttp.tapir.DecodeResult.Error(
| s,
| new NoSuchElementException(
| s"Could not find value $s for enum ${enumextensions.EnumMirror[T].mirroredName}, available values: ${enumextensions.EnumMirror[T].values.mkString(", ")}"
| )
| ),
| sttp.tapir.DecodeResult.Value(_)
| )
| )(_.name)
|""".stripMargin
else
"""def makeQueryCodecForEnum[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] =
| sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(s =>
| // Case-insensitive mapping
| scala.util.Try(T.upperCaseNameValuesToMap(s.toUpperCase))
| .fold(
| _ =>
| sttp.tapir.DecodeResult.Error(
| s,
| new NoSuchElementException(
| s"Could not find value $s for enum ${enumName}, available values: ${T.values.mkString(", ")}"
| )
| ),
| sttp.tapir.DecodeResult.Value(_)
| )
| )(_.entryName)
|""".stripMargin
private def enumQuerySerdeHelperDefn(targetScala3: Boolean): String = {
if (targetScala3)
"""
|def enumMap[E: enumextensions.EnumMirror]: Map[String, E] =
| Map.from(
| for e <- enumextensions.EnumMirror[E].values yield e.name.toUpperCase -> e
| )
|case class EnumQueryParamSupport[T: enumextensions.EnumMirror](eMap: Map[String, T]) extends QueryParamSupport[T] {
| // Case-insensitive mapping
| def decode(s: String): sttp.tapir.DecodeResult[T] =
| scala.util
| .Try(eMap(s.toUpperCase))
| .fold(
| _ =>
| sttp.tapir.DecodeResult.Error(
| s,
| new NoSuchElementException(
| s"Could not find value $s for enum ${enumextensions.EnumMirror[T].mirroredName}, available values: ${enumextensions.EnumMirror[T].values.mkString(", ")}"
| )
| ),
| sttp.tapir.DecodeResult.Value(_)
| )
| def encode(t: T): String = t.name
|}
|def queryCodecSupport[T: enumextensions.EnumMirror]: QueryParamSupport[T] =
| EnumQueryParamSupport(enumMap[T](using enumextensions.EnumMirror[T]))
|""".stripMargin
else
"""
|case class EnumQueryParamSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]) extends QueryParamSupport[T] {
| // Case-insensitive mapping
| def decode(s: String): sttp.tapir.DecodeResult[T] =
| scala.util.Try(T.upperCaseNameValuesToMap(s.toUpperCase))
| .fold(
| _ =>
| sttp.tapir.DecodeResult.Error(
| s,
| new NoSuchElementException(
| s"Could not find value $s for enum ${enumName}, available values: ${T.values.mkString(", ")}"
| )
| ),
| sttp.tapir.DecodeResult.Value(_)
| )
| def encode(t: T): String = t.entryName
|}
|def queryCodecSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]): QueryParamSupport[T] =
| EnumQueryParamSupport(enumName, T)
|""".stripMargin
}

@tailrec
final def recursiveFindAllReferencedSchemaTypes(
Expand Down Expand Up @@ -191,63 +198,14 @@ class ClassDefinitionGenerator {
Seq(s"""type $name = Map[String, $valueSchemaName]""")
}

// Uses enumeratum for scala 2, but generates scala 3 enums instead where it can
private[codegen] def generateEnum(
name: String,
obj: OpenapiSchemaEnum,
targetScala3: Boolean,
queryParamRefs: Set[String],
jsonSerdeLib: JsonSerdeLib.JsonSerdeLib,
jsonParamRefs: Set[String]
): Seq[String] = if (targetScala3) {
val maybeCompanion =
if (queryParamRefs contains name)
s"""
|object $name {
| given stringList${name}Codec: sttp.tapir.Codec[List[String], $name, sttp.tapir.CodecFormat.TextPlain] =
| makeQueryCodecForEnum[$name]
|}""".stripMargin
else ""
val maybeCodecExtensions = jsonSerdeLib match {
case _ if !jsonParamRefs.contains(name) && !queryParamRefs.contains(name) => ""
case _ if !jsonParamRefs.contains(name) => " derives enumextensions.EnumMirror"
case JsonSerdeLib.Circe if !queryParamRefs.contains(name) => " derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec"
case JsonSerdeLib.Circe => " derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec, enumextensions.EnumMirror"
case JsonSerdeLib.Jsoniter if !queryParamRefs.contains(name) => s" extends java.lang.Enum[$name]"
case JsonSerdeLib.Jsoniter => s" extends java.lang.Enum[$name] derives enumextensions.EnumMirror"
}
s"""$maybeCompanion
|enum $name$maybeCodecExtensions {
| case ${obj.items.map(_.value).mkString(", ")}
|}""".stripMargin :: Nil
} else {
val uncapitalisedName = BasicGenerator.uncapitalise(name)
val members = obj.items.map { i => s"case object ${i.value} extends $name" }
val maybeCodecExtension = jsonSerdeLib match {
case _ if !jsonParamRefs.contains(name) && !queryParamRefs.contains(name) => ""
case JsonSerdeLib.Circe => s" with enumeratum.CirceEnum[$name]"
case JsonSerdeLib.Jsoniter => ""
}
val maybeQueryCodecDefn =
if (queryParamRefs contains name)
s"""
| implicit val ${uncapitalisedName}QueryCodec: sttp.tapir.Codec[List[String], ${name}, sttp.tapir.CodecFormat.TextPlain] =
| makeQueryCodecForEnum("${name}", ${name})""".stripMargin
else ""
s"""
|sealed trait $name extends enumeratum.EnumEntry
|object $name extends enumeratum.Enum[$name]$maybeCodecExtension {
| val values = findValues
|${indent(2)(members.mkString("\n"))}$maybeQueryCodecDefn
|}""".stripMargin :: Nil
}

private[codegen] def generateClass(
allSchemas: Map[String, OpenapiSchemaType],
name: String,
obj: OpenapiSchemaObject,
jsonParamRefs: Set[String],
adtInheritanceMap: Map[String, Seq[String]]
adtInheritanceMap: Map[String, Seq[String]],
jsonSerdeLib: JsonSerdeLib.JsonSerdeLib,
targetScala3: Boolean
): Seq[String] = {
val isJson = jsonParamRefs contains name
def rec(name: String, obj: OpenapiSchemaObject, acc: List[String]): Seq[String] = {
Expand All @@ -268,24 +226,25 @@ class ClassDefinitionGenerator {
.flatten
.toList

val properties = obj.properties.map { case (key, OpenapiSchemaField(schemaType, maybeDefault)) =>
val tpe = mapSchemaTypeToType(name, key, obj.required.contains(key), schemaType, isJson)
val (properties, maybeEnums) = obj.properties.map { case (key, OpenapiSchemaField(schemaType, maybeDefault)) =>
val (tpe, maybeEnum) = mapSchemaTypeToType(name, key, obj.required.contains(key), schemaType, isJson, jsonSerdeLib, targetScala3)
val fixedKey = fixKey(key)
val optional = schemaType.nullable || !obj.required.contains(key)
val maybeExplicitDefault =
maybeDefault.map(" = " + DefaultValueRenderer.render(allModels = allSchemas, thisType = schemaType, optional)(_))
val default = maybeExplicitDefault getOrElse (if (optional) " = None" else "")
s"$fixedKey: $tpe$default"
}
s"$fixedKey: $tpe$default" -> maybeEnum
}.unzip

val parents = adtInheritanceMap.getOrElse(name, Nil) match {
case Nil => ""
case ps => ps.mkString(" extends ", " with ", "")
}

val enumDefn = maybeEnums.flatten.toList
s"""|case class $name (
|${indent(2)(properties.mkString(",\n"))}
|)$parents""".stripMargin :: innerClasses ::: acc
|)$parents""".stripMargin :: innerClasses ::: enumDefn ::: acc
}

rec(addName("", name), obj, Nil)
Expand All @@ -296,28 +255,52 @@ class ClassDefinitionGenerator {
key: String,
required: Boolean,
schemaType: OpenapiSchemaType,
isJson: Boolean
): String = {
val (tpe, optional) = schemaType match {
isJson: Boolean,
jsonSerdeLib: JsonSerdeLib.JsonSerdeLib,
targetScala3: Boolean
): (String, Option[String]) = {
val ((tpe, optional), maybeEnum) = schemaType match {
case simpleType: OpenapiSchemaSimpleType =>
mapSchemaSimpleTypeToType(simpleType, multipartForm = !isJson)
mapSchemaSimpleTypeToType(simpleType, multipartForm = !isJson) -> None

case objectType: OpenapiSchemaObject =>
addName(parentName, key) -> objectType.nullable
(addName(parentName, key) -> objectType.nullable, None)

case mapType: OpenapiSchemaMap =>
val innerType = mapSchemaTypeToType(addName(parentName, key), "item", required = true, mapType.items, isJson = isJson)
s"Map[String, $innerType]" -> mapType.nullable
val (innerType, maybeEnum) =
mapSchemaTypeToType(addName(parentName, key), "item", required = true, mapType.items, isJson = isJson, jsonSerdeLib, targetScala3)
(s"Map[String, $innerType]" -> mapType.nullable, maybeEnum)

case arrayType: OpenapiSchemaArray =>
val innerType = mapSchemaTypeToType(addName(parentName, key), "item", required = true, arrayType.items, isJson = isJson)
s"Seq[$innerType]" -> arrayType.nullable
val (innerType, maybeEnum) =
mapSchemaTypeToType(
addName(parentName, key),
"item",
required = true,
arrayType.items,
isJson = isJson,
jsonSerdeLib,
targetScala3
)
(s"Seq[$innerType]" -> arrayType.nullable, maybeEnum)

case e: OpenapiSchemaEnum =>
val enumName = addName(parentName.capitalize, key)
val enumDefn = EnumGenerator.generateEnum(
enumName,
e,
targetScala3,
Set.empty,
jsonSerdeLib,
if (isJson) Set(enumName) else Set.empty
)
(enumName -> e.nullable, Some(enumDefn.mkString("\n")))

case _ =>
throw new NotImplementedError(s"We can't serialize some of the properties yet! $parentName $key $schemaType")
}

if (optional || !required) s"Option[$tpe]" else tpe
(if (optional || !required) s"Option[$tpe]" else tpe, maybeEnum)
}

private def addName(parentName: String, key: String) = parentName + key.replace('_', ' ').replace('-', ' ').capitalize.replace(" ", "")
Expand Down
Loading
Loading