Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codegen: Fix issues with jsoniter in scala3 #3963

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class ClassDefinitionGenerator {
jsonSerdeLib,
jsonParamRefs,
allTransitiveJsonParamRefs,
fullModelPath,
validateNonDiscriminatedOneOfs,
adtInheritanceMap.mapValues(_.map(_._1)),
targetScala3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ object EnumGenerator {
case _ if !jsonParamRefs.contains(name) => " derives enumextensions.EnumMirror"
case JsonSerdeLib.Circe if !queryParamRefs.contains(name) => " derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec"
case JsonSerdeLib.Circe => " derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec, enumextensions.EnumMirror"
case JsonSerdeLib.Jsoniter | JsonSerdeLib.Zio if !queryParamRefs.contains(name) => s" extends java.lang.Enum[$name]"
case JsonSerdeLib.Jsoniter | JsonSerdeLib.Zio => s" extends java.lang.Enum[$name] derives enumextensions.EnumMirror"
case JsonSerdeLib.Jsoniter if !queryParamRefs.contains(name) => ""
case JsonSerdeLib.Jsoniter => " derives enumextensions.EnumMirror"
case JsonSerdeLib.Zio if !queryParamRefs.contains(name) => s" extends java.lang.Enum[$name]"
case JsonSerdeLib.Zio => s" extends java.lang.Enum[$name] derives enumextensions.EnumMirror"
}
s"""$maybeCompanion
|enum $name$maybeCodecExtensions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ object JsonSerdeGenerator {
jsonSerdeLib: JsonSerdeLib.JsonSerdeLib,
jsonParamRefs: Set[String],
allTransitiveJsonParamRefs: Set[String],
fullModelPath: String,
validateNonDiscriminatedOneOfs: Boolean,
adtInheritanceMap: Map[String, Seq[String]],
targetScala3: Boolean
Expand All @@ -41,7 +40,6 @@ object JsonSerdeGenerator {
jsonParamRefs,
allTransitiveJsonParamRefs,
adtInheritanceMap,
if (fullModelPath.isEmpty) None else Some(fullModelPath),
validateNonDiscriminatedOneOfs
)
case JsonSerdeLib.Zio => genZioSerdes(doc, allSchemas, allTransitiveJsonParamRefs, validateNonDiscriminatedOneOfs, targetScala3)
Expand Down Expand Up @@ -233,7 +231,6 @@ object JsonSerdeGenerator {
jsonParamRefs: Set[String],
allTransitiveJsonParamRefs: Set[String],
adtInheritanceMap: Map[String, Seq[String]],
fullModelPath: Option[String],
validateNonDiscriminatedOneOfs: Boolean
): Option[String] = {
// For jsoniter-scala, we define explicit serdes for any 'primitive' params (e.g. List[java.util.UUID]) that we reference.
Expand Down Expand Up @@ -271,7 +268,7 @@ object JsonSerdeGenerator {
Some(genJsoniterEnumSerde(name))
// For ADTs, generate the serde if it's referenced in any json model
case (name, schema: OpenapiSchemaOneOf) if allTransitiveJsonParamRefs.contains(name) =>
Some(generateJsoniterAdtSerde(allSchemas, name, schema, fullModelPath, validateNonDiscriminatedOneOfs))
Some(generateJsoniterAdtSerde(allSchemas, name, schema, validateNonDiscriminatedOneOfs))
case (_, _: OpenapiSchemaObject | _: OpenapiSchemaMap | _: OpenapiSchemaEnum | _: OpenapiSchemaOneOf) => None
case (n, x) => throw new NotImplementedError(s"Only objects, enums, maps and oneOf supported! (for $n found ${x})")
})
Expand Down Expand Up @@ -304,10 +301,8 @@ object JsonSerdeGenerator {
allSchemas: Map[String, OpenapiSchemaType],
name: String,
schema: OpenapiSchemaOneOf,
maybeFullModelPath: Option[String],
validateNonDiscriminatedOneOfs: Boolean
): String = {
val fullPathPrefix = maybeFullModelPath.map(_ + ".").getOrElse("")
val uncapitalisedName = BasicGenerator.uncapitalise(name)
schema match {
case OpenapiSchemaOneOf(_, Some(discriminator)) =>
Expand All @@ -323,11 +318,11 @@ object JsonSerdeGenerator {
val body = if (schemaToJsonMapping.exists { case (className, jsonValue) => className != jsonValue }) {
val discriminatorMap = indent(2)(
schemaToJsonMapping
.map { case (k, v) => s"""case "$fullPathPrefix$k" => "$v"""" }
.map { case (k, v) => s"""case "$k" => "$v"""" }
.mkString("\n", "\n", "\n")
)
val config =
s"""$jsoniterBaseConfig.withRequireDiscriminatorFirst(false).withDiscriminatorFieldName(Some("${discriminator.propertyName}")).withAdtLeafClassNameMapper{$discriminatorMap}"""
s"""$jsoniterBaseConfig.withRequireDiscriminatorFirst(false).withDiscriminatorFieldName(Some("${discriminator.propertyName}")).withAdtLeafClassNameMapper(x => com.github.plokhotnyuk.jsoniter_scala.macros.JsonCodecMaker.simpleClassName(x) match {$discriminatorMap})"""
val serde =
s"implicit lazy val ${uncapitalisedName}Codec: $jsoniterPkgCore.JsonValueCodec[$name] = $jsoniterPkgMacros.JsonCodecMaker.make($config)"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package sttp.tapir.generated

object TapirGeneratedEndpoints {

import sttp.tapir._
import sttp.tapir.model._
import sttp.tapir.generic.auto._
import sttp.tapir.json.jsoniter._
import com.github.plokhotnyuk.jsoniter_scala.macros._
import com.github.plokhotnyuk.jsoniter_scala.core._

import sttp.tapir.generated.TapirGeneratedEndpointsJsonSerdes._
import TapirGeneratedEndpointsSchemas._


case class CommaSeparatedValues[T](values: List[T])
case class ExplodedValues[T](values: List[T])
trait ExtraParamSupport[T] {
def decode(s: String): sttp.tapir.DecodeResult[T]
def encode(t: T): String
}
implicit def makePathCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[String, T, sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.string.mapDecode(support.decode)(support.encode)
}
implicit def makeQueryCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
.mapDecode(support.decode)(support.encode)
}
implicit def makeQueryOptCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[List[String], Option[T], sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.listHeadOption[String, String, sttp.tapir.CodecFormat.TextPlain]
.mapDecode(maybeV => DecodeResult.sequence(maybeV.toSeq.map(support.decode)).map(_.headOption))(_.map(support.encode))
}
implicit def makeUnexplodedQuerySeqCodecFromListHead[T](implicit support: sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], CommaSeparatedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
.mapDecode(values => DecodeResult.sequence(values.split(',').toSeq.map(e => support.rawDecode(List(e)))).map(s => CommaSeparatedValues(s.toList)))(_.values.map(support.encode).mkString(","))
}
implicit def makeUnexplodedQueryOptSeqCodecFromListHead[T](implicit support: sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], Option[CommaSeparatedValues[T]], sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.listHeadOption[String, String, sttp.tapir.CodecFormat.TextPlain]
.mapDecode{
case None => DecodeResult.Value(None)
case Some(values) => DecodeResult.sequence(values.split(',').toSeq.map(e => support.rawDecode(List(e)))).map(r => Some(CommaSeparatedValues(r.toList)))
}(_.map(_.values.map(support.encode).mkString(",")))
}
implicit def makeExplodedQuerySeqCodecFromListSeq[T](implicit support: sttp.tapir.Codec[List[String], List[T], sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], ExplodedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
support.mapDecode(l => DecodeResult.Value(ExplodedValues(l)))(_.values)
}


sealed trait ADTWithoutDiscriminator
sealed trait ADTWithDiscriminator
sealed trait ADTWithDiscriminatorNoMapping
case class SubtypeWithoutD1 (
s: String,
i: Option[Int] = None,
a: Seq[String],
absent: Option[String] = None
) extends ADTWithoutDiscriminator
case class SubtypeWithD1 (
s: String,
i: Option[Int] = None,
d: Option[Double] = None
) extends ADTWithDiscriminator with ADTWithDiscriminatorNoMapping
case class SubtypeWithoutD3 (
s: String,
i: Option[Int] = None,
e: Option[AnEnum] = None,
absent: Option[String] = None
) extends ADTWithoutDiscriminator
case class SubtypeWithoutD2 (
a: Seq[String],
absent: Option[String] = None
) extends ADTWithoutDiscriminator
case class SubtypeWithD2 (
s: String,
a: Option[Seq[String]] = None
) extends ADTWithDiscriminator with ADTWithDiscriminatorNoMapping

enum AnEnum {
case Foo, Bar, Baz
}



lazy val putAdtTest =
endpoint
.put
.in(("adt" / "test"))
.in(jsonBody[ADTWithoutDiscriminator])
.out(jsonBody[ADTWithoutDiscriminator].description("successful operation"))

lazy val postAdtTest =
endpoint
.post
.in(("adt" / "test"))
.in(jsonBody[ADTWithDiscriminatorNoMapping])
.out(jsonBody[ADTWithDiscriminator].description("successful operation"))


lazy val generatedEndpoints = List(putAdtTest, postAdtTest)

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
lazy val root = (project in file("."))
.enablePlugins(OpenapiCodegenPlugin)
.settings(
scalaVersion := "3.3.3",
version := "0.1",
openapiJsonSerdeLib := "jsoniter"
)

libraryDependencies ++= Seq(
"com.softwaremill.sttp.tapir" %% "tapir-jsoniter-scala" % "1.10.0",
"com.softwaremill.sttp.tapir" %% "tapir-openapi-docs" % "1.10.0",
"com.softwaremill.sttp.apispec" %% "openapi-circe-yaml" % "0.8.0",
"com.beachape" %% "enumeratum" % "1.7.4",
"com.github.plokhotnyuk.jsoniter-scala" %% "jsoniter-scala-core" % "2.30.7",
"com.github.plokhotnyuk.jsoniter-scala" %% "jsoniter-scala-macros" % "2.30.7" % "compile-internal",
"org.scalatest" %% "scalatest" % "3.2.19" % Test,
"com.softwaremill.sttp.tapir" %% "tapir-sttp-stub-server" % "1.10.0" % Test
)

import sttp.tapir.sbt.OpenapiCodegenPlugin.autoImport.{openapiJsonSerdeLib, openapiUseHeadTagForObjectName}

import scala.io.Source

TaskKey[Unit]("check") := {
val generatedCode =
Source.fromFile("target/scala-3.3.3/src_managed/main/sbt-openapi-codegen/TapirGeneratedEndpoints.scala").getLines.mkString("\n")
val expected = Source.fromFile("Expected.scala.txt").getLines.mkString("\n")
val generatedTrimmed =
generatedCode.linesIterator.zipWithIndex.filterNot(_._1.forall(_.isWhitespace)).map { case (a, i) => a.trim -> i }.toSeq
val expectedTrimmed = expected.linesIterator.filterNot(_.forall(_.isWhitespace)).map(_.trim).toSeq
if (generatedTrimmed.size != expectedTrimmed.size)
sys.error(s"expected ${expectedTrimmed.size} non-empty lines, found ${generatedTrimmed.size}")
generatedTrimmed.zip(expectedTrimmed).foreach { case ((a, i), b) =>
if (a != b) sys.error(s"Generated code did not match (expected '$b' on line $i, found '$a')")
}
println("Skipping swagger roundtrip for petstore")
()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sbt.version=1.10.1
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
val pluginVersion = System.getProperty("plugin.version")
if (pluginVersion == null)
throw new RuntimeException("""|
|
|The system property 'plugin.version' is not defined.
|Specify this property using the scriptedLaunchOpts -D.
|
|""".stripMargin)
else addSbtPlugin("com.softwaremill.sttp.tapir" % "sbt-openapi-codegen" % pluginVersion)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
object Main extends App {
import sttp.apispec.openapi.circe.yaml._
import sttp.tapir.generated._
import sttp.tapir.docs.openapi._

val docs = OpenAPIDocsInterpreter().toOpenAPI(TapirGeneratedEndpoints.generatedEndpoints, "My Bookshop", "1.0")

import java.nio.file.{Paths, Files}
import java.nio.charset.StandardCharsets

Files.write(Paths.get("target/swagger.yaml"), docs.toYaml.getBytes(StandardCharsets.UTF_8))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import com.github.plokhotnyuk.jsoniter_scala.core.writeToString
import io.circe.parser.parse
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers
import sttp.client3.UriContext
import sttp.client3.testing.SttpBackendStub
import sttp.tapir.generated.{TapirGeneratedEndpoints, TapirGeneratedEndpointsJsonSerdes}
import sttp.tapir.generated.TapirGeneratedEndpoints.*
import sttp.tapir.generated.TapirGeneratedEndpointsSchemas.*
import TapirGeneratedEndpointsJsonSerdes._
import sttp.tapir.server.stub.TapirStubInterpreter

import scala.concurrent.duration.DurationInt
import scala.concurrent.{Await, Future}
import scala.concurrent.ExecutionContext.Implicits.global

class JsonRoundtrip extends AnyFreeSpec with Matchers {
"oneOf without discriminator can be round-tripped by generated serdes" in {
val route = TapirGeneratedEndpoints.putAdtTest.serverLogic[Future]({
case foo: SubtypeWithoutD1 =>
Future successful Right[Unit, ADTWithoutDiscriminator](SubtypeWithoutD1(foo.s + "+SubtypeWithoutD1", foo.i, foo.a))
case foo: SubtypeWithoutD2 => Future successful Right[Unit, ADTWithoutDiscriminator](SubtypeWithoutD2(foo.a :+ "+SubtypeWithoutD2"))
case foo: SubtypeWithoutD3 =>
Future successful Right[Unit, ADTWithoutDiscriminator](SubtypeWithoutD3(foo.s + "+SubtypeWithoutD3", foo.i, foo.e))
})

val stub = TapirStubInterpreter(SttpBackendStub.asynchronousFuture)
.whenServerEndpoint(route)
.thenRunLogic()
.backend()

def normalise(json: String): String = parse(json).toTry.get.noSpacesSortKeys
locally {
val reqBody = SubtypeWithoutD1("a string", Some(123), Seq("string 1", "string 2"))
val reqJsonBody = writeToString(reqBody)
val respBody = SubtypeWithoutD1("a string+SubtypeWithoutD1", Some(123), Seq("string 1", "string 2"))
val respJsonBody = writeToString(respBody)
reqJsonBody shouldEqual """{"s":"a string","i":123,"a":["string 1","string 2"]}"""
respJsonBody shouldEqual """{"s":"a string+SubtypeWithoutD1","i":123,"a":["string 1","string 2"]}"""
Await.result(
sttp.client3.basicRequest
.put(uri"http://test.com/adt/test")
.body(reqJsonBody)
.send(stub)
.map { resp =>
resp.code.code === 200
resp.body shouldEqual Right(respJsonBody)
},
1.second
)
}

locally {
val reqBody = SubtypeWithoutD2(Seq("string 1", "string 2"))
val reqJsonBody = writeToString(reqBody)
val respBody = SubtypeWithoutD2(Seq("string 1", "string 2", "+SubtypeWithoutD2"))
val respJsonBody = writeToString(respBody)
reqJsonBody shouldEqual """{"a":["string 1","string 2"]}"""
respJsonBody shouldEqual """{"a":["string 1","string 2","+SubtypeWithoutD2"]}"""
Await.result(
sttp.client3.basicRequest
.put(uri"http://test.com/adt/test")
.body(reqJsonBody)
.send(stub)
.map { resp =>
resp.body shouldEqual Right(respJsonBody)
resp.code.code === 200
},
1.second
)
}

locally {
val reqBody = SubtypeWithoutD3("a string", Some(123), Some(AnEnum.Foo))
val reqJsonBody = writeToString(reqBody)
val respBody = SubtypeWithoutD3("a string+SubtypeWithoutD3", Some(123), Some(AnEnum.Foo))
val respJsonBody = writeToString(respBody)
reqJsonBody shouldEqual """{"s":"a string","i":123,"e":"Foo"}"""
respJsonBody shouldEqual """{"s":"a string+SubtypeWithoutD3","i":123,"e":"Foo"}"""
Await.result(
sttp.client3.basicRequest
.put(uri"http://test.com/adt/test")
.body(reqJsonBody)
.send(stub)
.map { resp =>
resp.body shouldEqual Right(respJsonBody)
resp.code.code === 200
},
1.second
)
}
}
"oneOf with discriminator can be round-tripped by generated serdes" in {
val route = TapirGeneratedEndpoints.postAdtTest.serverLogic[Future]({
case foo: SubtypeWithD1 => Future successful Right[Unit, ADTWithDiscriminator](SubtypeWithD1(foo.s + "+SubtypeWithD1", foo.i, foo.d))
case foo: SubtypeWithD2 => Future successful Right[Unit, ADTWithDiscriminator](SubtypeWithD2(foo.s + "+SubtypeWithD2", foo.a))
})

val stub = TapirStubInterpreter(SttpBackendStub.asynchronousFuture)
.whenServerEndpoint(route)
.thenRunLogic()
.backend()

def normalise(json: String): String = parse(json).toTry.get.noSpacesSortKeys

locally {
val reqBody: ADTWithDiscriminatorNoMapping = SubtypeWithD1("a string", Some(123), Some(23.4))
val reqJsonBody = writeToString(reqBody)
val respBody: ADTWithDiscriminator = SubtypeWithD1("a string+SubtypeWithD1", Some(123), Some(23.4))
val respJsonBody = writeToString(respBody)
reqJsonBody shouldEqual """{"type":"SubtypeWithD1","s":"a string","i":123,"d":23.4}"""
respJsonBody shouldEqual """{"type":"SubA","s":"a string+SubtypeWithD1","i":123,"d":23.4}"""
Await.result(
sttp.client3.basicRequest
.post(uri"http://test.com/adt/test")
.body(reqJsonBody)
.send(stub)
.map { resp =>
resp.code.code === 200
resp.body shouldEqual Right(respJsonBody)
},
1.second
)
}

locally {
val reqBody: ADTWithDiscriminatorNoMapping = SubtypeWithD2("a string", Some(Seq("string 1", "string 2")))
val reqJsonBody = writeToString(reqBody)
val respBody: ADTWithDiscriminator = SubtypeWithD2("a string+SubtypeWithD2", Some(Seq("string 1", "string 2")))
val respJsonBody = writeToString(respBody)
reqJsonBody shouldEqual """{"type":"SubtypeWithD2","s":"a string","a":["string 1","string 2"]}"""
respJsonBody shouldEqual """{"type":"SubB","s":"a string+SubtypeWithD2","a":["string 1","string 2"]}"""
Await.result(
sttp.client3.basicRequest
.post(uri"http://test.com/adt/test")
.body(reqJsonBody)
.send(stub)
.map { resp =>
resp.code.code === 200
resp.body shouldEqual Right(respJsonBody)
},
1.second
)
}

}
}
Loading
Loading