Skip to content

Commit

Permalink
Add float_hook to json decoder
Browse files Browse the repository at this point in the history
This adds support for a `float_hook` to the json decoder. If set, this
hook will be called to decode any untyped JSON float values from their
raw string representations. This may be used to change the default float
parsing from returning ``float`` values to return ``decimal.Decimal``
values again.

Since this is an uncommon option, it's only available on the Decoder,
rather than the top-level ``msgspec.json.decode`` function.
  • Loading branch information
jcrist committed Aug 10, 2023
1 parent ef7808b commit 5de1be2
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 5 deletions.
62 changes: 57 additions & 5 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -11092,6 +11092,22 @@ parse_number_nonfinite(
return ms_post_decode_float(val, type, path, strict, true);
}

static MS_NOINLINE PyObject *
json_float_hook(
const char *buf, Py_ssize_t size, PathNode *path, PyObject *float_hook
) {
PyObject *str = PyUnicode_New(size, 127);
if (str == NULL) return NULL;
memcpy(ascii_get_buffer(str), buf, size);
PyObject *out = CALL_ONE_ARG(float_hook, str);
Py_DECREF(str);
if (out == NULL) {
ms_maybe_wrap_validation_error(path);
return NULL;
}
return out;
}

static MS_INLINE PyObject *
parse_number_inline(
const unsigned char *p,
Expand All @@ -11101,6 +11117,7 @@ parse_number_inline(
TypeNode *type,
PathNode *path,
bool strict,
PyObject *float_hook,
bool from_str
) {
uint64_t mantissa = 0;
Expand Down Expand Up @@ -11286,6 +11303,9 @@ parse_number_inline(
(char *)start, p - start, true, path, NULL
);
}
else if (MS_UNLIKELY(float_hook != NULL && type->types & MS_TYPE_ANY)) {
return json_float_hook((char *)start, p - start, path, float_hook);
}
else {
if (MS_UNLIKELY(exponent > 288 || exponent < -307)) {
/* Exponent is out of bounds */
Expand Down Expand Up @@ -11363,6 +11383,7 @@ maybe_parse_number(
type,
path,
strict,
NULL,
true
);
return (*out != NULL || errmsg == NULL);
Expand Down Expand Up @@ -15403,6 +15424,7 @@ typedef struct JSONDecoderState {
/* Configuration */
TypeNode *type;
PyObject *dec_hook;
PyObject *float_hook;
bool strict;

/* Temporary scratch space */
Expand All @@ -15425,10 +15447,11 @@ typedef struct JSONDecoder {
TypeNode *type;
char strict;
PyObject *dec_hook;
PyObject *float_hook;
} JSONDecoder;

PyDoc_STRVAR(JSONDecoder__doc__,
"Decoder(type='Any', *, strict=True, dec_hook=None)\n"
"Decoder(type='Any', *, strict=True, dec_hook=None, float_hook=None)\n"
"--\n"
"\n"
"A JSON decoder.\n"
Expand All @@ -15449,19 +15472,28 @@ PyDoc_STRVAR(JSONDecoder__doc__,
" signature ``dec_hook(type: Type, obj: Any) -> Any``, where ``type`` is the\n"
" expected message type, and ``obj`` is the decoded representation composed\n"
" of only basic JSON types. This hook should transform ``obj`` into type\n"
" ``type``, or raise a ``NotImplementedError`` if unsupported."
" ``type``, or raise a ``NotImplementedError`` if unsupported.\n"
"float_hook : callable, optional\n"
" An optional callback for handling decoding untyped float literals. Should\n"
" have the signature ``float_hook(val: str) -> Any``, where ``val`` is the\n"
" raw string value of the JSON float. This hook is called to decode any\n"
" \"untyped\" float value (e.g. ``typing.Any`` typed). The default is\n"
" equivalent to ``float_hook=float``, where all untyped JSON floats are\n"
" decoded as python floats. Specifying ``float_hook=decimal.Decimal``\n"
" will decode all untyped JSON floats as decimals instead."
);
static int
JSONDecoder_init(JSONDecoder *self, PyObject *args, PyObject *kwds)
{
char *kwlist[] = {"type", "strict", "dec_hook", NULL};
char *kwlist[] = {"type", "strict", "dec_hook", "float_hook", NULL};
MsgspecState *st = msgspec_get_global_state();
PyObject *type = st->typing_any;
PyObject *dec_hook = NULL;
PyObject *float_hook = NULL;
int strict = 1;

if (!PyArg_ParseTupleAndKeywords(
args, kwds, "|O$pO", kwlist, &type, &strict, &dec_hook)
args, kwds, "|O$pOO", kwlist, &type, &strict, &dec_hook, &float_hook)
) {
return -1;
}
Expand All @@ -15479,6 +15511,19 @@ JSONDecoder_init(JSONDecoder *self, PyObject *args, PyObject *kwds)
}
self->dec_hook = dec_hook;

/* Handle float_hook */
if (float_hook == Py_None) {
float_hook = NULL;
}
if (float_hook != NULL) {
if (!PyCallable_Check(float_hook)) {
PyErr_SetString(PyExc_TypeError, "float_hook must be callable");
return -1;
}
Py_INCREF(float_hook);
}
self->float_hook = float_hook;

/* Handle strict */
self->strict = strict;

Expand All @@ -15498,6 +15543,7 @@ JSONDecoder_traverse(JSONDecoder *self, visitproc visit, void *arg)
if (out != 0) return out;
Py_VISIT(self->orig_type);
Py_VISIT(self->dec_hook);
Py_VISIT(self->float_hook);
return 0;
}

Expand All @@ -15508,6 +15554,7 @@ JSONDecoder_dealloc(JSONDecoder *self)
TypeNode_Free(self->type);
Py_XDECREF(self->orig_type);
Py_XDECREF(self->dec_hook);
Py_XDECREF(self->float_hook);
Py_TYPE(self)->tp_free((PyObject *)self);
}

Expand Down Expand Up @@ -17551,7 +17598,7 @@ json_maybe_decode_number(JSONDecoderState *self, TypeNode *type, PathNode *path)
PyObject *out = parse_number_inline(
self->input_pos, self->input_end,
&pout, &errmsg,
type, path, self->strict, false
type, path, self->strict, self->float_hook, false
);
self->input_pos = (unsigned char *)pout;

Expand Down Expand Up @@ -18014,6 +18061,7 @@ msgspec_json_format(PyObject *self, PyObject *args, PyObject *kwargs)

/* Init decoder */
dec.dec_hook = NULL;
dec.float_hook = NULL;
dec.type = NULL;
dec.scratch = NULL;
dec.scratch_capacity = 0;
Expand Down Expand Up @@ -18095,6 +18143,7 @@ JSONDecoder_decode(JSONDecoder *self, PyObject *const *args, Py_ssize_t nargs)
.type = self->type,
.strict = self->strict,
.dec_hook = self->dec_hook,
.float_hook = self->float_hook,
.scratch = NULL,
.scratch_capacity = 0,
.scratch_len = 0
Expand Down Expand Up @@ -18161,6 +18210,7 @@ JSONDecoder_decode_lines(JSONDecoder *self, PyObject *const *args, Py_ssize_t na
.type = self->type,
.strict = self->strict,
.dec_hook = self->dec_hook,
.float_hook = self->float_hook,
.scratch = NULL,
.scratch_capacity = 0,
.scratch_len = 0
Expand Down Expand Up @@ -18237,6 +18287,7 @@ static PyMemberDef JSONDecoder_members[] = {
{"type", T_OBJECT_EX, offsetof(JSONDecoder, orig_type), READONLY, "The Decoder type"},
{"strict", T_BOOL, offsetof(JSONDecoder, strict), READONLY, "The Decoder strict setting"},
{"dec_hook", T_OBJECT, offsetof(JSONDecoder, dec_hook), READONLY, "The Decoder dec_hook"},
{"float_hook", T_OBJECT, offsetof(JSONDecoder, float_hook), READONLY, "The Decoder float_hook"},
{NULL},
};

Expand Down Expand Up @@ -18334,6 +18385,7 @@ msgspec_json_decode(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyO
JSONDecoderState state = {
.strict = strict,
.dec_hook = dec_hook,
.float_hook = NULL,
.scratch = NULL,
.scratch_capacity = 0,
.scratch_len = 0
Expand Down
5 changes: 5 additions & 0 deletions msgspec/json.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ T = TypeVar("T")

enc_hook_sig = Optional[Callable[[Any], Any]]
dec_hook_sig = Optional[Callable[[type, Any], Any]]
float_hook_sig = Optional[Callable[[str], Any]]

class Encoder:
enc_hook: enc_hook_sig
Expand All @@ -41,13 +42,15 @@ class Decoder(Generic[T]):
type: Type[T]
strict: bool
dec_hook: dec_hook_sig
float_hook: float_hook_sig

@overload
def __init__(
self: Decoder[Any],
*,
strict: bool = True,
dec_hook: dec_hook_sig = None,
float_hook: float_hook_sig = None,
) -> None: ...
@overload
def __init__(
Expand All @@ -56,6 +59,7 @@ class Decoder(Generic[T]):
*,
strict: bool = True,
dec_hook: dec_hook_sig = None,
float_hook: float_hook_sig = None,
) -> None: ...
@overload
def __init__(
Expand All @@ -64,6 +68,7 @@ class Decoder(Generic[T]):
*,
strict: bool = True,
dec_hook: dec_hook_sig = None,
float_hook: float_hook_sig = None,
) -> None: ...
def decode(self, data: Union[bytes, str]) -> T: ...
def decode_lines(self, data: Union[bytes, str]) -> list[T]: ...
Expand Down
9 changes: 9 additions & 0 deletions tests/basic_typing_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import datetime
import decimal
import pickle
from typing import Any, Dict, Final, List, Type, Union

Expand Down Expand Up @@ -826,6 +827,14 @@ def dec_hook(typ: Type, obj: Any) -> Any:
msgspec.json.Decoder(dec_hook=dec_hook)


def check_json_Decoder_float_hook() -> None:
msgspec.json.Decoder(float_hook=None)
msgspec.json.Decoder(float_hook=float)
dec = msgspec.json.Decoder(float_hook=decimal.Decimal)
if dec.float_hook is not None:
dec.float_hook("1.5")


def check_json_Decoder_strict() -> None:
dec = msgspec.json.Decoder(List[int], strict=False)
reveal_type(dec.strict) # assert "bool" in typ
Expand Down
50 changes: 50 additions & 0 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import base64
import datetime
import decimal
import enum
import gc
import itertools
Expand Down Expand Up @@ -524,6 +525,19 @@ def test_decode_lines_bad_call(self):
with pytest.raises(TypeError):
dec.decode(1)

def test_decoder_init_float_hook(self):
dec = msgspec.json.Decoder()
assert dec.float_hook is None

dec = msgspec.json.Decoder(float_hook=None)
assert dec.float_hook is None

dec = msgspec.json.Decoder(float_hook=decimal.Decimal)
assert dec.float_hook is decimal.Decimal

with pytest.raises(TypeError):
dec = msgspec.json.Decoder(float_hook=1)


class TestBoolAndNone:
def test_encode_none(self):
Expand Down Expand Up @@ -1567,6 +1581,42 @@ def test_decode_float_err_expected_int(self, s):
):
msgspec.json.decode(s, type=int)

def test_float_hook_untyped(self):
dec = msgspec.json.Decoder(float_hook=decimal.Decimal)
res = dec.decode(b"1.33")
assert res == decimal.Decimal("1.33")
assert type(res) is decimal.Decimal

def test_float_hook_typed(self):
class Ex(msgspec.Struct):
a: float
b: decimal.Decimal
c: Any
d: Any

class MyFloat(NamedTuple):
x: str

dec = msgspec.json.Decoder(Ex, float_hook=MyFloat)
res = dec.decode(b'{"a": 1.5, "b": 1.3, "c": 1.3, "d": 123}')
sol = Ex(1.5, decimal.Decimal("1.3"), MyFloat("1.3"), 123)
assert res == sol

def test_float_hook_error(self):
def float_hook(val):
raise ValueError("Oh no!")

class Ex(msgspec.Struct):
a: float
b: Any

dec = msgspec.json.Decoder(Ex, float_hook=float_hook)
assert dec.decode(b'{"a": 1.5, "b": 2}') == Ex(a=1.5, b=2)
with pytest.raises(msgspec.ValidationError) as rec:
dec.decode(b'{"a": 1.5, "b": 2.5}')
assert "Oh no!" in str(rec.value)
assert "at `$.b`" in str(rec.value)


class TestDecimal:
"""Most decimal tests are in test_common.py, the ones here are for json
Expand Down

0 comments on commit 5de1be2

Please sign in to comment.