Skip to content

Commit

Permalink
In progress. Not ready for review. Approach 2.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697791585
  • Loading branch information
Google-ML-Automation committed Nov 19, 2024
1 parent 2c68569 commit 178ae8f
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 19 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ py_library_providing_imports_info(
":xla_bridge",
":xla_metadata",
"//jax/_src/lib",
"//third_party/py/absl/logging",
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps,
)

Expand Down
4 changes: 3 additions & 1 deletion jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,9 @@ def _str_abstractify(x):

def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
if dtype != np.dtype('object'):
dtypes.check_valid_dtype(dtype)

return ShapedArray(x.shape,
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
_shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
Expand Down
56 changes: 39 additions & 17 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@
import builtins
import dataclasses
import functools
import logging
import types
from typing import cast, overload, Any, Literal, Union
from typing import Any, Literal, Union, cast, overload
import warnings

import ml_dtypes
import numpy as np

from jax._src import config
from jax._src import traceback_util
from jax._src.typing import Array, DType, DTypeLike
from jax._src.util import set_module, StrictABC
from jax._src.util import StrictABC, set_module
import ml_dtypes
import numpy as np

from jax._src import traceback_util
traceback_util.register_exclusion(__file__)

try:
Expand Down Expand Up @@ -445,6 +445,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
np.dtype('int32'),
np.dtype('int64'),
]
_string_types = [np.dtype('str')]

if _int2_dtype is not None:
_signed_types.insert(0, _int2_dtype)
Expand All @@ -463,18 +464,37 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
np.dtype('complex64'),
np.dtype('complex128'),
]

_string_types = [
np.dtype('object'),
]


_jax_types = _bool_types + _int_types + _float_types + _complex_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}

_jax_dtype_set = {
float0,
*_bool_types,
*_int_types,
*_float_types,
*_complex_types,
# *_string_types,
}

_dtype_kinds: dict[str, set] = {
'bool': {*_bool_types},
'signed integer': {*_signed_types},
'unsigned integer': {*_unsigned_types},
'integral': {*_signed_types, *_unsigned_types},
'real floating': {*_float_types},
'complex floating': {*_complex_types},
'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types},
_dtype_kinds: dict[str, set[JAXType]] = {
'bool': {*_bool_types},
'signed integer': {*_signed_types},
'unsigned integer': {*_unsigned_types},
'integral': {*_signed_types, *_unsigned_types},
'real floating': {*_float_types},
'complex floating': {*_complex_types},
'numeric': {
*_signed_types,
*_unsigned_types,
*_float_types,
*_complex_types,
},
'string': {*_string_types},
}


Expand Down Expand Up @@ -736,8 +756,10 @@ def is_python_scalar(x: Any) -> bool:

def check_valid_dtype(dtype: DType) -> None:
if dtype not in _jax_dtype_set:
raise TypeError(f"Dtype {dtype} is not a valid JAX array "
"type. Only arrays of numeric types are supported by JAX.")
raise TypeError(
f'Dtype {dtype} is not a valid JAX array '
'type. Only arrays of numeric types and strings are supported by JAX.'
)

def dtype(x: Any, *, canonicalize: bool = False) -> DType:
"""Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:

def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
if dtype != np.dtype('object'):
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))


Expand Down
12 changes: 12 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5582,6 +5582,18 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# Keep the output uncommitted.
return jax.device_put(object)

# 2DO: Add a comment.
if isinstance(object, np.ndarray) and (object.dtype == np.dtype(np.object_)):
if (dtype is not None) and (dtype != object.dtype):
raise TypeError(
f"Cannot convert an array with dtype=object to dtype {dtype}"
)
if (ndmin > 0) and (ndmin != object.ndim):
raise TypeError(
f"ndmin {ndmin} does not match ndims {object.ndim} of input array"
)
return jax.device_put(x=object, device=device)

# For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this
# triggering for traced values. We do this here because it matters whether or
Expand Down

0 comments on commit 178ae8f

Please sign in to comment.