From 20a526c048c8a24a99b3c60465688c4b92e44078 Mon Sep 17 00:00:00 2001 From: hughsimpson Date: Fri, 19 Apr 2024 12:05:23 +0100 Subject: [PATCH] codegen: Semiauto schema derivation (#3671) --- doc/generator/sbt-openapi-codegen.md | 7 +- .../scala/sttp/tapir/codegen/GenScala.scala | 11 +- .../sttp/tapir/codegen/BasicGenerator.scala | 66 +++-- .../codegen/ClassDefinitionGenerator.scala | 14 +- .../tapir/codegen/JsonSerdeGenerator.scala | 44 +++- .../sttp/tapir/codegen/SchemaGenerator.scala | 247 ++++++++++++++++++ .../tapir/codegen/BasicGeneratorSpec.scala | 28 +- .../ClassDefinitionGeneratorSpec.scala | 38 ++- .../tapir/codegen/EndpointGeneratorSpec.scala | 6 +- .../sttp/tapir/sbt/OpenapiCodegenKeys.scala | 1 + .../sttp/tapir/sbt/OpenapiCodegenPlugin.scala | 6 +- .../sttp/tapir/sbt/OpenapiCodegenTask.scala | 4 +- .../oneOf-json-roundtrip/Expected.scala.txt | 3 +- .../ExpectedJsonSerdes.scala.txt | 53 ++++ .../ExpectedSchemas.scala.txt | 40 +++ .../oneOf-json-roundtrip/build.sbt | 27 +- .../Expected.scala.txt | 1 + .../Expected.scala.txt | 1 + 18 files changed, 518 insertions(+), 79 deletions(-) create mode 100644 openapi-codegen/core/src/main/scala/sttp/tapir/codegen/SchemaGenerator.scala create mode 100644 openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedJsonSerdes.scala.txt create mode 100644 openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedSchemas.scala.txt diff --git a/doc/generator/sbt-openapi-codegen.md b/doc/generator/sbt-openapi-codegen.md index 1cf48acd8b..50e299f4e8 100644 --- a/doc/generator/sbt-openapi-codegen.md +++ b/doc/generator/sbt-openapi-codegen.md @@ -35,16 +35,17 @@ defined case-classes and endpoint definitions. The generator currently supports these settings, you can override them in the `build.sbt`; ```eval_rst -===================================== ==================================== ======================================================================================= +===================================== ==================================== ================================================================================================== setting default value description -===================================== ==================================== ======================================================================================= +===================================== ==================================== ================================================================================================== openapiSwaggerFile baseDirectory.value / "swagger.yaml" The swagger file with the api definitions. openapiPackage sttp.tapir.generated The name for the generated package. openapiObject TapirGeneratedEndpoints The name for the generated object. openapiUseHeadTagForObjectName false If true, put endpoints in separate files based on first declared tag. openapiJsonSerdeLib circe The json serde library to use. openapiValidateNonDiscriminatedOneOfs true Whether to fail if variants of a oneOf without a discriminator cannot be disambiguated. -===================================== ==================================== ======================================================================================= +openapiMaxSchemasPerFile 400 Maximum number of schemas to generate in a single file (tweak if hitting javac class size limits). +===================================== ==================================== ================================================================================================== ``` The general usage is; diff --git a/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala b/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala index 3ad36d321e..82e34a35d5 100644 --- a/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala +++ b/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala @@ -54,6 +54,10 @@ object GenScala { "v" ) .orFalse + private val maxSchemasPerFileOpt: Opts[Option[Int]] = + Opts + .option[Int]("maxSchemasPerFile", "Maximum number of schemas to generate in a single file.", "m") + .orNone private val jsonLibOpt: Opts[Option[String]] = Opts.option[String]("jsonLib", "Json library to use for serdes", "j").orNone @@ -71,8 +75,8 @@ object GenScala { } val cmd: Command[IO[ExitCode]] = Command("genscala", "Generate Scala classes", helpFlag = true) { - (fileOpt, packageNameOpt, destDirOpt, objectNameOpt, targetScala3Opt, headTagForNamesOpt, jsonLibOpt, validateNonDiscriminatedOneOfsOpt) - .mapN { case (file, packageName, destDir, maybeObjectName, targetScala3, headTagForNames, jsonLib, validateNonDiscriminatedOneOfs) => + (fileOpt, packageNameOpt, destDirOpt, objectNameOpt, targetScala3Opt, headTagForNamesOpt, jsonLibOpt, validateNonDiscriminatedOneOfsOpt, maxSchemasPerFileOpt) + .mapN { case (file, packageName, destDir, maybeObjectName, targetScala3, headTagForNames, jsonLib, validateNonDiscriminatedOneOfs, maxSchemasPerFile) => val objectName = maybeObjectName.getOrElse(DefaultObjectName) def generateCode(doc: OpenapiDocument): IO[Unit] = for { @@ -84,7 +88,8 @@ object GenScala { targetScala3, headTagForNames, jsonLib.getOrElse("circe"), - validateNonDiscriminatedOneOfs + validateNonDiscriminatedOneOfs, + maxSchemasPerFile.getOrElse(400) ) ) destFiles <- contents.toVector.traverse { case (fileName, content) => writeGeneratedFile(destDir, fileName, content) } diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala index fd8c3f7eb1..7d1b8fce5a 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala @@ -34,7 +34,8 @@ object BasicGenerator { targetScala3: Boolean, useHeadTagForObjectNames: Boolean, jsonSerdeLib: String, - validateNonDiscriminatedOneOfs: Boolean + validateNonDiscriminatedOneOfs: Boolean, + maxSchemasPerFile: Int ): Map[String, String] = { val normalisedJsonLib = jsonSerdeLib.toLowerCase match { case "circe" => JsonSerdeLib.Circe @@ -47,7 +48,7 @@ object BasicGenerator { } val EndpointDefs(endpointsByTag, queryParamRefs, jsonParamRefs) = endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames) - val GeneratedClassDefinitions(classDefns, extras) = + val GeneratedClassDefinitions(classDefns, jsonSerdes, schemas) = classGenerator .classDefs( doc = doc, @@ -56,15 +57,19 @@ object BasicGenerator { jsonSerdeLib = normalisedJsonLib, jsonParamRefs = jsonParamRefs, fullModelPath = s"$packagePath.$objName", - validateNonDiscriminatedOneOfs = validateNonDiscriminatedOneOfs + validateNonDiscriminatedOneOfs = validateNonDiscriminatedOneOfs, + maxSchemasPerFile = maxSchemasPerFile ) - .getOrElse(GeneratedClassDefinitions("", None)) - val isSplit = extras.nonEmpty - val internalImports = - if (isSplit) - s"""import $packagePath.$objName._ - |import ${objName}JsonSerdes._""".stripMargin - else s"import $objName._" + .getOrElse(GeneratedClassDefinitions("", None, Nil)) + val hasJsonSerdes = jsonSerdes.nonEmpty + + val maybeJsonImport = if (hasJsonSerdes) s"\nimport $packagePath.${objName}JsonSerdes._" else "" + val maybeSchemaImport = + if (schemas.size > 1) (1 to schemas.size).map(i => s"import ${objName}Schemas$i._").mkString("\n", "\n", "") + else if (schemas.size == 1) s"\nimport ${objName}Schemas._" + else "" + val internalImports = s"import $packagePath.$objName._$maybeJsonImport$maybeSchemaImport" + val taggedObjs = endpointsByTag.collect { case (Some(headTag), body) if body.nonEmpty => val taggedObj = @@ -81,14 +86,39 @@ object BasicGenerator { |}""".stripMargin headTag -> taggedObj } - val extraObj = extras.map { body => + + val jsonSerdeObj = jsonSerdes.map { body => s"""package $packagePath | |object ${objName}JsonSerdes { | import $packagePath.$objName._ + | import sttp.tapir.generic.auto._ |${indent(2)(body)} |}""".stripMargin } + + val schemaObjs = if (schemas.size > 1) schemas.zipWithIndex.map { case (body, idx) => + val priorImports = (0 until idx).map { i => s"import $packagePath.${objName}Schemas${i + 1}._" }.mkString("\n") + val name = s"${objName}Schemas${idx + 1}" + name -> s"""package $packagePath + | + |object $name { + | import $packagePath.$objName._ + | import sttp.tapir.generic.auto._ + |${indent(2)(priorImports)} + |${indent(2)(body)} + |}""".stripMargin + } + else if (schemas.size == 1) + Seq(s"${objName}Schemas" -> s"""package $packagePath + | + |object ${objName}Schemas { + | import $packagePath.$objName._ + | import sttp.tapir.generic.auto._ + |${indent(2)(schemas.head)} + |}""".stripMargin) + else Nil + val endpointsInMain = endpointsByTag.getOrElse(None, "") val maybeSpecificationExtensionKeys = doc.paths @@ -100,21 +130,21 @@ object BasicGenerator { val values = pairs.map(_._2) val `type` = SpecificationExtensionRenderer.renderCombinedType(values) val name = strippedToCamelCase(keyName) - val uncapitalisedName = name.head.toLower + name.tail - val capitalisedName = name.head.toUpper + name.tail + val uncapitalisedName = uncapitalise(name) + val capitalisedName = uncapitalisedName.capitalize s"""type ${capitalisedName}Extension = ${`type`} |val ${uncapitalisedName}ExtensionKey = new sttp.tapir.AttributeKey[${capitalisedName}Extension]("$packagePath.$objName.${capitalisedName}Extension") |""".stripMargin } .mkString("\n") - val serdeImport = if (isSplit && endpointsInMain.nonEmpty) s"\nimport $packagePath.${objName}JsonSerdes._" else "" - val mainObj = s"""| + val extraImports = if (endpointsInMain.nonEmpty) s"$maybeJsonImport$maybeSchemaImport" else "" + val mainObj = s""" |package $packagePath | |object $objName { | - |${indent(2)(imports(normalisedJsonLib) + serdeImport)} + |${indent(2)(imports(normalisedJsonLib) + extraImports)} | |${indent(2)(classDefns)} | @@ -124,7 +154,7 @@ object BasicGenerator { | |} |""".stripMargin - taggedObjs ++ extraObj.map(s"${objName}JsonSerdes" -> _) + (objName -> mainObj) + taggedObjs ++ jsonSerdeObj.map(s"${objName}JsonSerdes" -> _) ++ schemaObjs + (objName -> mainObj) } private[codegen] def imports(jsonSerdeLib: JsonSerdeLib.JsonSerdeLib): String = { @@ -184,4 +214,6 @@ object BasicGenerator { .zipWithIndex .map { case (part, 0) => part; case (part, _) => part.capitalize } .mkString + + def uncapitalise(name: String): String = name.head.toLower +: name.tail } diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/ClassDefinitionGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/ClassDefinitionGenerator.scala index 58babaab19..2fd3dc2395 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/ClassDefinitionGenerator.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/ClassDefinitionGenerator.scala @@ -7,7 +7,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType._ import scala.annotation.tailrec -case class GeneratedClassDefinitions(classRepr: String, serdeRepr: Option[String]) +case class GeneratedClassDefinitions(classRepr: String, serdeRepr: Option[String], schemaRepr: Seq[String]) class ClassDefinitionGenerator { @@ -18,7 +18,8 @@ class ClassDefinitionGenerator { jsonSerdeLib: JsonSerdeLib.JsonSerdeLib = JsonSerdeLib.Circe, jsonParamRefs: Set[String] = Set.empty, fullModelPath: String = "", - validateNonDiscriminatedOneOfs: Boolean = true + validateNonDiscriminatedOneOfs: Boolean = true, + maxSchemasPerFile: Int = 400 ): Option[GeneratedClassDefinitions] = { val allSchemas: Map[String, OpenapiSchemaType] = doc.components.toSeq.flatMap(_.schemas).toMap val allOneOfSchemas = allSchemas.collect { case (name, oneOf: OpenapiSchemaOneOf) => name -> oneOf }.toSeq @@ -40,7 +41,8 @@ class ClassDefinitionGenerator { val adtTypes = adtInheritanceMap.flatMap(_._2).toSeq.distinct.map(name => s"sealed trait $name").mkString("", "\n", "\n") val enumQuerySerdeHelper = if (!generatesQueryParamEnums) "" else enumQuerySerdeHelperDefn(targetScala3) - val postDefns = JsonSerdeGenerator.serdeDefs( + val schemas = SchemaGenerator.generateSchemas(doc, allSchemas, fullModelPath, jsonSerdeLib, maxSchemasPerFile) + val jsonSerdes = JsonSerdeGenerator.serdeDefs( doc, jsonSerdeLib, jsonParamRefs, @@ -63,8 +65,8 @@ class ClassDefinitionGenerator { val helpers = (enumQuerySerdeHelper + adtTypes).linesIterator .filterNot(_.forall(_.isWhitespace)) .mkString("\n") - // Json serdes live in a separate file from the class defns - defns.map(helpers + "\n" + _).map(defStr => GeneratedClassDefinitions(defStr, postDefns)) + // Json serdes & schemas live in separate files from the class defns + defns.map(helpers + "\n" + _).map(defStr => GeneratedClassDefinitions(defStr, jsonSerdes, schemas)) } private def mkMapParentsByChild(allOneOfSchemas: Seq[(String, OpenapiSchemaOneOf)]): Map[String, Seq[String]] = @@ -219,7 +221,7 @@ class ClassDefinitionGenerator { | case ${obj.items.map(_.value).mkString(", ")} |}""".stripMargin :: Nil } else { - val uncapitalisedName = name.head.toLower +: name.tail + val uncapitalisedName = BasicGenerator.uncapitalise(name) val members = obj.items.map { i => s"case object ${i.value} extends $name" } val maybeCodecExtension = jsonSerdeLib match { case _ if !jsonParamRefs.contains(name) && !queryParamRefs.contains(name) => "" diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/JsonSerdeGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/JsonSerdeGenerator.scala index 08510df925..e1add60704 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/JsonSerdeGenerator.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/JsonSerdeGenerator.scala @@ -7,6 +7,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{ OpenapiSchemaArray, OpenapiSchemaBoolean, OpenapiSchemaEnum, + OpenapiSchemaField, OpenapiSchemaMap, OpenapiSchemaNumericType, OpenapiSchemaObject, @@ -86,7 +87,7 @@ object JsonSerdeGenerator { // if lhs has some required non-nullable fields with no default that rhs will never contain, then right cannot be mistaken for left if ((requiredL.keySet -- anyR.keySet).nonEmpty) false else { - // otherwise, if any required field on rhs can't look like the similarly-named field on lhs, then r can't look like l + // otherwise, if any field on rhs required by lhs can't look like the similarly-named field on lhs, then r can't look like l val rForRequiredL = anyR.filter(requiredL.keySet contains _._1) requiredL.forall { case (k, lhsV) => rCanLookLikeL(lhsV.`type`, rForRequiredL(k).`type`) } } @@ -118,8 +119,10 @@ object JsonSerdeGenerator { // Enum serdes are generated at the declaration site case (_, _: OpenapiSchemaEnum) => None // We generate the serde if it's referenced in any json model - case (name, _: OpenapiSchemaObject | _: OpenapiSchemaMap) if allTransitiveJsonParamRefs.contains(name) => - Some(genCirceNamedSerde(name)) + case (name, schema: OpenapiSchemaObject) if allTransitiveJsonParamRefs.contains(name) => + Some(genCirceObjectSerde(name, schema)) + case (name, schema: OpenapiSchemaMap) if allTransitiveJsonParamRefs.contains(name) => + Some(genCirceMapSerde(name, schema)) case (name, schema: OpenapiSchemaOneOf) if allTransitiveJsonParamRefs.contains(name) => Some(genCirceAdtSerde(allSchemas, schema, name, validateNonDiscriminatedOneOfs)) case (_, _: OpenapiSchemaObject | _: OpenapiSchemaMap | _: OpenapiSchemaEnum | _: OpenapiSchemaOneOf) => None @@ -128,11 +131,28 @@ object JsonSerdeGenerator { .map(_.mkString("\n")) } - private def genCirceNamedSerde(name: String): String = { - val uncapitalisedName = name.head.toLower +: name.tail - s"""implicit lazy val ${uncapitalisedName}JsonDecoder: io.circe.Decoder[$name] = io.circe.generic.semiauto.deriveDecoder[$name] + private def genCirceObjectSerde(name: String, schema: OpenapiSchemaObject): String = { + val subs = schema.properties.collect { + case (k, OpenapiSchemaField(`type`: OpenapiSchemaObject, _)) => genCirceObjectSerde(s"$name${k.capitalize}", `type`) + case (k, OpenapiSchemaField(OpenapiSchemaArray(`type`: OpenapiSchemaObject, _), _)) => + genCirceObjectSerde(s"$name${k.capitalize}Item", `type`) + case (k, OpenapiSchemaField(OpenapiSchemaMap(`type`: OpenapiSchemaObject, _), _)) => + genCirceObjectSerde(s"$name${k.capitalize}Item", `type`) + } match { + case Nil => "" + case s => s.mkString("", "\n", "\n") + } + val uncapitalisedName = BasicGenerator.uncapitalise(name) + s"""${subs}implicit lazy val ${uncapitalisedName}JsonDecoder: io.circe.Decoder[$name] = io.circe.generic.semiauto.deriveDecoder[$name] |implicit lazy val ${uncapitalisedName}JsonEncoder: io.circe.Encoder[$name] = io.circe.generic.semiauto.deriveEncoder[$name]""".stripMargin } + private def genCirceMapSerde(name: String, schema: OpenapiSchemaMap): String = { + val subs = schema.items match { + case `type`: OpenapiSchemaObject => Some(genCirceObjectSerde(s"${name}ObjectsItem", `type`)) + case _ => None + } + subs.fold("")("\n" + _) + } private def genCirceAdtSerde( allSchemas: Map[String, OpenapiSchemaType], @@ -140,7 +160,7 @@ object JsonSerdeGenerator { name: String, validateNonDiscriminatedOneOfs: Boolean ): String = { - val uncapitalisedName = name.head.toLower +: name.tail + val uncapitalisedName = BasicGenerator.uncapitalise(name) schema match { case OpenapiSchemaOneOf(_, Some(discriminator)) => @@ -256,7 +276,7 @@ object JsonSerdeGenerator { } private def genJsoniterClassSerde(supertypes: Seq[OpenapiSchemaOneOf])(name: String): String = { - val uncapitalisedName = name.head.toLower +: name.tail + val uncapitalisedName = BasicGenerator.uncapitalise(name) if (supertypes.exists(_.discriminator.isDefined)) throw new NotImplementedError( s"A class cannot be used both in a oneOf with discriminator and at the top level when using jsoniter serdes at $name" @@ -266,13 +286,13 @@ object JsonSerdeGenerator { } private def genJsoniterEnumSerde(name: String): String = { - val uncapitalisedName = name.head.toLower +: name.tail + val uncapitalisedName = BasicGenerator.uncapitalise(name) s""" |implicit lazy val ${uncapitalisedName}JsonCodec: $jsoniterPkgCore.JsonValueCodec[${name}] = $jsoniterPkgMacros.JsonCodecMaker.make($jsoniteEnumConfig.withDiscriminatorFieldName(scala.None))""".stripMargin } private def genJsoniterNamedSerde(name: String): String = { - val uncapitalisedName = name.head.toLower +: name.tail + val uncapitalisedName = BasicGenerator.uncapitalise(name) s""" |implicit lazy val ${uncapitalisedName}JsonCodec: $jsoniterPkgCore.JsonValueCodec[$name] = $jsoniterPkgMacros.JsonCodecMaker.make($jsoniterBaseConfig)""".stripMargin } @@ -285,7 +305,7 @@ object JsonSerdeGenerator { validateNonDiscriminatedOneOfs: Boolean ): String = { val fullPathPrefix = maybeFullModelPath.map(_ + ".").getOrElse("") - val uncapitalisedName = name.head.toLower +: name.tail + val uncapitalisedName = BasicGenerator.uncapitalise(name) schema match { case OpenapiSchemaOneOf(_, Some(discriminator)) => def subtypeNames = schema.types.map { @@ -321,7 +341,7 @@ object JsonSerdeGenerator { if (validateNonDiscriminatedOneOfs) checkForSoundness(allSchemas)(schema.types.map(_.asInstanceOf[OpenapiSchemaRef])) val childNameAndSerde = schemas.collect { case ref: OpenapiSchemaRef => val name = ref.stripped - name -> s"${name.head.toLower +: name.tail}JsonCodec" + name -> s"${BasicGenerator.uncapitalise(name)}JsonCodec" } val childSerdes = childNameAndSerde.map(_._2) val doDecode = childSerdes.mkString("List(\n ", ",\n ", ")\n") + diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/SchemaGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/SchemaGenerator.scala new file mode 100644 index 0000000000..53a3b44c3e --- /dev/null +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/SchemaGenerator.scala @@ -0,0 +1,247 @@ +package sttp.tapir.codegen + +import sttp.tapir.codegen.BasicGenerator.indent +import sttp.tapir.codegen.JsonSerdeLib.JsonSerdeLib +import sttp.tapir.codegen.openapi.models.OpenapiModels.OpenapiDocument +import sttp.tapir.codegen.openapi.models.OpenapiSchemaType +import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{ + Discriminator, + OpenapiSchemaAllOf, + OpenapiSchemaAny, + OpenapiSchemaAnyOf, + OpenapiSchemaArray, + OpenapiSchemaConstantString, + OpenapiSchemaEnum, + OpenapiSchemaField, + OpenapiSchemaMap, + OpenapiSchemaNot, + OpenapiSchemaObject, + OpenapiSchemaOneOf, + OpenapiSchemaRef, + OpenapiSchemaSimpleType +} + +import scala.collection.mutable + +object SchemaGenerator { + + def generateSchemas( + doc: OpenapiDocument, + allSchemas: Map[String, OpenapiSchemaType], + fullModelPath: String, + jsonSerdeLib: JsonSerdeLib, + maxSchemasPerFile: Int + ): Seq[String] = { + def schemaContainsAny(schema: OpenapiSchemaType): Boolean = schema match { + case _: OpenapiSchemaAny => true + case OpenapiSchemaArray(items, _) => schemaContainsAny(items) + case OpenapiSchemaMap(items, _) => schemaContainsAny(items) + case OpenapiSchemaObject(fs, _, _) => fs.values.map(_.`type`).exists(schemaContainsAny) + case OpenapiSchemaOneOf(types, _) => types.exists(schemaContainsAny) + case OpenapiSchemaAllOf(types) => types.exists(schemaContainsAny) + case OpenapiSchemaAnyOf(types) => types.exists(schemaContainsAny) + case OpenapiSchemaNot(item) => schemaContainsAny(item) + case _: OpenapiSchemaSimpleType | _: OpenapiSchemaEnum | _: OpenapiSchemaConstantString | _: OpenapiSchemaRef => false + } + val schemasWithAny = allSchemas.filter { case (_, schema) => + schemaContainsAny(schema) + } + val maybeAnySchema: Option[(OpenapiSchemaType, String)] = + if (schemasWithAny.isEmpty) None + else if (jsonSerdeLib == JsonSerdeLib.Circe) + Some( + OpenapiSchemaAny( + false + ) -> "implicit lazy val anyTapirSchema: sttp.tapir.Schema[io.circe.Json] = sttp.tapir.Schema.any[io.circe.Json]" + ) + else throw new NotImplementedError("any not implemented for json libs other than circe") + val openApiSchemasWithTapirSchemas = doc.components + .map(_.schemas.map { + case (name, _: OpenapiSchemaEnum) => + name -> s"implicit lazy val ${BasicGenerator.uncapitalise(name)}TapirSchema: sttp.tapir.Schema[$name] = sttp.tapir.Schema.derived" + case (name, obj: OpenapiSchemaObject) => name -> schemaForObject(name, obj) + case (name, schema: OpenapiSchemaMap) => name -> schemaForMap(name, schema) + case (name, schema: OpenapiSchemaOneOf) => + name -> genADTSchema(name, schema, if (fullModelPath.isEmpty) None else Some(fullModelPath)) + case (n, x) => throw new NotImplementedError(s"Only objects, enums, maps and oneOf supported! (for $n found ${x})") + }) + .toSeq + .flatMap(maybeAnySchema.toSeq ++ _) + .toMap + + // The algorithm here is to aviod mutually references between objects. It goes like this: + // 1) Find all 'rings' -- that is, sets of mutually-recursive object references that will need to be defined in the same object + // (e.g. the schemas for `case class A(maybeB: Option[B])` and `case class B(maybeA: Option[A])` would need to be defined together) + val groupedByRing = constructRings(allSchemas) + // 2) Order the definitions, such that objects appear before any places they're referenced + val orderedLayers = orderLayers(groupedByRing) + // 3) Group the definitions into at most `maxSchemasPerFile`, whilst avoiding splitting groups across files + val foldedLayers = foldLayers(maxSchemasPerFile)(orderedLayers) + // Our output will now only need to imports the 'earlier' files into the 'later' files, and _not_ vice verse + maybeAnySchema.map(_._2).toSeq ++ foldedLayers.map(ring => ring.map(openApiSchemasWithTapirSchemas apply _._1).mkString("\n")) + } + // Group files into chunks of size < maxLayerSize + private def foldLayers(maxSchemasPerFile: Int)(layers: Seq[Seq[(String, OpenapiSchemaType)]]): Seq[Seq[(String, OpenapiSchemaType)]] = { + val maxLayerSize = maxSchemasPerFile + layers.foldLeft(Seq.empty[Seq[(String, OpenapiSchemaType)]]) { (acc, next) => + if (acc.isEmpty) Seq(next) + else if (acc.last.size + next.size >= maxLayerSize) acc :+ next + else { + val first :+ last = acc + first :+ (last ++ next) + } + } + } + // Need to order rings so that leaf schemas are defined before parents + private def orderLayers(layers: Seq[Seq[(String, OpenapiSchemaType)]]): Seq[Seq[(String, OpenapiSchemaType)]] = { + def getDirectChildren(schema: OpenapiSchemaType): Set[String] = schema match { + case r: OpenapiSchemaRef => Set(r.stripped) + case _: OpenapiSchemaSimpleType | _: OpenapiSchemaEnum | _: OpenapiSchemaConstantString => Set.empty[String] + case OpenapiSchemaArray(items, _) => getDirectChildren(items) + case OpenapiSchemaNot(items) => getDirectChildren(items) + case OpenapiSchemaMap(items, _) => getDirectChildren(items) + case OpenapiSchemaOneOf(items, _) => items.flatMap(getDirectChildren).toSet + case OpenapiSchemaAnyOf(items) => items.flatMap(getDirectChildren).toSet + case OpenapiSchemaAllOf(items) => items.flatMap(getDirectChildren).toSet + case OpenapiSchemaObject(kvs, _, _) => kvs.values.flatMap(f => getDirectChildren(f.`type`)).toSet + } + val withDirectChildren = layers.map { layer => + layer.map { case (k, v) => (k, v, getDirectChildren(v)) } + } + val initialSet: mutable.Set[Seq[(String, OpenapiSchemaType, Set[String])]] = mutable.Set(withDirectChildren: _*) + val acquired = mutable.Set.empty[String] + val res = mutable.ArrayBuffer.empty[Seq[(String, OpenapiSchemaType, Set[String])]] + while (initialSet.nonEmpty) { + // Find all schema 'rings' that depend only on 'already aquired' schemas and/or other members of the same ring + val nextLayers = initialSet.filter(g => g.forall(_._3.forall(c => acquired.contains(c) || g.map(_._1).contains(c)))) + // remove these from the initial set, add to the 'acquired' set & res seq + initialSet --= nextLayers + // sorting here for output stability + res ++= nextLayers.toSeq.sortBy(_.head._1) + acquired ++= nextLayers.flatMap(_.map(_._1)).toSet + if (initialSet.nonEmpty && nextLayers.isEmpty) + throw new IllegalStateException("Cannot order layers until mutually-recursive references have been resolved.") + } + + res.map(_.map { case (k, v, _) => k -> v }) + } + // finds all mutually-recursive references, grouping mutually-recursive schemas into a single 'layer' seq + private def constructRings(allSchemas: Map[String, OpenapiSchemaType]): Seq[Seq[(String, OpenapiSchemaType)]] = { + val initialSet: mutable.Set[(String, OpenapiSchemaType)] = mutable.Set(allSchemas.toSeq: _*) + val res = mutable.ArrayBuffer.empty[Seq[(String, OpenapiSchemaType)]] + while (initialSet.nonEmpty) { + val nextRing = mutable.ArrayBuffer.empty[(String, OpenapiSchemaType)] + def recurse(next: (String, OpenapiSchemaType)): Unit = { + val (nextName, nextSchema) = next + nextRing += next + initialSet -= next + // Find all simple reference loops for a single candidate + val refs = getReferencesToXInY(allSchemas, nextName, nextSchema, Set.empty, Seq(nextName)) + val newRefs = refs.flatMap(r => initialSet.find(_._1 == r)) + // New candidates may themselves have mutually-recursive references to other candidates that _don't_ have + // 'loop' references to initial candidate, so we need to recurse here - e.g for + // `case class A(maybeB: Option[B])`, `case class B(maybeA: Option[A], maybeC: Option[C])`, `case class C(maybeB: Option[B])` + // we have the loops A -> B -> A, and B -> C -> B, but the loop A -> B -> C -> B -> A would not be detected by `getReferencesToXInY` + // Fusing all simple loops should be a valid way of constructing the equivalence set. + newRefs foreach recurse + } + // Select next candidate. Order lexicographically for stable output + val next = initialSet.minBy(_._1) + recurse(next) + res += nextRing.sortBy(_._1) + } + res.toSeq + } + // find all simple reference loops starting at a a single schema (e.g. A -> B -> C -> A) + private def getReferencesToXInY( + allSchemas: Map[String, OpenapiSchemaType], + referrent: String, // The stripped ref of the schema we're looking for references to + referenceCandidate: OpenapiSchemaType, // candidate for mutually-recursive referrence + checked: Set[String], // refs we've already checked + maybeRefs: Seq[String] // chain of refs from referrent -> [...maybeRefs] -> referenceCandidate + ): Set[String] = referenceCandidate match { + case ref: OpenapiSchemaRef => + val stripped = ref.stripped + // in this case, we have a chain of referrences from referrent -> [...maybeRefs] -> referrent, creating a mutually-recursive loop + if (stripped == referrent) maybeRefs.toSet + // if already checked, skip + else if (checked contains stripped) Set.empty + // else add the ref to 'maybeRefs' chain and descend + else { + allSchemas + .get(ref.stripped) + .map(getReferencesToXInY(allSchemas, referrent, _, checked + stripped, maybeRefs :+ stripped)) + .toSet + .flatten + } + // these types cannot contain a referrence + case _: OpenapiSchemaSimpleType | _: OpenapiSchemaEnum | _: OpenapiSchemaConstantString => Set.empty + // descend into the sole child type + case OpenapiSchemaArray(items, _) => getReferencesToXInY(allSchemas, referrent, items, checked, maybeRefs) + case OpenapiSchemaNot(items) => getReferencesToXInY(allSchemas, referrent, items, checked, maybeRefs) + case OpenapiSchemaMap(items, _) => getReferencesToXInY(allSchemas, referrent, items, checked, maybeRefs) + // descend into all child types + case OpenapiSchemaOneOf(items, _) => items.flatMap(getReferencesToXInY(allSchemas, referrent, _, checked, maybeRefs)).toSet + case OpenapiSchemaAllOf(items) => items.flatMap(getReferencesToXInY(allSchemas, referrent, _, checked, maybeRefs)).toSet + case OpenapiSchemaAnyOf(items) => items.flatMap(getReferencesToXInY(allSchemas, referrent, _, checked, maybeRefs)).toSet + case OpenapiSchemaObject(kvs, _, _) => + kvs.values.flatMap(v => getReferencesToXInY(allSchemas, referrent, v.`type`, checked, maybeRefs)).toSet + } + + private def schemaForObject(name: String, schema: OpenapiSchemaObject): String = { + val subs = schema.properties.collect { + case (k, OpenapiSchemaField(`type`: OpenapiSchemaObject, _)) => schemaForObject(s"$name${k.capitalize}", `type`) + case (k, OpenapiSchemaField(OpenapiSchemaArray(`type`: OpenapiSchemaObject, _), _)) => + schemaForObject(s"$name${k.capitalize}Item", `type`) + case (k, OpenapiSchemaField(OpenapiSchemaMap(`type`: OpenapiSchemaObject, _), _)) => + schemaForObject(s"$name${k.capitalize}Item", `type`) + } match { + case Nil => "" + case s => s.mkString("", "\n", "\n") + } + s"${subs}implicit lazy val ${BasicGenerator.uncapitalise(name)}TapirSchema: sttp.tapir.Schema[$name] = sttp.tapir.Schema.derived" + } + private def schemaForMap(name: String, schema: OpenapiSchemaMap): String = { + val subs = schema.items match { + case `type`: OpenapiSchemaObject => Some(schemaForObject(s"${name}ObjectsItem", `type`)) + case _ => None + } + subs.fold("")("\n" + _) + } + private def genADTSchema(name: String, schema: OpenapiSchemaOneOf, fullModelPath: Option[String]): String = { + val schemaImpl = schema match { + case OpenapiSchemaOneOf(_, None) => "sttp.tapir.Schema.derived" + case OpenapiSchemaOneOf(_, Some(Discriminator(propertyName, maybeMapping))) => + val mapping = + maybeMapping.map(_.map { case (propName, fullRef) => propName -> fullRef.stripPrefix("#/components/schemas/") }).getOrElse { + schema.types.map { + case ref: OpenapiSchemaRef => ref.stripped -> ref.stripped + case other => + throw new IllegalArgumentException(s"oneOf subtypes must be refs to explicit schema models, found $other for $name") + }.toMap + } + val fullModelPrefix = fullModelPath.map(_ + ".") getOrElse "" + val fields = mapping + .map { case (propValue, fullRef) => + val fullClassName = fullModelPrefix + fullRef + s""""$propValue" -> sttp.tapir.SchemaType.SRef(sttp.tapir.Schema.SName("$fullClassName"))""" + } + .mkString(",\n") + s"""{ + | val derived = implicitly[sttp.tapir.generic.Derived[sttp.tapir.Schema[$name]]].value + | derived.schemaType match { + | case s: sttp.tapir.SchemaType.SCoproduct[_] => derived.copy(schemaType = s.addDiscriminatorField( + | sttp.tapir.FieldName("$propertyName"), + | sttp.tapir.Schema.string, + | Map( + |${indent(8)(fields)} + | ) + | )) + | case _ => throw new IllegalStateException("Derived schema for $name should be a coproduct") + | } + |}""".stripMargin + } + + s"implicit lazy val ${BasicGenerator.uncapitalise(name)}TapirSchema: sttp.tapir.Schema[$name] = ${schemaImpl}" + } +} diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala index 1b6afb8a58..5f8ae2b027 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala @@ -16,7 +16,8 @@ class BasicGeneratorSpec extends CompileCheckTestBase { targetScala3 = false, useHeadTagForObjectNames = useHeadTagForObjectNames, jsonSerdeLib = jsonSerdeLib, - validateNonDiscriminatedOneOfs = true + validateNonDiscriminatedOneOfs = true, + maxSchemasPerFile = 400 ) } def gen( @@ -30,14 +31,13 @@ class BasicGeneratorSpec extends CompileCheckTestBase { jsonSerdeLib = jsonSerdeLib ) val main = genned("TapirGeneratedEndpoints") - val maybeJson = genned.get("TapirGeneratedEndpointsJsonSerdes") - main + maybeJson.map("\n" + _).getOrElse("") + val schemaKeys = genned.keys.filter(_.startsWith("TapirGeneratedEndpointsSchemas")).toSeq.sorted + val maybeExtra = (schemaKeys.map(genned) ++ genned.get("TapirGeneratedEndpointsJsonSerdes")).mkString("\n") + main + "\n" + maybeExtra } def testJsonLib(jsonSerdeLib: String) = { it should s"generate the bookshop example using ${jsonSerdeLib} serdes" in { - val res = gen(TestHelpers.myBookshopDoc, useHeadTagForObjectNames = false, jsonSerdeLib = jsonSerdeLib) -// println(res) - res shouldCompile () + gen(TestHelpers.myBookshopDoc, useHeadTagForObjectNames = false, jsonSerdeLib = jsonSerdeLib) shouldCompile () } it should s"split outputs by tag if useHeadTagForObjectNames = true using ${jsonSerdeLib} serdes" in { @@ -46,20 +46,22 @@ class BasicGeneratorSpec extends CompileCheckTestBase { useHeadTagForObjectNames = true, jsonSerdeLib = jsonSerdeLib ) - val schemas = generated("TapirGeneratedEndpoints") + val models = generated("TapirGeneratedEndpoints") val serdes = generated("TapirGeneratedEndpointsJsonSerdes") + val schemas = generated("TapirGeneratedEndpointsSchemas") val endpoints = generated("Bookshop") // schema file on its own should compile - schemas shouldCompile () + models shouldCompile () // schema file should contain no endpoint definitions - schemas.linesIterator.count(_.matches("""^\s*endpoint""")) shouldEqual 0 + models.linesIterator.count(_.matches("""^\s*endpoint""")) shouldEqual 0 // schema file with serde file should compile - (schemas + "\n" + serdes) shouldCompile () + (models + "\n" + serdes) shouldCompile () + // schema file with serde file & schema file should compile + (models + "\n" + serdes + "\n" + schemas) shouldCompile () // Bookshop file should contain all endpoint definitions endpoints.linesIterator.count(_.matches("""^\s*endpoint""")) shouldEqual 3 - // endpoint file depends on schema file. For simplicity of testing, just strip the package declaration from the - // endpoint file, and concat the two, before testing for compilation - (schemas + "\n" + serdes + "\n" + endpoints) shouldCompile () + // endpoint file depends on models, serdes & schemas + (models + "\n" + serdes + "\n" + schemas + "\n" + endpoints) shouldCompile () } it should s"compile endpoints with enum query params using ${jsonSerdeLib} serdes" in { diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala index c728ea631d..729179b6d6 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala @@ -287,11 +287,15 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase { ) val gen = new ClassDefinitionGenerator() + def concatted(res: GeneratedClassDefinitions): String = { + (res.classRepr + res.serdeRepr.fold("")("\n" + _)).linesIterator.filterNot(_.trim.isEmpty).mkString("\n") + } val res = gen .classDefs(doc, true, jsonParamRefs = Set("Test")) - .map(_.classRepr.linesIterator.filterNot(_.trim.isEmpty).mkString("\n")) - val resWithQueryParamCodec = gen.classDefs(doc, true, queryParamRefs = Set("Test"), jsonParamRefs = Set("Test")) - .map(_.classRepr.linesIterator.filterNot(_.trim.isEmpty).mkString("\n")) + .map(concatted) + val resWithQueryParamCodec = gen + .classDefs(doc, true, queryParamRefs = Set("Test"), jsonParamRefs = Set("Test")) + .map(concatted) // can't just check whether these compile, because our tests only run on scala 2.12 - so instead just eyeball it... res shouldBe Some("""enum Test derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec { | case enum1, enum2 @@ -481,15 +485,26 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase { } it should "generate ADTs for oneOf schemas (jsoniter)" in { + val imports = + """import sttp.tapir.generic.auto._ + |""".stripMargin val gen = new ClassDefinitionGenerator() def testOK(doc: OpenapiDocument) = { - val GeneratedClassDefinitions(res, extra) = - gen.classDefs(doc, false, jsonSerdeLib = JsonSerdeLib.Jsoniter, jsonParamRefs = Set("ReqWithVariants")).get + val GeneratedClassDefinitions(res, jsonSerdes, schemas) = + gen + .classDefs( + doc, + false, + jsonSerdeLib = JsonSerdeLib.Jsoniter, + jsonParamRefs = Set("ReqWithVariants"), + fullModelPath = "foo.bar.baz" + ) + .get - val fullRes = (res + "\n" + extra.get) + val fullRes = imports + res + "\n" + jsonSerdes.get res shouldCompile () fullRes shouldCompile () - extra.get should include( + jsonSerdes.get should include( """implicit lazy val reqWithVariantsCodec: com.github.plokhotnyuk.jsoniter_scala.core.JsonValueCodec[ReqWithVariants] = com.github.plokhotnyuk.jsoniter_scala.macros.JsonCodecMaker.make(com.github.plokhotnyuk.jsoniter_scala.macros.CodecMakerConfig.withAllowRecursiveTypes(true).withTransientEmpty(false).withRequireCollectionFields(true).withRequireDiscriminatorFirst(false).withDiscriminatorFieldName(Some("type")))""" ) } @@ -503,13 +518,16 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase { } it should "generate ADTs for oneOf schemas (circe)" in { + val imports = + """import sttp.tapir.generic.auto._ + |""".stripMargin val gen = new ClassDefinitionGenerator() def testOK(doc: OpenapiDocument) = { - val GeneratedClassDefinitions(res, extra) = + val GeneratedClassDefinitions(res, jsonSerdes, schemas) = gen.classDefs(doc, false, jsonSerdeLib = JsonSerdeLib.Circe, jsonParamRefs = Set("ReqWithVariants")).get - val fullRes = (res + "\n" + extra.get) - fullRes shouldCompile () + val fullRes = (res + "\n" + jsonSerdes.get) + (imports + fullRes) shouldCompile () val expectedLines = Seq( """implicit lazy val reqWithVariantsJsonEncoder: io.circe.Encoder[ReqWithVariants]""", """case x: ReqSubtype1 => io.circe.Encoder[ReqSubtype1].apply(x).mapObject(_.add("type", io.circe.Json.fromString("ReqSubtype1")))""", diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala index 184a21b4f9..a0cdfd8588 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala @@ -240,7 +240,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { targetScala3 = false, useHeadTagForObjectNames = false, jsonSerdeLib = "circe", - validateNonDiscriminatedOneOfs = true + validateNonDiscriminatedOneOfs = true, + maxSchemasPerFile = 400 )("TapirGeneratedEndpoints") generatedCode should include( """file: sttp.model.Part[java.io.File]""" @@ -260,7 +261,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { targetScala3 = false, useHeadTagForObjectNames = false, jsonSerdeLib = "circe", - validateNonDiscriminatedOneOfs = true + validateNonDiscriminatedOneOfs = true, + maxSchemasPerFile = 400 )("TapirGeneratedEndpoints") generatedCode shouldCompile () val expectedAttrDecls = Seq( diff --git a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala index 17ba665f4f..0299642f55 100644 --- a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala +++ b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala @@ -12,6 +12,7 @@ trait OpenapiCodegenKeys { lazy val openapiJsonSerdeLib = settingKey[String]("The lib to use for json serdes. Supports 'circe' and 'jsoniter'.") lazy val openapiValidateNonDiscriminatedOneOfs = settingKey[Boolean]("Whether to fail if variants of a oneOf without a discriminator cannot be disambiguated..") + lazy val openapiMaxSchemasPerFile = settingKey[Int]("Maximum number of schemas to generate for a single file") lazy val generateTapirDefinitions = taskKey[Unit]("The task that generates tapir definitions based on the input swagger file.") } diff --git a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala index beea28b176..9c766a6f1b 100644 --- a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala +++ b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala @@ -28,7 +28,8 @@ object OpenapiCodegenPlugin extends AutoPlugin { openapiObject := "TapirGeneratedEndpoints", openapiUseHeadTagForObjectName := false, openapiJsonSerdeLib := "circe", - openapiValidateNonDiscriminatedOneOfs := true + openapiValidateNonDiscriminatedOneOfs := true, + openapiMaxSchemasPerFile := 400 ) private def codegen = Def.task { @@ -41,6 +42,7 @@ object OpenapiCodegenPlugin extends AutoPlugin { openapiUseHeadTagForObjectName, openapiJsonSerdeLib, openapiValidateNonDiscriminatedOneOfs, + openapiMaxSchemasPerFile, sourceManaged, streams, scalaVersion @@ -52,6 +54,7 @@ object OpenapiCodegenPlugin extends AutoPlugin { useHeadTagForObjectName: Boolean, jsonSerdeLib: String, validateNonDiscriminatedOneOfs: Boolean, + maxSchemasPerFile: Int, srcDir: File, taskStreams: TaskStreams, sv: String @@ -63,6 +66,7 @@ object OpenapiCodegenPlugin extends AutoPlugin { useHeadTagForObjectName, jsonSerdeLib, validateNonDiscriminatedOneOfs, + maxSchemasPerFile, srcDir, taskStreams.cacheDirectory, sv.startsWith("3") diff --git a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala index 42d19986b5..c4d55c257f 100644 --- a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala +++ b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala @@ -12,6 +12,7 @@ case class OpenapiCodegenTask( useHeadTagForObjectName: Boolean, jsonSerdeLib: String, validateNonDiscriminatedOneOfs: Boolean, + maxSchemasPerFile: Int, dir: File, cacheDir: File, targetScala3: Boolean @@ -53,7 +54,8 @@ case class OpenapiCodegenTask( targetScala3, useHeadTagForObjectName, jsonSerdeLib, - validateNonDiscriminatedOneOfs + validateNonDiscriminatedOneOfs, + maxSchemasPerFile ) .map { case (objectName, fileBody) => val file = directory / s"$objectName.scala" diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/Expected.scala.txt b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/Expected.scala.txt index 3c97b166d9..21d39490ba 100644 --- a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/Expected.scala.txt +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/Expected.scala.txt @@ -7,8 +7,9 @@ object TapirGeneratedEndpoints { import sttp.tapir.generic.auto._ import sttp.tapir.json.circe._ import io.circe.generic.semiauto._ - + import sttp.tapir.generated.TapirGeneratedEndpointsJsonSerdes._ + import TapirGeneratedEndpointsSchemas._ sealed trait ADTWithoutDiscriminator sealed trait ADTWithDiscriminator diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedJsonSerdes.scala.txt b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedJsonSerdes.scala.txt new file mode 100644 index 0000000000..d688bbf900 --- /dev/null +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedJsonSerdes.scala.txt @@ -0,0 +1,53 @@ +package sttp.tapir.generated + +object TapirGeneratedEndpointsJsonSerdes { + import sttp.tapir.generated.TapirGeneratedEndpoints._ + import sttp.tapir.generic.auto._ + implicit lazy val aDTWithDiscriminatorJsonEncoder: io.circe.Encoder[ADTWithDiscriminator] = io.circe.Encoder.instance { + case x: SubtypeWithD1 => io.circe.Encoder[SubtypeWithD1].apply(x).mapObject(_.add("type", io.circe.Json.fromString("SubA"))) + case x: SubtypeWithD2 => io.circe.Encoder[SubtypeWithD2].apply(x).mapObject(_.add("type", io.circe.Json.fromString("SubB"))) + } + implicit lazy val aDTWithDiscriminatorJsonDecoder: io.circe.Decoder[ADTWithDiscriminator] = io.circe.Decoder { (c: io.circe.HCursor) => + for { + discriminator <- c.downField("type").as[String] + res <- discriminator match { + case "SubA" => c.as[SubtypeWithD1] + case "SubB" => c.as[SubtypeWithD2] + } + } yield res + } + implicit lazy val subtypeWithoutD1JsonDecoder: io.circe.Decoder[SubtypeWithoutD1] = io.circe.generic.semiauto.deriveDecoder[SubtypeWithoutD1] + implicit lazy val subtypeWithoutD1JsonEncoder: io.circe.Encoder[SubtypeWithoutD1] = io.circe.generic.semiauto.deriveEncoder[SubtypeWithoutD1] + implicit lazy val subtypeWithD1JsonDecoder: io.circe.Decoder[SubtypeWithD1] = io.circe.generic.semiauto.deriveDecoder[SubtypeWithD1] + implicit lazy val subtypeWithD1JsonEncoder: io.circe.Encoder[SubtypeWithD1] = io.circe.generic.semiauto.deriveEncoder[SubtypeWithD1] + implicit lazy val aDTWithDiscriminatorNoMappingJsonEncoder: io.circe.Encoder[ADTWithDiscriminatorNoMapping] = io.circe.Encoder.instance { + case x: SubtypeWithD1 => io.circe.Encoder[SubtypeWithD1].apply(x).mapObject(_.add("type", io.circe.Json.fromString("SubtypeWithD1"))) + case x: SubtypeWithD2 => io.circe.Encoder[SubtypeWithD2].apply(x).mapObject(_.add("type", io.circe.Json.fromString("SubtypeWithD2"))) + } + implicit lazy val aDTWithDiscriminatorNoMappingJsonDecoder: io.circe.Decoder[ADTWithDiscriminatorNoMapping] = io.circe.Decoder { (c: io.circe.HCursor) => + for { + discriminator <- c.downField("type").as[String] + res <- discriminator match { + case "SubtypeWithD1" => c.as[SubtypeWithD1] + case "SubtypeWithD2" => c.as[SubtypeWithD2] + } + } yield res + } + implicit lazy val subtypeWithoutD3JsonDecoder: io.circe.Decoder[SubtypeWithoutD3] = io.circe.generic.semiauto.deriveDecoder[SubtypeWithoutD3] + implicit lazy val subtypeWithoutD3JsonEncoder: io.circe.Encoder[SubtypeWithoutD3] = io.circe.generic.semiauto.deriveEncoder[SubtypeWithoutD3] + implicit lazy val subtypeWithoutD2JsonDecoder: io.circe.Decoder[SubtypeWithoutD2] = io.circe.generic.semiauto.deriveDecoder[SubtypeWithoutD2] + implicit lazy val subtypeWithoutD2JsonEncoder: io.circe.Encoder[SubtypeWithoutD2] = io.circe.generic.semiauto.deriveEncoder[SubtypeWithoutD2] + implicit lazy val subtypeWithD2JsonDecoder: io.circe.Decoder[SubtypeWithD2] = io.circe.generic.semiauto.deriveDecoder[SubtypeWithD2] + implicit lazy val subtypeWithD2JsonEncoder: io.circe.Encoder[SubtypeWithD2] = io.circe.generic.semiauto.deriveEncoder[SubtypeWithD2] + implicit lazy val aDTWithoutDiscriminatorJsonEncoder: io.circe.Encoder[ADTWithoutDiscriminator] = io.circe.Encoder.instance { + case x: SubtypeWithoutD1 => io.circe.Encoder[SubtypeWithoutD1].apply(x) + case x: SubtypeWithoutD2 => io.circe.Encoder[SubtypeWithoutD2].apply(x) + case x: SubtypeWithoutD3 => io.circe.Encoder[SubtypeWithoutD3].apply(x) + } + implicit lazy val aDTWithoutDiscriminatorJsonDecoder: io.circe.Decoder[ADTWithoutDiscriminator] = + List[io.circe.Decoder[ADTWithoutDiscriminator]]( + io.circe.Decoder[SubtypeWithoutD1].asInstanceOf[io.circe.Decoder[ADTWithoutDiscriminator]], + io.circe.Decoder[SubtypeWithoutD2].asInstanceOf[io.circe.Decoder[ADTWithoutDiscriminator]], + io.circe.Decoder[SubtypeWithoutD3].asInstanceOf[io.circe.Decoder[ADTWithoutDiscriminator]] + ).reduceLeft(_ or _) +} diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedSchemas.scala.txt b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedSchemas.scala.txt new file mode 100644 index 0000000000..dd0c6da56c --- /dev/null +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/ExpectedSchemas.scala.txt @@ -0,0 +1,40 @@ +package sttp.tapir.generated + +object TapirGeneratedEndpointsSchemas { + import sttp.tapir.generated.TapirGeneratedEndpoints._ + import sttp.tapir.generic.auto._ + implicit lazy val subtypeWithD1TapirSchema: sttp.tapir.Schema[SubtypeWithD1] = sttp.tapir.Schema.derived + implicit lazy val subtypeWithD2TapirSchema: sttp.tapir.Schema[SubtypeWithD2] = sttp.tapir.Schema.derived + implicit lazy val subtypeWithoutD1TapirSchema: sttp.tapir.Schema[SubtypeWithoutD1] = sttp.tapir.Schema.derived + implicit lazy val subtypeWithoutD2TapirSchema: sttp.tapir.Schema[SubtypeWithoutD2] = sttp.tapir.Schema.derived + implicit lazy val subtypeWithoutD3TapirSchema: sttp.tapir.Schema[SubtypeWithoutD3] = sttp.tapir.Schema.derived + implicit lazy val aDTWithDiscriminatorTapirSchema: sttp.tapir.Schema[ADTWithDiscriminator] = { + val derived = implicitly[sttp.tapir.generic.Derived[sttp.tapir.Schema[ADTWithDiscriminator]]].value + derived.schemaType match { + case s: sttp.tapir.SchemaType.SCoproduct[_] => derived.copy(schemaType = s.addDiscriminatorField( + sttp.tapir.FieldName("type"), + sttp.tapir.Schema.string, + Map( + "SubA" -> sttp.tapir.SchemaType.SRef(sttp.tapir.Schema.SName("sttp.tapir.generated.TapirGeneratedEndpoints.SubtypeWithD1")), + "SubB" -> sttp.tapir.SchemaType.SRef(sttp.tapir.Schema.SName("sttp.tapir.generated.TapirGeneratedEndpoints.SubtypeWithD2")) + ) + )) + case _ => throw new IllegalStateException("Derived schema for ADTWithDiscriminator should be a coproduct") + } + } + implicit lazy val aDTWithDiscriminatorNoMappingTapirSchema: sttp.tapir.Schema[ADTWithDiscriminatorNoMapping] = { + val derived = implicitly[sttp.tapir.generic.Derived[sttp.tapir.Schema[ADTWithDiscriminatorNoMapping]]].value + derived.schemaType match { + case s: sttp.tapir.SchemaType.SCoproduct[_] => derived.copy(schemaType = s.addDiscriminatorField( + sttp.tapir.FieldName("type"), + sttp.tapir.Schema.string, + Map( + "SubtypeWithD1" -> sttp.tapir.SchemaType.SRef(sttp.tapir.Schema.SName("sttp.tapir.generated.TapirGeneratedEndpoints.SubtypeWithD1")), + "SubtypeWithD2" -> sttp.tapir.SchemaType.SRef(sttp.tapir.Schema.SName("sttp.tapir.generated.TapirGeneratedEndpoints.SubtypeWithD2")) + ) + )) + case _ => throw new IllegalStateException("Derived schema for ADTWithDiscriminatorNoMapping should be a coproduct") + } + } + implicit lazy val aDTWithoutDiscriminatorTapirSchema: sttp.tapir.Schema[ADTWithoutDiscriminator] = sttp.tapir.Schema.derived +} diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/build.sbt b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/build.sbt index 42b7170a62..953265c708 100644 --- a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/build.sbt +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/build.sbt @@ -19,16 +19,23 @@ libraryDependencies ++= Seq( import scala.io.Source TaskKey[Unit]("check") := { - val generatedCode = - Source.fromFile("target/scala-2.13/src_managed/main/sbt-openapi-codegen/TapirGeneratedEndpoints.scala").getLines.mkString("\n") - val expected = Source.fromFile("Expected.scala.txt").getLines.mkString("\n") - val generatedTrimmed = generatedCode.linesIterator.zipWithIndex.filterNot(_._1.forall(_.isWhitespace)).map{ case (a, i) => a.trim -> i }.toSeq - val expectedTrimmed = expected.linesIterator.filterNot(_.forall(_.isWhitespace)).map(_.trim).toSeq - if (generatedTrimmed.size != expectedTrimmed.size) - sys.error(s"expected ${expectedTrimmed.size} non-empty lines, found ${generatedTrimmed.size}") - generatedTrimmed.zip(expectedTrimmed).foreach { case ((a, i), b) => - if (a != b) sys.error(s"Generated code did not match (expected '$b' on line $i, found '$a')") + def check(generatedFileName: String, expectedFileName: String) = { + val generatedCode = + Source.fromFile(s"target/scala-2.13/src_managed/main/sbt-openapi-codegen/$generatedFileName").getLines.mkString("\n") + val expectedCode = Source.fromFile(expectedFileName).getLines.mkString("\n") + val generatedTrimmed = + generatedCode.linesIterator.zipWithIndex.filterNot(_._1.forall(_.isWhitespace)).map { case (a, i) => a.trim -> i }.toSeq + val expectedTrimmed = expectedCode.linesIterator.filterNot(_.forall(_.isWhitespace)).map(_.trim).toSeq + if (generatedTrimmed.size != expectedTrimmed.size) + sys.error(s"expected ${expectedTrimmed.size} non-empty lines, found ${generatedTrimmed.size}") + generatedTrimmed.zip(expectedTrimmed).foreach { case ((a, i), b) => + if (a != b) sys.error(s"Generated code in file $generatedCode did not match (expected '$b' on line $i, found '$a')") + } } - println("Skipping swagger roundtrip for petstore") + Seq( + "TapirGeneratedEndpoints.scala" -> "Expected.scala.txt", + "TapirGeneratedEndpointsJsonSerdes.scala" -> "ExpectedJsonSerdes.scala.txt", + "TapirGeneratedEndpointsSchemas.scala" -> "ExpectedSchemas.scala.txt" + ).foreach { case (generated, expected) => check(generated, expected) } () } diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip_jsoniter/Expected.scala.txt b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip_jsoniter/Expected.scala.txt index 371bacf275..d093903e80 100644 --- a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip_jsoniter/Expected.scala.txt +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip_jsoniter/Expected.scala.txt @@ -10,6 +10,7 @@ object TapirGeneratedEndpoints { import com.github.plokhotnyuk.jsoniter_scala.core._ import sttp.tapir.generated.TapirGeneratedEndpointsJsonSerdes._ + import TapirGeneratedEndpointsSchemas._ sealed trait ADTWithoutDiscriminator sealed trait ADTWithDiscriminator diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip_scala3/Expected.scala.txt b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip_scala3/Expected.scala.txt index de77335ab3..c8b9dbf7e5 100644 --- a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip_scala3/Expected.scala.txt +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip_scala3/Expected.scala.txt @@ -9,6 +9,7 @@ object TapirGeneratedEndpoints { import io.circe.generic.semiauto._ import sttp.tapir.generated.TapirGeneratedEndpointsJsonSerdes._ + import TapirGeneratedEndpointsSchemas._ sealed trait ADTWithoutDiscriminator sealed trait ADTWithDiscriminator