From 5de1be20abd7b5c5841dbe749efa6c21728bb52d Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Thu, 10 Aug 2023 15:14:34 -0500 Subject: [PATCH] Add `float_hook` to json decoder This adds support for a `float_hook` to the json decoder. If set, this hook will be called to decode any untyped JSON float values from their raw string representations. This may be used to change the default float parsing from returning ``float`` values to return ``decimal.Decimal`` values again. Since this is an uncommon option, it's only available on the Decoder, rather than the top-level ``msgspec.json.decode`` function. --- msgspec/_core.c | 62 +++++++++++++++++++++++++++++++--- msgspec/json.pyi | 5 +++ tests/basic_typing_examples.py | 9 +++++ tests/test_json.py | 50 +++++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 5 deletions(-) diff --git a/msgspec/_core.c b/msgspec/_core.c index 1f5d5230..59a08d08 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -11092,6 +11092,22 @@ parse_number_nonfinite( return ms_post_decode_float(val, type, path, strict, true); } +static MS_NOINLINE PyObject * +json_float_hook( + const char *buf, Py_ssize_t size, PathNode *path, PyObject *float_hook +) { + PyObject *str = PyUnicode_New(size, 127); + if (str == NULL) return NULL; + memcpy(ascii_get_buffer(str), buf, size); + PyObject *out = CALL_ONE_ARG(float_hook, str); + Py_DECREF(str); + if (out == NULL) { + ms_maybe_wrap_validation_error(path); + return NULL; + } + return out; +} + static MS_INLINE PyObject * parse_number_inline( const unsigned char *p, @@ -11101,6 +11117,7 @@ parse_number_inline( TypeNode *type, PathNode *path, bool strict, + PyObject *float_hook, bool from_str ) { uint64_t mantissa = 0; @@ -11286,6 +11303,9 @@ parse_number_inline( (char *)start, p - start, true, path, NULL ); } + else if (MS_UNLIKELY(float_hook != NULL && type->types & MS_TYPE_ANY)) { + return json_float_hook((char *)start, p - start, path, float_hook); + } else { if (MS_UNLIKELY(exponent > 288 || exponent < -307)) { /* Exponent is out of bounds */ @@ -11363,6 +11383,7 @@ maybe_parse_number( type, path, strict, + NULL, true ); return (*out != NULL || errmsg == NULL); @@ -15403,6 +15424,7 @@ typedef struct JSONDecoderState { /* Configuration */ TypeNode *type; PyObject *dec_hook; + PyObject *float_hook; bool strict; /* Temporary scratch space */ @@ -15425,10 +15447,11 @@ typedef struct JSONDecoder { TypeNode *type; char strict; PyObject *dec_hook; + PyObject *float_hook; } JSONDecoder; PyDoc_STRVAR(JSONDecoder__doc__, -"Decoder(type='Any', *, strict=True, dec_hook=None)\n" +"Decoder(type='Any', *, strict=True, dec_hook=None, float_hook=None)\n" "--\n" "\n" "A JSON decoder.\n" @@ -15449,19 +15472,28 @@ PyDoc_STRVAR(JSONDecoder__doc__, " signature ``dec_hook(type: Type, obj: Any) -> Any``, where ``type`` is the\n" " expected message type, and ``obj`` is the decoded representation composed\n" " of only basic JSON types. This hook should transform ``obj`` into type\n" -" ``type``, or raise a ``NotImplementedError`` if unsupported." +" ``type``, or raise a ``NotImplementedError`` if unsupported.\n" +"float_hook : callable, optional\n" +" An optional callback for handling decoding untyped float literals. Should\n" +" have the signature ``float_hook(val: str) -> Any``, where ``val`` is the\n" +" raw string value of the JSON float. This hook is called to decode any\n" +" \"untyped\" float value (e.g. ``typing.Any`` typed). The default is\n" +" equivalent to ``float_hook=float``, where all untyped JSON floats are\n" +" decoded as python floats. Specifying ``float_hook=decimal.Decimal``\n" +" will decode all untyped JSON floats as decimals instead." ); static int JSONDecoder_init(JSONDecoder *self, PyObject *args, PyObject *kwds) { - char *kwlist[] = {"type", "strict", "dec_hook", NULL}; + char *kwlist[] = {"type", "strict", "dec_hook", "float_hook", NULL}; MsgspecState *st = msgspec_get_global_state(); PyObject *type = st->typing_any; PyObject *dec_hook = NULL; + PyObject *float_hook = NULL; int strict = 1; if (!PyArg_ParseTupleAndKeywords( - args, kwds, "|O$pO", kwlist, &type, &strict, &dec_hook) + args, kwds, "|O$pOO", kwlist, &type, &strict, &dec_hook, &float_hook) ) { return -1; } @@ -15479,6 +15511,19 @@ JSONDecoder_init(JSONDecoder *self, PyObject *args, PyObject *kwds) } self->dec_hook = dec_hook; + /* Handle float_hook */ + if (float_hook == Py_None) { + float_hook = NULL; + } + if (float_hook != NULL) { + if (!PyCallable_Check(float_hook)) { + PyErr_SetString(PyExc_TypeError, "float_hook must be callable"); + return -1; + } + Py_INCREF(float_hook); + } + self->float_hook = float_hook; + /* Handle strict */ self->strict = strict; @@ -15498,6 +15543,7 @@ JSONDecoder_traverse(JSONDecoder *self, visitproc visit, void *arg) if (out != 0) return out; Py_VISIT(self->orig_type); Py_VISIT(self->dec_hook); + Py_VISIT(self->float_hook); return 0; } @@ -15508,6 +15554,7 @@ JSONDecoder_dealloc(JSONDecoder *self) TypeNode_Free(self->type); Py_XDECREF(self->orig_type); Py_XDECREF(self->dec_hook); + Py_XDECREF(self->float_hook); Py_TYPE(self)->tp_free((PyObject *)self); } @@ -17551,7 +17598,7 @@ json_maybe_decode_number(JSONDecoderState *self, TypeNode *type, PathNode *path) PyObject *out = parse_number_inline( self->input_pos, self->input_end, &pout, &errmsg, - type, path, self->strict, false + type, path, self->strict, self->float_hook, false ); self->input_pos = (unsigned char *)pout; @@ -18014,6 +18061,7 @@ msgspec_json_format(PyObject *self, PyObject *args, PyObject *kwargs) /* Init decoder */ dec.dec_hook = NULL; + dec.float_hook = NULL; dec.type = NULL; dec.scratch = NULL; dec.scratch_capacity = 0; @@ -18095,6 +18143,7 @@ JSONDecoder_decode(JSONDecoder *self, PyObject *const *args, Py_ssize_t nargs) .type = self->type, .strict = self->strict, .dec_hook = self->dec_hook, + .float_hook = self->float_hook, .scratch = NULL, .scratch_capacity = 0, .scratch_len = 0 @@ -18161,6 +18210,7 @@ JSONDecoder_decode_lines(JSONDecoder *self, PyObject *const *args, Py_ssize_t na .type = self->type, .strict = self->strict, .dec_hook = self->dec_hook, + .float_hook = self->float_hook, .scratch = NULL, .scratch_capacity = 0, .scratch_len = 0 @@ -18237,6 +18287,7 @@ static PyMemberDef JSONDecoder_members[] = { {"type", T_OBJECT_EX, offsetof(JSONDecoder, orig_type), READONLY, "The Decoder type"}, {"strict", T_BOOL, offsetof(JSONDecoder, strict), READONLY, "The Decoder strict setting"}, {"dec_hook", T_OBJECT, offsetof(JSONDecoder, dec_hook), READONLY, "The Decoder dec_hook"}, + {"float_hook", T_OBJECT, offsetof(JSONDecoder, float_hook), READONLY, "The Decoder float_hook"}, {NULL}, }; @@ -18334,6 +18385,7 @@ msgspec_json_decode(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyO JSONDecoderState state = { .strict = strict, .dec_hook = dec_hook, + .float_hook = NULL, .scratch = NULL, .scratch_capacity = 0, .scratch_len = 0 diff --git a/msgspec/json.pyi b/msgspec/json.pyi index 53f9253a..30e87a92 100644 --- a/msgspec/json.pyi +++ b/msgspec/json.pyi @@ -18,6 +18,7 @@ T = TypeVar("T") enc_hook_sig = Optional[Callable[[Any], Any]] dec_hook_sig = Optional[Callable[[type, Any], Any]] +float_hook_sig = Optional[Callable[[str], Any]] class Encoder: enc_hook: enc_hook_sig @@ -41,6 +42,7 @@ class Decoder(Generic[T]): type: Type[T] strict: bool dec_hook: dec_hook_sig + float_hook: float_hook_sig @overload def __init__( @@ -48,6 +50,7 @@ class Decoder(Generic[T]): *, strict: bool = True, dec_hook: dec_hook_sig = None, + float_hook: float_hook_sig = None, ) -> None: ... @overload def __init__( @@ -56,6 +59,7 @@ class Decoder(Generic[T]): *, strict: bool = True, dec_hook: dec_hook_sig = None, + float_hook: float_hook_sig = None, ) -> None: ... @overload def __init__( @@ -64,6 +68,7 @@ class Decoder(Generic[T]): *, strict: bool = True, dec_hook: dec_hook_sig = None, + float_hook: float_hook_sig = None, ) -> None: ... def decode(self, data: Union[bytes, str]) -> T: ... def decode_lines(self, data: Union[bytes, str]) -> list[T]: ... diff --git a/tests/basic_typing_examples.py b/tests/basic_typing_examples.py index f560ec8d..21cd3fa4 100644 --- a/tests/basic_typing_examples.py +++ b/tests/basic_typing_examples.py @@ -2,6 +2,7 @@ from __future__ import annotations import datetime +import decimal import pickle from typing import Any, Dict, Final, List, Type, Union @@ -826,6 +827,14 @@ def dec_hook(typ: Type, obj: Any) -> Any: msgspec.json.Decoder(dec_hook=dec_hook) +def check_json_Decoder_float_hook() -> None: + msgspec.json.Decoder(float_hook=None) + msgspec.json.Decoder(float_hook=float) + dec = msgspec.json.Decoder(float_hook=decimal.Decimal) + if dec.float_hook is not None: + dec.float_hook("1.5") + + def check_json_Decoder_strict() -> None: dec = msgspec.json.Decoder(List[int], strict=False) reveal_type(dec.strict) # assert "bool" in typ diff --git a/tests/test_json.py b/tests/test_json.py index 7ffa0b03..898c455b 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -2,6 +2,7 @@ import base64 import datetime +import decimal import enum import gc import itertools @@ -524,6 +525,19 @@ def test_decode_lines_bad_call(self): with pytest.raises(TypeError): dec.decode(1) + def test_decoder_init_float_hook(self): + dec = msgspec.json.Decoder() + assert dec.float_hook is None + + dec = msgspec.json.Decoder(float_hook=None) + assert dec.float_hook is None + + dec = msgspec.json.Decoder(float_hook=decimal.Decimal) + assert dec.float_hook is decimal.Decimal + + with pytest.raises(TypeError): + dec = msgspec.json.Decoder(float_hook=1) + class TestBoolAndNone: def test_encode_none(self): @@ -1567,6 +1581,42 @@ def test_decode_float_err_expected_int(self, s): ): msgspec.json.decode(s, type=int) + def test_float_hook_untyped(self): + dec = msgspec.json.Decoder(float_hook=decimal.Decimal) + res = dec.decode(b"1.33") + assert res == decimal.Decimal("1.33") + assert type(res) is decimal.Decimal + + def test_float_hook_typed(self): + class Ex(msgspec.Struct): + a: float + b: decimal.Decimal + c: Any + d: Any + + class MyFloat(NamedTuple): + x: str + + dec = msgspec.json.Decoder(Ex, float_hook=MyFloat) + res = dec.decode(b'{"a": 1.5, "b": 1.3, "c": 1.3, "d": 123}') + sol = Ex(1.5, decimal.Decimal("1.3"), MyFloat("1.3"), 123) + assert res == sol + + def test_float_hook_error(self): + def float_hook(val): + raise ValueError("Oh no!") + + class Ex(msgspec.Struct): + a: float + b: Any + + dec = msgspec.json.Decoder(Ex, float_hook=float_hook) + assert dec.decode(b'{"a": 1.5, "b": 2}') == Ex(a=1.5, b=2) + with pytest.raises(msgspec.ValidationError) as rec: + dec.decode(b'{"a": 1.5, "b": 2.5}') + assert "Oh no!" in str(rec.value) + assert "at `$.b`" in str(rec.value) + class TestDecimal: """Most decimal tests are in test_common.py, the ones here are for json