diff --git a/core/src/main/scala/sttp/tapir/Schema.scala b/core/src/main/scala/sttp/tapir/Schema.scala index 8a39d2616b..4404c05065 100644 --- a/core/src/main/scala/sttp/tapir/Schema.scala +++ b/core/src/main/scala/sttp/tapir/Schema.scala @@ -361,6 +361,26 @@ object Schema extends LowPrioritySchema with SchemaCompanionMacros { val Attribute: AttributeKey[Tuple] = new AttributeKey[Tuple]("sttp.tapir.Schema.Tuple") } + /** For coproduct schemas, when there's a discriminator field, used to attach the encoded value of the discriminator field. Such value is + * added to the discriminator field schemas in each of the coproduct's subtypes. When rendering OpenAPI/JSON schema, these values are + * converted to `const` constraints on fields. + */ + case class EncodedDiscriminatorValue(v: String) + object EncodedDiscriminatorValue { + /* + Implementation note: the discriminator value constraint is in fact an enum validator with a single possible enum value. Hence an + alternative design would be to add such validators to discriminator fields, instead of an attribute. However, this has two drawbacks: + 1. when adding discriminator fields using `addDiscriminatorField`, we don't have access to the decoded discriminator value - only + to the encoded one, via reverse mapping lookup + 2. the validator doesn't necessarily make sense, as it can't be used to validate the deserialiszd object. Usually the discriminator + fields don't even exist on the high-level representations. + That's why instead of re-using the validators, we decided to use a specialised attribute. + */ + + val Attribute: AttributeKey[EncodedDiscriminatorValue] = + new AttributeKey[EncodedDiscriminatorValue]("sttp.tapir.Schema.EncodedDiscriminatorValue") + } + /** @param typeParameterShortNames * full name of type parameters, name is legacy and kept only for backward compatibility */ diff --git a/core/src/main/scala/sttp/tapir/SchemaType.scala b/core/src/main/scala/sttp/tapir/SchemaType.scala index 06b395ce9e..e413f145cc 100644 --- a/core/src/main/scala/sttp/tapir/SchemaType.scala +++ b/core/src/main/scala/sttp/tapir/SchemaType.scala @@ -129,11 +129,39 @@ object SchemaType { discriminatorSchema: Schema[D] = Schema.string, discriminatorMapping: Map[String, SRef[_]] = Map.empty ): SCoproduct[T] = { + // used to add encoded discriminator value attributes + val reverseDiscriminatorByNameMapping: Map[SName, String] = discriminatorMapping.toList.map { case (v, ref) => (ref.name, v) }.toMap + SCoproduct( subtypes.map { - case s @ Schema(st: SchemaType.SProduct[Any @unchecked], _, _, _, _, _, _, _, _, _, _) - if st.fields.forall(_.name != discriminatorName) => - s.copy(schemaType = st.copy(fields = st.fields :+ SProductField[Any, D](discriminatorName, discriminatorSchema, _ => None))) + case s @ Schema(st: SchemaType.SProduct[Any @unchecked], _, _, _, _, _, _, _, _, _, _) => + // first, ensuring that the discriminator field is added to the schema type - it might already be present + var targetSt = + if (st.fields.forall(_.name != discriminatorName)) + st.copy(fields = st.fields :+ SProductField[Any, D](discriminatorName, discriminatorSchema, _ => None)) + else st + + // next, modifying the discriminator field, by adding the value attribute (if a value can be found) + targetSt = targetSt.copy(fields = targetSt.fields.map { field => + if (field.name == discriminatorName) { + val discriminatorValue = s.name.flatMap { subtypeName => + reverseDiscriminatorByNameMapping.get(subtypeName) + } + + discriminatorValue match { + case Some(v) => + SProductField( + field.name, + field.schema.attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue(v)), + field.get + ) + case None => field + } + + } else field + }) + + s.copy(schemaType = targetSt) case s => s }, Some(SDiscriminator(discriminatorName, discriminatorMapping)) diff --git a/core/src/test/scala/sttp/tapir/SchemaMacroTest.scala b/core/src/test/scala/sttp/tapir/SchemaMacroTest.scala index 2b3df4fa97..c5b77ab228 100644 --- a/core/src/test/scala/sttp/tapir/SchemaMacroTest.scala +++ b/core/src/test/scala/sttp/tapir/SchemaMacroTest.scala @@ -304,7 +304,14 @@ class SchemaMacroTest extends AnyFlatSpec with Matchers with TableDrivenProperty schemaType.subtypes.foreach { childSchema => val childProduct = childSchema.schemaType.asInstanceOf[SProduct[_]] - childProduct.fields.find(_.name.name == "kind") shouldBe Some(SProductField(FieldName("kind"), Schema.string, (_: Any) => None)) + val discValue = if (childSchema.name.get.fullName == "sttp.tapir.SchemaMacroTestData.User") "user" else "org" + childProduct.fields.find(_.name.name == "kind") shouldBe Some( + SProductField( + FieldName("kind"), + Schema.string.attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue(discValue)), + (_: Any) => None + ) + ) } } diff --git a/core/src/test/scala/sttp/tapir/generic/SchemaGenericAutoTest.scala b/core/src/test/scala/sttp/tapir/generic/SchemaGenericAutoTest.scala index 311ee9e6b1..fd4d50636e 100644 --- a/core/src/test/scala/sttp/tapir/generic/SchemaGenericAutoTest.scala +++ b/core/src/test/scala/sttp/tapir/generic/SchemaGenericAutoTest.scala @@ -245,7 +245,13 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers { schemaType.asInstanceOf[SCoproduct[Entity]].subtypes should contain theSameElementsAs List( Schema( SProduct[Organization]( - List(field(FieldName("name"), Schema(SString())), field(FieldName("who_am_i"), Schema(SString()))) + List( + field(FieldName("name"), Schema(SString())), + field( + FieldName("who_am_i"), + Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("Organization")) + ) + ) ), Some(SName("sttp.tapir.generic.Organization")) ), @@ -254,7 +260,10 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers { List( field(FieldName("first"), Schema(SString())), field(FieldName("age"), Schema(SInteger(), format = Some("int32"))), - field(FieldName("who_am_i"), Schema(SString())) + field( + FieldName("who_am_i"), + Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("Person")) + ) ) ), Some(SName("sttp.tapir.generic.Person")) @@ -262,7 +271,10 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers { Schema( SProduct[UnknownEntity.type]( List( - field(FieldName("who_am_i"), Schema(SString())) + field( + FieldName("who_am_i"), + Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("UnknownEntity")) + ) ) ), Some(SName("sttp.tapir.generic.UnknownEntity")) diff --git a/docs/apispec-docs/src/main/scala/sttp/tapir/docs/apispec/schema/TSchemaToASchema.scala b/docs/apispec-docs/src/main/scala/sttp/tapir/docs/apispec/schema/TSchemaToASchema.scala index d778a19728..3f7dc4f483 100644 --- a/docs/apispec-docs/src/main/scala/sttp/tapir/docs/apispec/schema/TSchemaToASchema.scala +++ b/docs/apispec-docs/src/main/scala/sttp/tapir/docs/apispec/schema/TSchemaToASchema.scala @@ -112,9 +112,13 @@ private[docs] class TSchemaToASchema( // The primary motivation for using schema name as fallback title is to improve Swagger UX with // `oneOf` schemas in OpenAPI 3.1. See https://github.com/softwaremill/tapir/issues/3447 for details. def fallbackTitle = tschema.name.map(fallbackSchemaTitle) + + val const = tschema.attribute(TSchema.EncodedDiscriminatorValue.Attribute).map(_.v).map(v => ExampleSingleValue(v)) + oschema - .copy(title = titleFromAttr orElse fallbackTitle) + .copy(title = titleFromAttr.orElse(fallbackTitle)) .copy(uniqueItems = tschema.attribute(UniqueItems.Attribute).map(_.uniqueItems)) + .copy(const = const) } private def addMetadata(oschema: ASchema, tschema: TSchema[_]): ASchema = { diff --git a/docs/asyncapi-docs/src/test/resources/expected_coproduct_with_discriminator.yml b/docs/asyncapi-docs/src/test/resources/expected_coproduct_with_discriminator.yml new file mode 100644 index 0000000000..bbbe58cfd0 --- /dev/null +++ b/docs/asyncapi-docs/src/test/resources/expected_coproduct_with_discriminator.yml @@ -0,0 +1,69 @@ +asyncapi: 2.6.0 +info: + title: discriminator + version: '1.0' +channels: + /animals: + subscribe: + operationId: onAnimals + message: + $ref: '#/components/messages/Animal' + publish: + operationId: sendAnimals + message: + $ref: '#/components/messages/GetAnimal' + bindings: + ws: + method: GET +components: + schemas: + GetAnimal: + title: GetAnimal + type: object + required: + - name + properties: + name: + type: string + Animal: + title: Animal + oneOf: + - $ref: '#/components/schemas/Cat' + - $ref: '#/components/schemas/Dog' + discriminator: pet + Cat: + title: Cat + type: object + required: + - name + - pet + properties: + name: + type: string + pet: + type: string + const: Cat + Dog: + title: Dog + type: object + required: + - name + - breed + - pet + properties: + name: + type: string + breed: + type: string + pet: + type: string + const: Dog + messages: + GetAnimal: + payload: + $ref: '#/components/schemas/GetAnimal' + contentType: application/json + Animal: + payload: + $ref: '#/components/schemas/Animal' + contentType: application/json diff --git a/docs/asyncapi-docs/src/test/scala/sttp/tapir/docs/asyncapi/VerifyAsyncAPIYamlTest.scala b/docs/asyncapi-docs/src/test/scala/sttp/tapir/docs/asyncapi/VerifyAsyncAPIYamlTest.scala index bd60d01d71..82667676e4 100644 --- a/docs/asyncapi-docs/src/test/scala/sttp/tapir/docs/asyncapi/VerifyAsyncAPIYamlTest.scala +++ b/docs/asyncapi-docs/src/test/scala/sttp/tapir/docs/asyncapi/VerifyAsyncAPIYamlTest.scala @@ -133,7 +133,11 @@ class VerifyAsyncAPIYamlTest extends AnyFunSuite with Matchers { .out( webSocketBody[Fruit, CodecFormat.Json, Int, CodecFormat.Json](AkkaStreams) // TODO: missing `RequestInfo.example(example: EndpointIO.Example)` and friends - .pipe(e => e.copy(requestsInfo = e.requestsInfo.example(Example.of(Fruit("apple")).name("Apple").summary("Sample representation of apple")))) + .pipe(e => + e.copy(requestsInfo = + e.requestsInfo.example(Example.of(Fruit("apple")).name("Apple").summary("Sample representation of apple")) + ) + ) ) val expectedYaml = loadYaml("expected_json_example_name_summary.yml") @@ -232,6 +236,22 @@ class VerifyAsyncAPIYamlTest extends AnyFunSuite with Matchers { noIndentation(yaml) shouldBe loadYaml("expected_flags_header.yml") } + test("should work with discriminators") { + case class GetAnimal(name: String) + sealed trait Animal + case class Cat(name: String) extends Animal + case class Dog(name: String, breed: String) extends Animal + implicit val configuration: sttp.tapir.generic.Configuration = sttp.tapir.generic.Configuration.default.withDiscriminator("pet") + + val animalEndpoint = endpoint.get + .in("animals") + .out(webSocketBody[GetAnimal, CodecFormat.Json, Animal, CodecFormat.Json](AkkaStreams)) + + val yaml = AsyncAPIInterpreter().toAsyncAPI(animalEndpoint, "discriminator", "1.0").toYaml + + noIndentation(yaml) shouldBe loadYaml("expected_coproduct_with_discriminator.yml") + } + private def loadYaml(fileName: String): String = { noIndentation(Source.fromInputStream(getClass.getResourceAsStream(s"/$fileName")).getLines().mkString("\n")) } diff --git a/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator.yml b/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator.yml index 9cfb4bee8a..d51d6e65e6 100644 --- a/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator.yml +++ b/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator.yml @@ -33,6 +33,7 @@ components: properties: name: type: string + const: sml Person: title: Person type: object @@ -42,6 +43,7 @@ components: properties: name: type: string + const: john age: type: integer format: int32 diff --git a/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator_nested.yml b/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator_nested.yml index c7a3eaccca..5491f27c23 100644 --- a/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator_nested.yml +++ b/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator_nested.yml @@ -41,6 +41,7 @@ components: properties: name: type: string + const: sml Person: title: Person type: object @@ -50,6 +51,7 @@ components: properties: name: type: string + const: john age: type: integer format: int32 diff --git a/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator_with_enum_circe.yml b/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator_with_enum_circe.yml index b12ded75fa..61a4cd44fa 100644 --- a/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator_with_enum_circe.yml +++ b/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_discriminator_with_enum_circe.yml @@ -37,3 +37,4 @@ components: - red shapeType: type: string + const: Square diff --git a/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_independent.yml b/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_independent.yml index ef7ba4f4db..cff3fb0580 100644 --- a/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_independent.yml +++ b/docs/openapi-docs/src/test/resources/coproduct/expected_coproduct_independent.yml @@ -74,6 +74,7 @@ components: type: string kind: type: string + const: organization Person: title: Person type: object @@ -100,4 +101,5 @@ components: type: integer format: int32 kind: - type: string \ No newline at end of file + type: string + const: person diff --git a/json/pickler/src/test/scala/sttp/tapir/json/pickler/SchemaDerivationTest.scala b/json/pickler/src/test/scala/sttp/tapir/json/pickler/SchemaDerivationTest.scala index af77c3ee4c..466f503f81 100644 --- a/json/pickler/src/test/scala/sttp/tapir/json/pickler/SchemaDerivationTest.scala +++ b/json/pickler/src/test/scala/sttp/tapir/json/pickler/SchemaDerivationTest.scala @@ -12,8 +12,8 @@ import sttp.tapir.{AttributeKey, FieldName, Schema, SchemaType, Validator} import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInteger} -class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers with Inside { - import SchemaGenericAutoTest._ +class SchemaDerivationTest extends AsyncFlatSpec with Matchers with Inside { + import SchemaDerivationTest._ import generic.auto._ def implicitlySchema[T: Pickler]: Schema[T] = summon[Pickler[T]].schema @@ -210,7 +210,7 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers with Inside { val schema = implicitlySchema[Test1] // when - schema.name shouldBe Some(SName("sttp.tapir.json.pickler.SchemaGenericAutoTest..Test1")) + schema.name shouldBe Some(SName("sttp.tapir.json.pickler.SchemaDerivationTest..Test1")) schema.schemaType shouldBe SProduct[Test1]( List( field(FieldName("f1"), implicitlySchema[String]), @@ -276,7 +276,13 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers with Inside { schemaType.asInstanceOf[SCoproduct[Entity]].subtypes should contain theSameElementsAs List( Schema( SProduct[Organization]( - List(field(FieldName("name"), Schema(SString())), field(FieldName("who_am_i"), Schema(SString()))) + List( + field(FieldName("name"), Schema(SString())), + field( + FieldName("who_am_i"), + Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("Organization")) + ) + ) ), Some(SName("sttp.tapir.json.pickler.Organization")) ), @@ -285,7 +291,10 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers with Inside { List( field(FieldName("first"), Schema(SString())), field(FieldName("age"), Schema(SInteger(), format = Some("int32"))), - field(FieldName("who_am_i"), Schema(SString())) + field( + FieldName("who_am_i"), + Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("Person")) + ) ) ), Some(SName("sttp.tapir.json.pickler.Person")) @@ -293,7 +302,10 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers with Inside { Schema( SProduct[UnknownEntity.type]( List( - field(FieldName("who_am_i"), Schema(SString())) + field( + FieldName("who_am_i"), + Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("UnknownEntity")) + ) ) ), Some(SName("sttp.tapir.json.pickler.UnknownEntity")) @@ -423,7 +435,10 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers with Inside { List( field(FieldName("name"), stringSchema.copy(description = Some("cat name"))), field(FieldName("catFood"), stringSchema.copy(description = Some("cat food"))), - field(FieldName("$type"), Schema(SString())) + field( + FieldName("$type"), + Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("Cat")) + ) ) ), Some(SName("sttp.tapir.SchemaMacroTestData.Cat")) @@ -434,7 +449,10 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers with Inside { List( field(FieldName("name"), stringSchema.copy(description = Some("name"))), field(FieldName("dogFood"), stringSchema.copy(description = Some("dog food"))), - field(FieldName("$type"), Schema(SString())) + field( + FieldName("$type"), + Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("Dog")) + ) ) ), Some(SName("sttp.tapir.SchemaMacroTestData.Dog")) @@ -445,7 +463,10 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers with Inside { List( field(FieldName("name"), stringSchema.copy(description = Some("name"))), field(FieldName("likesNuts"), booleanSchema.copy(description = Some("likes nuts?"))), - field(FieldName("$type"), Schema(SString())) + field( + FieldName("$type"), + Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("Hamster")) + ) ) ), Some(SName("sttp.tapir.SchemaMacroTestData.Hamster")) @@ -470,7 +491,7 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers with Inside { } } -object SchemaGenericAutoTest { +object SchemaDerivationTest { import generic.auto._ def implicitlySchema[A: Pickler]: Schema[A] = summon[Pickler[A]].schema diff --git a/project/Versions.scala b/project/Versions.scala index ce9540db41..1e67bb0f77 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -11,7 +11,7 @@ object Versions { val sttp = "3.9.7" val sttpModel = "1.7.11" val sttpShared = "1.3.19" - val sttpApispec = "0.11.0" + val sttpApispec = "0.11.2" val akkaHttp = "10.2.10" val akkaStreams = "2.6.20" val pekkoHttp = "1.0.1"