diff --git a/src/awkward/_v2/_typetracer.py b/src/awkward/_v2/_typetracer.py index a25d969a4a..7cf34c49ba 100644 --- a/src/awkward/_v2/_typetracer.py +++ b/src/awkward/_v2/_typetracer.py @@ -203,8 +203,22 @@ class TypeTracerArray: def from_array(cls, array, dtype=None): if isinstance(array, ak._v2.index.Index): array = array.data + + # not array-like + if not hasattr(array, "shape"): + sequence = list(array) + array = numpy.array(sequence) + if array.dtype == np.dtype("O"): + raise ak._v2._util.error( + ValueError( + f"bug in Awkward Array: attempt to construct `TypeTracerArray` " + f"from a sequence of non-primitive types: {sequence}" + ) + ) + if dtype is None: dtype = array.dtype + return cls(dtype, shape=array.shape) def __init__(self, dtype, shape=None): @@ -432,9 +446,6 @@ def copy(self): return self -unset = object() - - class TypeTracer(ak.nplike.NumpyLike): known_data = False known_shape = False @@ -483,21 +494,21 @@ def raw(self, array, nplike): ############################ array creation - def array(self, data, dtype=unset, **kwargs): + def array(self, data, dtype=None, **kwargs): # data[, dtype=[, copy=]] - if dtype is unset: + if dtype is None: dtype = data.dtype return TypeTracerArray.from_array(data, dtype=dtype) - def asarray(self, array, dtype=unset, **kwargs): + def asarray(self, array, dtype=None, **kwargs): # array[, dtype=][, order=] - if dtype is unset: + if dtype is None: dtype = array.dtype return TypeTracerArray.from_array(array, dtype=dtype) - def ascontiguousarray(self, array, dtype=unset, **kwargs): + def ascontiguousarray(self, array, dtype=None, **kwargs): # array[, dtype=] - if dtype is unset: + if dtype is None: dtype = array.dtype return TypeTracerArray.from_array(array, dtype=dtype) @@ -520,23 +531,26 @@ def empty(self, shape, dtype=np.float64, **kwargs): # shape/len[, dtype=] return TypeTracerArray(dtype, shape) - def full(self, shape, value, dtype=unset, **kwargs): + def full(self, shape, value, dtype=None, **kwargs): # shape/len, value[, dtype=] - if dtype is unset: + if dtype is None: dtype = numpy.array(value).dtype return TypeTracerArray(dtype, shape) - def zeros_like(self, *args, **kwargs): - # array - raise ak._v2._util.error(NotImplementedError) + def zeros_like(self, a, dtype=None, **kwargs): + if dtype is None: + dtype = a.dtype - def ones_like(self, *args, **kwargs): - # array - raise ak._v2._util.error(NotImplementedError) + if isinstance(a, UnknownScalar): + return UnknownScalar(dtype) - def full_like(self, *args, **kwargs): - # array, fill_value - raise ak._v2._util.error(NotImplementedError) + return TypeTracerArray(dtype, a.shape) + + def ones_like(self, a, dtype=None, **kwargs): + return self.zeros_like(a, dtype) + + def full_like(self, a, fill_value, dtype=None, **kwargs): + return self.zeros_like(a, dtype) def arange(self, *args, **kwargs): # stop[, dtype=] @@ -854,7 +868,7 @@ def argmax(self, *args, **kwargs): raise ak._v2._util.error(NotImplementedError) def array_str( - self, array, max_line_width=unset, precision=unset, suppress_small=unset + self, array, max_line_width=None, precision=None, suppress_small=None ): # array, max_line_width, precision=None, suppress_small=None return "[?? ... ??]" diff --git a/tests/v2/test_1504-typetracer-like.py b/tests/v2/test_1504-typetracer-like.py new file mode 100644 index 0000000000..403f0e4f44 --- /dev/null +++ b/tests/v2/test_1504-typetracer-like.py @@ -0,0 +1,48 @@ +import awkward as ak +import numpy as np +import pytest + + +typetracer = ak._v2._typetracer.TypeTracer.instance() + + +@pytest.mark.parametrize("dtype", [np.float64, np.int64, np.uint8, None]) +@pytest.mark.parametrize("like_dtype", [np.float64, np.int64, np.uint8, None]) +def test_ones_like(dtype, like_dtype): + array = ak._v2.contents.numpyarray.NumpyArray( + np.array([99, 88, 77, 66, 66], dtype=dtype) + ) + ones = ak._v2.ones_like(array.typetracer, dtype=like_dtype, highlevel=False) + assert ones.typetracer.shape == array.shape + assert ones.typetracer.dtype == like_dtype or array.dtype + + +@pytest.mark.parametrize("dtype", [np.float64, np.int64, np.uint8, None]) +@pytest.mark.parametrize("like_dtype", [np.float64, np.int64, np.uint8, None]) +def test_zeros_like(dtype, like_dtype): + array = ak._v2.contents.numpyarray.NumpyArray( + np.array([99, 88, 77, 66, 66], dtype=dtype) + ) + + full = ak._v2.zeros_like(array.typetracer, dtype=like_dtype, highlevel=False) + assert full.typetracer.shape == array.shape + assert full.typetracer.dtype == like_dtype or array.dtype + + +@pytest.mark.parametrize("dtype", [np.float64, np.int64, np.uint8, None]) +@pytest.mark.parametrize("like_dtype", [np.float64, np.int64, np.uint8, None]) +@pytest.mark.parametrize("value", [1.0, -20, np.iinfo(np.uint64).max]) +def test_full_like(dtype, like_dtype, value): + array = ak._v2.contents.numpyarray.NumpyArray( + np.array([99, 88, 77, 66, 66], dtype=dtype) + ) + full = ak._v2.full_like(array.typetracer, value, dtype=like_dtype, highlevel=False) + assert full.typetracer.shape == array.shape + assert full.typetracer.dtype == like_dtype or array.dtype + + +def test_full_like_cast(): + with pytest.raises(ValueError): + ak._v2._typetracer.TypeTracerArray.from_array( + [1, ak._v2._typetracer.UnknownScalar(np.uint8)] + )