Skip to content

Commit

Permalink
add UnmarshalJSON and MarshalJSON to Decimal (#191)
Browse files Browse the repository at this point in the history
* add UnmarshalJSON and MarshalJSON to Decimal

* add more tests
  • Loading branch information
plopezlpz authored Mar 29, 2023
1 parent ec09c68 commit caa4fab
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 0 deletions.
48 changes: 48 additions & 0 deletions ion/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,51 @@ func (d *Decimal) String() string {
return b.String()
}
}

// UnmarshalJSON implements the json.Unmarshaler interface.
func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error {
str := string(decimalBytes)
if str == "null" {
return nil
}
str = strings.Replace(str, "E", "D", 1)
str = strings.Replace(str, "e", "d", 1)
parsed, err := ParseDecimal(str)
if err != nil {
return fmt.Errorf("error unmarshalling decimal '%s': %w", str, err)
}
*d = *parsed
return nil
}

// MarshalJSON implements the json.Marshaler interface.
func (d *Decimal) MarshalJSON() ([]byte, error) {
absN := new(big.Int).Abs(d.n).String()
scale := int(-d.scale)
sign := d.n.Sign()

var str string
if scale == 0 {
str = absN
} else if scale > 0 {
// add zeroes to the right
str = absN + strings.Repeat("0", scale)
} else {
// add zeroes to the left
absScale := -scale
nLen := len(absN)

if absScale >= nLen {
str = "0." + strings.Repeat("0", absScale-nLen) + absN
} else {
str = absN[:nLen-absScale] + "." + absN[nLen-absScale:]
}
str = strings.TrimRight(str, "0")
str = strings.TrimSuffix(str, ".")
}

if sign == -1 {
str = "-" + str
}
return []byte(str), nil
}
70 changes: 70 additions & 0 deletions ion/decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package ion

import (
"encoding/json"
"fmt"
"math/big"
"testing"
Expand Down Expand Up @@ -329,3 +330,72 @@ func TestUpscale(t *testing.T) {
actual := d.upscale(4).String()
assert.Equal(t, "10.0000", actual)
}

func TestMarshalJSON(t *testing.T) {
test := func(a string, expected string) {
t.Run("("+a+")", func(t *testing.T) {
ad, err := ParseDecimal(a)
require.NoError(t, err)

am, err := ad.MarshalJSON()
require.NoError(t, err)

assert.Equal(t, []byte(expected), am)
})
}
test("123000", "123000")

test("1.01", "1.01")
test("0.01", "0.01")
test("0.0", "0")
test("0.123456789012345678901234567890", "0.12345678901234567890123456789") // Trims trailing zeros
test("123456789012345678901234567890.123456789012345678901234567890", "123456789012345678901234567890.12345678901234567890123456789") // Trims trailing zeros

test("1d-2", "0.01")
test("1d-3", "0.001")
test("1d2", "100")

test("-1d3", "-1000")
test("-1d-3", "-0.001")
test("-0.0", "0")
test("-0.1", "-0.1")
}

func TestUnmarshalJSON(t *testing.T) {
test := func(a string, expected string) {
t.Run("("+a+")", func(t *testing.T) {
expectedDec := MustParseDecimal(expected)

var r struct {
D *Decimal `json:"d"`
}
err := json.Unmarshal([]byte(`{"d":`+a+`}`), &r)
require.NoError(t, err)

assert.Truef(t, expectedDec.Equal(r.D), "expected %v, got %v", expected, r.D)
})
}

test("123000", "123000")
test("123.1", "123.1")
test("123.10", "123.1")
test("-123000", "-123000")
test("-123.1", "-123.1")
test("-123.10", "-123.1")

test("1e+2", "100")
test("1e2", "100")
test("1E2", "100")
test("1E+2", "100")

test("-1e+2", "-100")
test("-1e2", "-100")
test("-1E2", "-100")
test("-1E+2", "-100")

test("1e-2", "0.01")
test("1E-2", "0.01")

test("-1e-2", "-0.01")
test("-1E-2", "-0.01")
}

0 comments on commit caa4fab

Please sign in to comment.