Skip to content

Commit

Permalink
Add an encoded discriminator value attribute for coproducts, use it t…
Browse files Browse the repository at this point in the history
…o render const constraints (#3955)
  • Loading branch information
adamw authored Jul 26, 2024
1 parent 4822f59 commit f2c22af
Show file tree
Hide file tree
Showing 13 changed files with 209 additions and 21 deletions.
20 changes: 20 additions & 0 deletions core/src/main/scala/sttp/tapir/Schema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,26 @@ object Schema extends LowPrioritySchema with SchemaCompanionMacros {
val Attribute: AttributeKey[Tuple] = new AttributeKey[Tuple]("sttp.tapir.Schema.Tuple")
}

/** For coproduct schemas, when there's a discriminator field, used to attach the encoded value of the discriminator field. Such value is
* added to the discriminator field schemas in each of the coproduct's subtypes. When rendering OpenAPI/JSON schema, these values are
* converted to `const` constraints on fields.
*/
case class EncodedDiscriminatorValue(v: String)
object EncodedDiscriminatorValue {
/*
Implementation note: the discriminator value constraint is in fact an enum validator with a single possible enum value. Hence an
alternative design would be to add such validators to discriminator fields, instead of an attribute. However, this has two drawbacks:
1. when adding discriminator fields using `addDiscriminatorField`, we don't have access to the decoded discriminator value - only
to the encoded one, via reverse mapping lookup
2. the validator doesn't necessarily make sense, as it can't be used to validate the deserialiszd object. Usually the discriminator
fields don't even exist on the high-level representations.
That's why instead of re-using the validators, we decided to use a specialised attribute.
*/

val Attribute: AttributeKey[EncodedDiscriminatorValue] =
new AttributeKey[EncodedDiscriminatorValue]("sttp.tapir.Schema.EncodedDiscriminatorValue")
}

/** @param typeParameterShortNames
* full name of type parameters, name is legacy and kept only for backward compatibility
*/
Expand Down
34 changes: 31 additions & 3 deletions core/src/main/scala/sttp/tapir/SchemaType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,39 @@ object SchemaType {
discriminatorSchema: Schema[D] = Schema.string,
discriminatorMapping: Map[String, SRef[_]] = Map.empty
): SCoproduct[T] = {
// used to add encoded discriminator value attributes
val reverseDiscriminatorByNameMapping: Map[SName, String] = discriminatorMapping.toList.map { case (v, ref) => (ref.name, v) }.toMap

SCoproduct(
subtypes.map {
case s @ Schema(st: SchemaType.SProduct[Any @unchecked], _, _, _, _, _, _, _, _, _, _)
if st.fields.forall(_.name != discriminatorName) =>
s.copy(schemaType = st.copy(fields = st.fields :+ SProductField[Any, D](discriminatorName, discriminatorSchema, _ => None)))
case s @ Schema(st: SchemaType.SProduct[Any @unchecked], _, _, _, _, _, _, _, _, _, _) =>
// first, ensuring that the discriminator field is added to the schema type - it might already be present
var targetSt =
if (st.fields.forall(_.name != discriminatorName))
st.copy(fields = st.fields :+ SProductField[Any, D](discriminatorName, discriminatorSchema, _ => None))
else st

// next, modifying the discriminator field, by adding the value attribute (if a value can be found)
targetSt = targetSt.copy(fields = targetSt.fields.map { field =>
if (field.name == discriminatorName) {
val discriminatorValue = s.name.flatMap { subtypeName =>
reverseDiscriminatorByNameMapping.get(subtypeName)
}

discriminatorValue match {
case Some(v) =>
SProductField(
field.name,
field.schema.attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue(v)),
field.get
)
case None => field
}

} else field
})

s.copy(schemaType = targetSt)
case s => s
},
Some(SDiscriminator(discriminatorName, discriminatorMapping))
Expand Down
9 changes: 8 additions & 1 deletion core/src/test/scala/sttp/tapir/SchemaMacroTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,14 @@ class SchemaMacroTest extends AnyFlatSpec with Matchers with TableDrivenProperty

schemaType.subtypes.foreach { childSchema =>
val childProduct = childSchema.schemaType.asInstanceOf[SProduct[_]]
childProduct.fields.find(_.name.name == "kind") shouldBe Some(SProductField(FieldName("kind"), Schema.string, (_: Any) => None))
val discValue = if (childSchema.name.get.fullName == "sttp.tapir.SchemaMacroTestData.User") "user" else "org"
childProduct.fields.find(_.name.name == "kind") shouldBe Some(
SProductField(
FieldName("kind"),
Schema.string.attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue(discValue)),
(_: Any) => None
)
)
}
}

Expand Down
18 changes: 15 additions & 3 deletions core/src/test/scala/sttp/tapir/generic/SchemaGenericAutoTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,13 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers {
schemaType.asInstanceOf[SCoproduct[Entity]].subtypes should contain theSameElementsAs List(
Schema(
SProduct[Organization](
List(field(FieldName("name"), Schema(SString())), field(FieldName("who_am_i"), Schema(SString())))
List(
field(FieldName("name"), Schema(SString())),
field(
FieldName("who_am_i"),
Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("Organization"))
)
)
),
Some(SName("sttp.tapir.generic.Organization"))
),
Expand All @@ -254,15 +260,21 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers {
List(
field(FieldName("first"), Schema(SString())),
field(FieldName("age"), Schema(SInteger(), format = Some("int32"))),
field(FieldName("who_am_i"), Schema(SString()))
field(
FieldName("who_am_i"),
Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("Person"))
)
)
),
Some(SName("sttp.tapir.generic.Person"))
),
Schema(
SProduct[UnknownEntity.type](
List(
field(FieldName("who_am_i"), Schema(SString()))
field(
FieldName("who_am_i"),
Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("UnknownEntity"))
)
)
),
Some(SName("sttp.tapir.generic.UnknownEntity"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,13 @@ private[docs] class TSchemaToASchema(
// The primary motivation for using schema name as fallback title is to improve Swagger UX with
// `oneOf` schemas in OpenAPI 3.1. See https://github.com/softwaremill/tapir/issues/3447 for details.
def fallbackTitle = tschema.name.map(fallbackSchemaTitle)

val const = tschema.attribute(TSchema.EncodedDiscriminatorValue.Attribute).map(_.v).map(v => ExampleSingleValue(v))

oschema
.copy(title = titleFromAttr orElse fallbackTitle)
.copy(title = titleFromAttr.orElse(fallbackTitle))
.copy(uniqueItems = tschema.attribute(UniqueItems.Attribute).map(_.uniqueItems))
.copy(const = const)
}

private def addMetadata(oschema: ASchema, tschema: TSchema[_]): ASchema = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
asyncapi: 2.6.0
info:
title: discriminator
version: '1.0'
channels:
/animals:
subscribe:
operationId: onAnimals
message:
$ref: '#/components/messages/Animal'
publish:
operationId: sendAnimals
message:
$ref: '#/components/messages/GetAnimal'
bindings:
ws:
method: GET
components:
schemas:
GetAnimal:
title: GetAnimal
type: object
required:
- name
properties:
name:
type: string
Animal:
title: Animal
oneOf:
- $ref: '#/components/schemas/Cat'
- $ref: '#/components/schemas/Dog'
discriminator: pet
Cat:
title: Cat
type: object
required:
- name
- pet
properties:
name:
type: string
pet:
type: string
const: Cat
Dog:
title: Dog
type: object
required:
- name
- breed
- pet
properties:
name:
type: string
breed:
type: string
pet:
type: string
const: Dog
messages:
GetAnimal:
payload:
$ref: '#/components/schemas/GetAnimal'
contentType: application/json
Animal:
payload:
$ref: '#/components/schemas/Animal'
contentType: application/json
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ class VerifyAsyncAPIYamlTest extends AnyFunSuite with Matchers {
.out(
webSocketBody[Fruit, CodecFormat.Json, Int, CodecFormat.Json](AkkaStreams)
// TODO: missing `RequestInfo.example(example: EndpointIO.Example)` and friends
.pipe(e => e.copy(requestsInfo = e.requestsInfo.example(Example.of(Fruit("apple")).name("Apple").summary("Sample representation of apple"))))
.pipe(e =>
e.copy(requestsInfo =
e.requestsInfo.example(Example.of(Fruit("apple")).name("Apple").summary("Sample representation of apple"))
)
)
)

val expectedYaml = loadYaml("expected_json_example_name_summary.yml")
Expand Down Expand Up @@ -232,6 +236,22 @@ class VerifyAsyncAPIYamlTest extends AnyFunSuite with Matchers {
noIndentation(yaml) shouldBe loadYaml("expected_flags_header.yml")
}

test("should work with discriminators") {
case class GetAnimal(name: String)
sealed trait Animal
case class Cat(name: String) extends Animal
case class Dog(name: String, breed: String) extends Animal
implicit val configuration: sttp.tapir.generic.Configuration = sttp.tapir.generic.Configuration.default.withDiscriminator("pet")

val animalEndpoint = endpoint.get
.in("animals")
.out(webSocketBody[GetAnimal, CodecFormat.Json, Animal, CodecFormat.Json](AkkaStreams))

val yaml = AsyncAPIInterpreter().toAsyncAPI(animalEndpoint, "discriminator", "1.0").toYaml

noIndentation(yaml) shouldBe loadYaml("expected_coproduct_with_discriminator.yml")
}

private def loadYaml(fileName: String): String = {
noIndentation(Source.fromInputStream(getClass.getResourceAsStream(s"/$fileName")).getLines().mkString("\n"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ components:
properties:
name:
type: string
const: sml
Person:
title: Person
type: object
Expand All @@ -42,6 +43,7 @@ components:
properties:
name:
type: string
const: john
age:
type: integer
format: int32
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ components:
properties:
name:
type: string
const: sml
Person:
title: Person
type: object
Expand All @@ -50,6 +51,7 @@ components:
properties:
name:
type: string
const: john
age:
type: integer
format: int32
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ components:
- red
shapeType:
type: string
const: Square
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ components:
type: string
kind:
type: string
const: organization
Person:
title: Person
type: object
Expand All @@ -100,4 +101,5 @@ components:
type: integer
format: int32
kind:
type: string
type: string
const: person
Loading

0 comments on commit f2c22af

Please sign in to comment.