diff --git a/src/dask_awkward/lib/unproject_layout.py b/src/dask_awkward/lib/unproject_layout.py index 0e1db0ad..493e3150 100644 --- a/src/dask_awkward/lib/unproject_layout.py +++ b/src/dask_awkward/lib/unproject_layout.py @@ -316,8 +316,15 @@ def _unproject_layout(form, layout, length, backend): # UnmaskedArray, non-UnmaskedArray form elif isinstance(layout, UnmaskedArray) and form.is_option: if isinstance(form, BitMaskedForm): + byte_length = ( + unknown_length if length is unknown_length else math.ceil(length / 8.0) + ) return BitMaskedArray( - ak.index.IndexU8.zeros(length, backend.index_nplike), + ak.index.Index( + backend.index_nplike.full(byte_length, 255, dtype=np.uint8) + if form.valid_when + else backend.index_nplike.zeros(byte_length, dtype=np.uint8) + ), _unproject_layout( form.content, layout.content, layout.content.length, backend ), @@ -328,7 +335,11 @@ def _unproject_layout(form, layout, length, backend): ) elif isinstance(form, ByteMaskedForm): return ByteMaskedArray( - ak.index.Index8.zeros(length, backend.index_nplike), + ak.index.Index( + backend.index_nplike.full(length, 1, dtype=np.int8) + if form.valid_when + else backend.index_nplike.zeros(length, dtype=np.int8) + ), _unproject_layout( form.content, layout.content, layout.content.length, backend ), diff --git a/tests/test_unproject_layout.py b/tests/test_unproject_layout.py index d5641725..552a5e1a 100644 --- a/tests/test_unproject_layout.py +++ b/tests/test_unproject_layout.py @@ -87,3 +87,24 @@ def test_UnionArray(): projected = ak.from_iter([{"x": 1}, {"x": 2}, {"x": 3}], highlevel=False) unprojected = unproject_layout(form, projected) compare_values(projected, unprojected) + + +def test_BitMaskedArray(): + form = ak.forms.BitMaskedForm('u8', ak.forms.NumpyForm('int64'), valid_when=False, lsb_order=True) + projected = ak.contents.UnmaskedArray(ak.from_iter([1, 2, 3, 4], highlevel=False)) + unprojected = unproject_layout(form, projected) + compare_values(projected, unprojected) + + form = ak.forms.BitMaskedForm('u8', ak.forms.NumpyForm('int64'), valid_when=True, lsb_order=True) + unprojected = unproject_layout(form, projected) + compare_values(projected, unprojected) + +def test_ByteMaskedArray(): + form = ak.forms.ByteMaskedForm('u8', ak.forms.NumpyForm('int64'), valid_when=False) + projected = ak.contents.UnmaskedArray(ak.from_iter([1, 2, 3, 4], highlevel=False)) + unprojected = unproject_layout(form, projected) + compare_values(projected, unprojected) + + form = ak.forms.ByteMaskedForm('u8', ak.forms.NumpyForm('int64'), valid_when=True) + unprojected = unproject_layout(form, projected) + compare_values(projected, unprojected)