Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Python 3.12's type aliases #606

Merged
merged 1 commit into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions docs/source/supported-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ Most combinations of the following types are supported (with a few restrictions)
- `typing.Literal`
- `typing.NewType`
- `typing.Final`
- `typing.TypeAliasType`
- `typing.TypeAlias`
- `typing.NamedTuple` / `collections.namedtuple`
- `typing.TypedDict`
- `typing.Generic`
Expand Down Expand Up @@ -1170,6 +1172,36 @@ support here is purely to aid static analysis tools like mypy_ or pyright_.
File "<stdin>", line 1, in <module>
msgspec.ValidationError: Expected `int`, got `str`

Type Aliases
------------

For complex types, sometimes it can be nice to write the type once so you can
reuse it later.

.. code-block:: python

Point = tuple[float, float]

Here ``Point`` is a "type alias" for ``tuple[float, float]`` - ``msgspec``
will substitute in ``tuple[float, float]`` whenever the ``Point`` type
is used in an annotation.

``msgspec`` supports the following equivalent forms:

.. code-block:: python

# Using variable assignment
Point = tuple[float, float]

# Using variable assignment, annotated as a `TypeAlias`
Point: TypeAlias = tuple[float, float]

# Using Python 3.12's new `type` statement. This only works on Python 3.12+
type Point = tuple[float, float]

To learn more about Type Aliases, see Python's `Type Alias docs here
<https://docs.python.org/3/library/typing.html#type-aliases>`__.

Generic Types
-------------

Expand Down
107 changes: 83 additions & 24 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
#include "ryu.h"
#include "atof.h"

/* Python version checks */
#define PY39_PLUS (PY_VERSION_HEX >= 0x03090000)
#define PY310_PLUS (PY_VERSION_HEX >= 0x030a0000)
#define PY311_PLUS (PY_VERSION_HEX >= 0x030b0000)
#define PY312_PLUS (PY_VERSION_HEX >= 0x030c0000)

/* Hint to the compiler not to store `x` in a register since it is likely to
* change. Results in much higher performance on GCC, with smaller benefits on
* clang */
Expand All @@ -36,18 +42,18 @@ ms_popcount(uint64_t i) { \
}
#endif

#if PY_VERSION_HEX < 0x03090000
#define CALL_ONE_ARG(f, a) PyObject_CallFunctionObjArgs(f, a, NULL)
#define CALL_NO_ARGS(f) PyObject_CallFunctionObjArgs(f, NULL)
#define CALL_METHOD_ONE_ARG(o, n, a) PyObject_CallMethodObjArgs(o, n, a, NULL)
#define CALL_METHOD_NO_ARGS(o, n) PyObject_CallMethodObjArgs(o, n, NULL)
#define SET_SIZE(obj, size) (((PyVarObject *)obj)->ob_size = size)
#else
#if PY39_PLUS
#define CALL_ONE_ARG(f, a) PyObject_CallOneArg(f, a)
#define CALL_NO_ARGS(f) PyObject_CallNoArgs(f)
#define CALL_METHOD_ONE_ARG(o, n, a) PyObject_CallMethodOneArg(o, n, a)
#define CALL_METHOD_NO_ARGS(o, n) PyObject_CallMethodNoArgs(o, n)
#define SET_SIZE(obj, size) Py_SET_SIZE(obj, size)
#else
#define CALL_ONE_ARG(f, a) PyObject_CallFunctionObjArgs(f, a, NULL)
#define CALL_NO_ARGS(f) PyObject_CallFunctionObjArgs(f, NULL)
#define CALL_METHOD_ONE_ARG(o, n, a) PyObject_CallMethodObjArgs(o, n, a, NULL)
#define CALL_METHOD_NO_ARGS(o, n) PyObject_CallMethodObjArgs(o, n, NULL)
#define SET_SIZE(obj, size) (((PyVarObject *)obj)->ob_size = size)
#endif

#define DIV_ROUND_CLOSEST(n, d) ((((n) < 0) == ((d) < 0)) ? (((n) + (d)/2)/(d)) : (((n) - (d)/2)/(d)))
Expand Down Expand Up @@ -157,7 +163,7 @@ fast_long_extract_parts(PyObject *vv, bool *neg, uint64_t *scale) {
uint64_t prev, x = 0;
bool negative;

#if PY_VERSION_HEX >= 0x030c0000
#if PY312_PLUS
/* CPython 3.12 changed int internal representation */
int sign = 1 - (v->long_value.lv_tag & _PyLong_SIGN_MASK);
negative = sign == -1;
Expand Down Expand Up @@ -405,6 +411,9 @@ typedef struct {
PyObject *str___dataclass_fields__;
PyObject *str___attrs_attrs__;
PyObject *str___supertype__;
#if PY312_PLUS
PyObject *str___value__;
#endif
PyObject *str___bound__;
PyObject *str___constraints__;
PyObject *str_int;
Expand All @@ -427,8 +436,11 @@ typedef struct {
PyObject *get_typeddict_info;
PyObject *get_dataclass_info;
PyObject *rebuild;
#if PY_VERSION_HEX >= 0x030a00f0
#if PY310_PLUS
PyObject *types_uniontype;
#endif
#if PY312_PLUS
PyObject *typing_typealiastype;
#endif
PyObject *astimezone;
PyObject *re_compile;
Expand Down Expand Up @@ -2122,7 +2134,7 @@ PyTypeObject NoDefault_Type = {
.tp_basicsize = 0
};

#if PY_VERSION_HEX >= 0x030c0000
#if PY312_PLUS
PyObject _NoDefault_Object = {
_PyObject_EXTRA_INIT
{ _Py_IMMORTAL_REFCNT },
Expand Down Expand Up @@ -2226,7 +2238,7 @@ PyTypeObject Unset_Type = {
.tp_basicsize = 0
};

#if PY_VERSION_HEX >= 0x030c0000
#if PY312_PLUS
PyObject _Unset_Object = {
_PyObject_EXTRA_INIT
{ _Py_IMMORTAL_REFCNT },
Expand Down Expand Up @@ -4459,6 +4471,21 @@ typenode_origin_args_metadata(
t = temp;
continue;
}
/* Check for parametrized TypeAliasType if Python 3.12+ */
#if PY312_PLUS
if (Py_TYPE(origin) == (PyTypeObject *)(state->mod->typing_typealiastype)) {
PyObject *value = PyObject_GetAttr(origin, state->mod->str___value__);
if (value == NULL) goto error;
PyObject *temp = PyObject_GetItem(value, args);
Py_DECREF(value);
if (temp == NULL) goto error;
Py_CLEAR(args);
Py_CLEAR(origin);
Py_DECREF(t);
t = temp;
continue;
}
#endif
}
else {
/* Custom non-parametrized generics won't have __args__
Expand Down Expand Up @@ -4487,14 +4514,23 @@ typenode_origin_args_metadata(
t = supertype;
continue;
}
else {
PyErr_Clear();
break;
PyErr_Clear();

/* Check for TypeAliasType if Python 3.12+ */
#if PY312_PLUS
if (Py_TYPE(t) == (PyTypeObject *)(state->mod->typing_typealiastype)) {
PyObject *value = PyObject_GetAttr(t, state->mod->str___value__);
if (value == NULL) goto error;
Py_DECREF(t);
t = value;
continue;
}
#endif
break;
}
}

#if PY_VERSION_HEX >= 0x030a00f0
#if PY310_PLUS
if (Py_TYPE(t) == (PyTypeObject *)(state->mod->types_uniontype)) {
/* Handle types.UnionType unions (`int | float | ...`) */
args = PyObject_GetAttr(t, state->mod->str___args__);
Expand Down Expand Up @@ -4692,13 +4728,18 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) {
}
}
else if (origin == state->mod->typing_union) {
if (Py_EnterRecursiveCall(" while analyzing a type")) {
out = -1;
goto done;
}
for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(args); i++) {
PyObject *arg = PyTuple_GET_ITEM(args, i);
/* Ignore UnsetType in unions */
if (arg == (PyObject *)(&Unset_Type)) continue;
out = typenode_collect_type(state, arg);
if (out < 0) break;
}
Py_LeaveRecursiveCall();
}
else if (origin == state->mod->typing_literal) {
if (state->literals == NULL) {
Expand Down Expand Up @@ -4761,6 +4802,8 @@ TypeNode_Convert(PyObject *obj) {
state.mod = msgspec_get_global_state();
state.context = obj;

if (Py_EnterRecursiveCall(" while analyzing a type")) return NULL;

/* Traverse `obj` to collect all type annotations at this level */
if (typenode_collect_type(&state, obj) < 0) goto done;
/* Handle structs in a second pass */
Expand All @@ -4773,6 +4816,7 @@ TypeNode_Convert(PyObject *obj) {
out = typenode_from_collect_state(&state);
done:
typenode_collect_clear_state(&state);
Py_LeaveRecursiveCall();
return out;
}

Expand Down Expand Up @@ -9717,14 +9761,14 @@ ms_encode_err_type_unsupported(PyTypeObject *type) {
*************************************************************************/

#define MS_HAS_TZINFO(o) (((_PyDateTime_BaseTZInfo *)(o))->hastzinfo)
#if PY_VERSION_HEX < 0x030a00f0
#if PY310_PLUS
#define MS_DATE_GET_TZINFO(o) PyDateTime_DATE_GET_TZINFO(o)
#define MS_TIME_GET_TZINFO(o) PyDateTime_TIME_GET_TZINFO(o)
#else
#define MS_DATE_GET_TZINFO(o) (MS_HAS_TZINFO(o) ? \
((PyDateTime_DateTime *)(o))->tzinfo : Py_None)
#define MS_TIME_GET_TZINFO(o) (MS_HAS_TZINFO(o) ? \
((PyDateTime_Time *)(o))->tzinfo : Py_None)
#else
#define MS_DATE_GET_TZINFO(o) PyDateTime_DATE_GET_TZINFO(o)
#define MS_TIME_GET_TZINFO(o) PyDateTime_TIME_GET_TZINFO(o)
#endif

#ifndef TIMEZONE_CACHE_SIZE
Expand Down Expand Up @@ -15472,7 +15516,7 @@ static struct PyMethodDef Decoder_methods[] = {
"decode", (PyCFunction) Decoder_decode, METH_FASTCALL,
Decoder_decode__doc__,
},
#if PY_VERSION_HEX >= 0x03090000
#if PY39_PLUS
{"__class_getitem__", Py_GenericAlias, METH_O|METH_CLASS},
#endif
{NULL, NULL} /* sentinel */
Expand Down Expand Up @@ -18512,7 +18556,7 @@ static struct PyMethodDef JSONDecoder_methods[] = {
"decode_lines", (PyCFunction) JSONDecoder_decode_lines, METH_FASTCALL,
JSONDecoder_decode_lines__doc__,
},
#if PY_VERSION_HEX >= 0x03090000
#if PY39_PLUS
{"__class_getitem__", Py_GenericAlias, METH_O|METH_CLASS},
#endif
{NULL, NULL} /* sentinel */
Expand Down Expand Up @@ -21029,6 +21073,9 @@ msgspec_clear(PyObject *m)
Py_CLEAR(st->str___dataclass_fields__);
Py_CLEAR(st->str___attrs_attrs__);
Py_CLEAR(st->str___supertype__);
#if PY312_PLUS
Py_CLEAR(st->str___value__);
#endif
Py_CLEAR(st->str___bound__);
Py_CLEAR(st->str___constraints__);
Py_CLEAR(st->str_int);
Expand All @@ -21051,8 +21098,11 @@ msgspec_clear(PyObject *m)
Py_CLEAR(st->get_typeddict_info);
Py_CLEAR(st->get_dataclass_info);
Py_CLEAR(st->rebuild);
#if PY_VERSION_HEX >= 0x030a00f0
#if PY310_PLUS
Py_CLEAR(st->types_uniontype);
#endif
#if PY312_PLUS
Py_CLEAR(st->typing_typealiastype);
#endif
Py_CLEAR(st->astimezone);
Py_CLEAR(st->re_compile);
Expand Down Expand Up @@ -21118,8 +21168,11 @@ msgspec_traverse(PyObject *m, visitproc visit, void *arg)
Py_VISIT(st->get_typeddict_info);
Py_VISIT(st->get_dataclass_info);
Py_VISIT(st->rebuild);
#if PY_VERSION_HEX >= 0x030a00f0
#if PY310_PLUS
Py_VISIT(st->types_uniontype);
#endif
#if PY312_PLUS
Py_VISIT(st->typing_typealiastype);
#endif
Py_VISIT(st->astimezone);
Py_VISIT(st->re_compile);
Expand Down Expand Up @@ -21315,6 +21368,9 @@ PyInit__core(void)
SET_REF(typing_final, "Final");
SET_REF(typing_generic, "Generic");
SET_REF(typing_generic_alias, "_GenericAlias");
#if PY312_PLUS
SET_REF(typing_typealiastype, "TypeAliasType");
#endif
Py_DECREF(temp_module);

temp_module = PyImport_ImportModule("msgspec._utils");
Expand All @@ -21328,7 +21384,7 @@ PyInit__core(void)
SET_REF(rebuild, "rebuild");
Py_DECREF(temp_module);

#if PY_VERSION_HEX >= 0x030a00f0
#if PY310_PLUS
temp_module = PyImport_ImportModule("types");
if (temp_module == NULL) return NULL;
SET_REF(types_uniontype, "UnionType");
Expand Down Expand Up @@ -21411,6 +21467,9 @@ PyInit__core(void)
CACHED_STRING(str___dataclass_fields__, "__dataclass_fields__");
CACHED_STRING(str___attrs_attrs__, "__attrs_attrs__");
CACHED_STRING(str___supertype__, "__supertype__");
#if PY312_PLUS
CACHED_STRING(str___value__, "__value__");
#endif
CACHED_STRING(str___bound__, "__bound__");
CACHED_STRING(str___constraints__, "__constraints__");
CACHED_STRING(str_int, "int");
Expand Down
9 changes: 9 additions & 0 deletions msgspec/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
except Exception:
_types_UnionType = type("UnionType", (), {}) # type: ignore

try:
from typing import TypeAliasType as _TypeAliasType # type: ignore
except Exception:
_TypeAliasType = type("TypeAliasType", (), {}) # type: ignore

import msgspec
from msgspec import NODEFAULT, UNSET, UnsetType as _UnsetType

Expand Down Expand Up @@ -628,6 +633,8 @@ def _origin_args_metadata(t):
t = origin
elif origin == Final:
t = t.__args__[0]
elif type(origin) is _TypeAliasType:
t = origin.__value__[t.__args__]
else:
args = getattr(t, "__args__", None)
origin = _CONCRETE_TYPES.get(origin, origin)
Expand All @@ -636,6 +643,8 @@ def _origin_args_metadata(t):
supertype = getattr(t, "__supertype__", None)
if supertype is not None:
t = supertype
elif type(t) is _TypeAliasType:
t = t.__value__
else:
origin = t
args = None
Expand Down
Loading
Loading