Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix YamlDecoder and NaN tag #330

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions core/shared/src/main/scala/org/virtuslab/yaml/Tag.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,33 @@ object Tag {
val corePrimitives = Set(nullTag, boolean, int, float, str)
val coreSchemaValues = (corePrimitives ++ Set(seq, map)).map(_.value)

private val nullPattern = "null|Null|NULL|~".r
private val booleanPattern = "true|True|TRUE|false|False|FALSE".r
private val int10Pattern = "[-+]?[0-9]+".r
private val int8Pattern = "0o[0-7]+".r
private val int16Pattern = "0x[0-9a-fA-F]+".r
private val floatPattern = "[-+]?(\\.[0-9]+|[0-9]+(\\.[0-9]*)?)([eE][-+]?[0-9]+)?".r
private val minusInfinity = "-(\\.inf|\\.Inf|\\.INF)".r
private val plusInfinity = "\\+?(\\.inf|\\.Inf|\\.INF)".r
private[yaml] val nullPattern = "^(null|Null|NULL|~)?$".r
private[yaml] val falsePattern = "false|False|FALSE".r
private[yaml] val truePattern = "true|True|TRUE".r
private val int10Pattern = "[-+]?[0-9]+".r
private val int8Pattern = "0o[0-7]+".r
private val int16Pattern = "0x[0-9a-fA-F]+".r
private val floatPattern = "[-+]?(\\.[0-9]+|[0-9]+(\\.[0-9]*)?)([eE][-+]?[0-9]+)?".r
private[yaml] val minusInfinity = "-(\\.inf|\\.Inf|\\.INF)".r
private[yaml] val plusInfinity = "\\+?(\\.inf|\\.Inf|\\.INF)".r
private[yaml] val nan = "\\.nan|\\.NaN|\\.NAN".r

def resolveTag(value: String, style: Option[ScalarStyle] = None): Tag = {
val assumeString = style.exists(s => s == DoubleQuoted || s == SingleQuoted)
value match {
case null => nullTag
case _ if assumeString => str
case nullPattern(_*) => nullTag
case booleanPattern(_*) => boolean
case int10Pattern(_*) => int
case int8Pattern(_*) => int
case int16Pattern(_*) => int
case floatPattern(_*) => float
case minusInfinity(_*) => float
case plusInfinity(_*) => float
case _ => str
case null => nullTag
case _ if assumeString => str
case nullPattern(_*) => nullTag
case falsePattern(_*) => boolean
case truePattern(_*) => boolean
case int10Pattern(_*) => int
case int8Pattern(_*) => int
case int16Pattern(_*) => int
case floatPattern(_*) => float
case minusInfinity(_*) => float
case plusInfinity(_*) => float
case nan(_*) => float
case _ => str
}
}
}
69 changes: 43 additions & 26 deletions core/shared/src/main/scala/org/virtuslab/yaml/YamlDecoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,67 +92,83 @@ object YamlDecoder extends YamlDecoderCompanionCrossCompat {
tpe
)

implicit def forInt: YamlDecoder[Int] = YamlDecoder { case s @ ScalarNode(value, _) =>
val normalizedValue =
if (value.startsWith("0o")) value.stripPrefix("0o").prepended('0') else value
private def normalizeInt(string: String): String = {
val octal = if (string.startsWith("0o")) string.stripPrefix("0o").prepended('0') else string
octal.replaceAll("_", "")
}

Try(java.lang.Integer.decode(normalizedValue.replaceAll("_", "")).toInt).toEither.left
implicit def forInt: YamlDecoder[Int] = YamlDecoder { case s @ ScalarNode(value, _) =>
Try(java.lang.Integer.decode(normalizeInt(value)).toInt).toEither.left
.map(ConstructError.from(_, "Int", s))
}

implicit def forLong: YamlDecoder[Long] = YamlDecoder { case s @ ScalarNode(value, _) =>
val normalizedValue =
if (value.startsWith("0o")) value.stripPrefix("0o").prepended('0') else value

Try(java.lang.Long.decode(normalizedValue.replaceAll("_", "")).toLong).toEither.left
Try(java.lang.Long.decode(normalizeInt(value)).toLong).toEither.left
.map(ConstructError.from(_, "Long", s))
}

implicit def forDouble: YamlDecoder[Double] = YamlDecoder { case s @ ScalarNode(value, _) =>
val lowercased = value.toLowerCase
if (lowercased.endsWith("inf")) {
if (value.startsWith("-")) Right(Double.NegativeInfinity)
else Right(Double.PositiveInfinity)
} else if (lowercased.endsWith("nan")) {
if (Tag.nan.matches(value)) {
Right(Double.NaN)
} else if (Tag.plusInfinity.matches(value)) {
Right(Double.PositiveInfinity)
} else if (Tag.minusInfinity.matches(value)) {
Right(Double.NegativeInfinity)
} else {
Try(java.lang.Double.parseDouble(value.replaceAll("_", ""))).toEither.left
.map(ConstructError.from(_, "Double", s))
}
}

def forDoublePrecise: YamlDecoder[Double] = YamlDecoder { case s @ ScalarNode(value, _) =>
forDouble.construct(s).flatMap { n =>
val ns = n.toString
if (ns == value) Right(n) else Left(ConstructError.from(s"Double, decoded $ns", s))
}
}

implicit def forFloat: YamlDecoder[Float] = YamlDecoder { case s @ ScalarNode(value, _) =>
val lowercased = value.toLowerCase
if (lowercased.endsWith("inf")) {
if (value.startsWith("-")) Right(Float.NegativeInfinity)
else Right(Float.PositiveInfinity)
} else if (lowercased.endsWith("nan")) {
if (Tag.nan.matches(value)) {
Right(Float.NaN)
} else if (Tag.plusInfinity.matches(value)) {
Right(Float.PositiveInfinity)
} else if (Tag.minusInfinity.matches(value)) {
Right(Float.NegativeInfinity)
} else {
Try(java.lang.Float.parseFloat(value.replaceAll("_", ""))).toEither.left
.map(ConstructError.from(_, "Float", s))
}
}

implicit def forShort: YamlDecoder[Short] = YamlDecoder { case s @ ScalarNode(value, _) =>
val normalizedValue =
if (value.startsWith("0o")) value.stripPrefix("0o").prepended('0') else value
def forFloatPrecise: YamlDecoder[Float] = YamlDecoder { case s @ ScalarNode(value, _) =>
forFloat.construct(s).flatMap { n =>
val ns = n.toString
if (ns == value) Right(n) else Left(ConstructError.from(s"Float, decoded $ns", s))
}
}

Try(java.lang.Short.decode(normalizedValue.replaceAll("_", "")).toShort).toEither.left
implicit def forShort: YamlDecoder[Short] = YamlDecoder { case s @ ScalarNode(value, _) =>
Try(java.lang.Short.decode(normalizeInt(value)).toShort).toEither.left
.map(ConstructError.from(_, "Short", s))
}

implicit def forByte: YamlDecoder[Byte] = YamlDecoder { case s @ ScalarNode(value, _) =>
Try(java.lang.Byte.decode(value.replaceAll("_", "")).toByte).toEither.left
Try(java.lang.Byte.decode(normalizeInt(value)).toByte).toEither.left
.map(ConstructError.from(_, "Byte", s))
}

implicit def forBoolean: YamlDecoder[Boolean] = YamlDecoder { case s @ ScalarNode(value, _) =>
value.toBooleanOption.toRight(cannotParse(value, "Boolean", s))
if (Tag.falsePattern.matches(value)) {
Right(false)
} else if (Tag.truePattern.matches(value)) {
Right(true)
} else {
Left(cannotParse(value, "Boolean", s))
}
}

implicit def forBigInt: YamlDecoder[BigInt] = YamlDecoder { case s @ ScalarNode(value, _) =>
Try(BigInt(value.replaceAll("_", ""))).toEither.left
Try(BigInt(normalizeInt(value))).toEither.left
.map(ConstructError.from(_, "BigInt", s))
}

Expand All @@ -177,8 +193,9 @@ object YamlDecoder extends YamlDecoderCompanionCrossCompat {
.orElse(forBigInt.widen)
.construct(node)
case node @ ScalarNode(_, Tag.float) =>
forDouble
forFloatPrecise
.widen[Any]
.orElse(forDoublePrecise.widen)
.orElse(forBigDecimal.widen)
.construct(node)
case ScalarNode(value, Tag.str) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ class DecoderSuite extends munit.FunSuite:
123 -> 321,
"string" -> "aezakmi",
true -> false,
5.5f -> 55.55d
5.5f -> 55.55f
)

assertEquals(yaml.as[Map[Any, Any]], Right(expected))
Expand Down Expand Up @@ -565,3 +565,13 @@ class DecoderSuite extends munit.FunSuite:
assertEquals(foo.b, "from yaml")
assert(!evaluated)
}

test("Fails decoding -XXXinf as Float") {
val yaml = "-XXXinf"

yaml.as[Float] match
case Left(e: ConstructError) =>
assertEquals(e.expected, Some("Float"))
case Right(value) =>
fail(s"Should fail, but got $value")
}