Skip to content

Commit

Permalink
feat: add support for dictionary encoding from Arrow (#2630)
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 authored Aug 9, 2023
1 parent 6c52131 commit e2eb69a
Show file tree
Hide file tree
Showing 13 changed files with 418 additions and 75 deletions.
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_from_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def from_categorical(array, *, highlevel=True, behavior=None):
from categorical is cheap.)
See also #ak.is_categorical, #ak.categories, #ak.to_categorical,
#ak.from_categorical.
#ak.str.to_categorical, #ak.from_categorical.
"""
# Dispatch
yield (array,)
Expand Down
6 changes: 6 additions & 0 deletions src/awkward/operations/ak_to_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from awkward._behavior import behavior_of
from awkward._categorical import as_hashable
from awkward._dispatch import high_level_function
from awkward._errors import deprecate
from awkward._layout import wrap_layout
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
Expand Down Expand Up @@ -92,6 +93,11 @@ def to_categorical(array, *, highlevel=True, behavior=None):


def _impl(array, highlevel, behavior):
deprecate(
"The general purpose `ak.to_categorical` has been replaced by `ak.str.to_categorical`",
"2.5.0",
)

def action(layout, **kwargs):
if layout.purelist_depth == 1:
if layout.is_indexed and layout.is_option:
Expand Down
20 changes: 15 additions & 5 deletions src/awkward/operations/str/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

# https://arrow.apache.org/docs/python/api/compute.html#string-predicates
# This set of string-manipulation routines is strongly inspired by Arrow:


# string predicates
# https://arrow.apache.org/docs/python/api/compute.html#string-predicates
from awkward.operations.str.akstr_is_alnum import *
from awkward.operations.str.akstr_is_alpha import *
from awkward.operations.str.akstr_is_decimal import *
Expand All @@ -16,6 +18,7 @@
from awkward.operations.str.akstr_is_ascii import *

# string transforms
# https://arrow.apache.org/docs/python/api/compute.html#string-transforms
from awkward.operations.str.akstr_capitalize import *
from awkward.operations.str.akstr_length import *
from awkward.operations.str.akstr_lower import *
Expand All @@ -29,11 +32,13 @@
from awkward.operations.str.akstr_replace_substring_regex import *

# string padding
# https://arrow.apache.org/docs/python/api/compute.html#string-padding
from awkward.operations.str.akstr_center import *
from awkward.operations.str.akstr_lpad import *
from awkward.operations.str.akstr_rpad import *

# string trimming
# https://arrow.apache.org/docs/python/api/compute.html#string-trimming
from awkward.operations.str.akstr_ltrim import *
from awkward.operations.str.akstr_ltrim_whitespace import *
from awkward.operations.str.akstr_rtrim import *
Expand All @@ -42,25 +47,26 @@
from awkward.operations.str.akstr_trim_whitespace import *

# string splitting
# https://arrow.apache.org/docs/python/api/compute.html#string-splitting
from awkward.operations.str.akstr_split_whitespace import *
from awkward.operations.str.akstr_split_pattern import *
from awkward.operations.str.akstr_split_pattern_regex import *

# string component extraction

# https://arrow.apache.org/docs/python/api/compute.html#string-component-extraction
from awkward.operations.str.akstr_extract_regex import *

# string joining

# https://arrow.apache.org/docs/python/api/compute.html#string-joining
from awkward.operations.str.akstr_join import *
from awkward.operations.str.akstr_join_element_wise import *

# string slicing

# https://arrow.apache.org/docs/python/api/compute.html#string-slicing
from awkward.operations.str.akstr_slice import *

# containment tests

# https://arrow.apache.org/docs/python/api/compute.html#containment-tests
from awkward.operations.str.akstr_count_substring import *
from awkward.operations.str.akstr_count_substring_regex import *
from awkward.operations.str.akstr_ends_with import *
Expand All @@ -73,6 +79,10 @@
from awkward.operations.str.akstr_match_substring_regex import *
from awkward.operations.str.akstr_starts_with import *

# dictionary-encoding (exclusively for strings)
# https://arrow.apache.org/docs/python/api/compute.html#associative-transforms
from awkward.operations.str.akstr_to_categorical import *


def _get_ufunc_action(
utf8_function,
Expand Down
79 changes: 79 additions & 0 deletions src/awkward/operations/str/akstr_to_categorical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

__all__ = ("to_categorical",)

import awkward as ak
from awkward._behavior import behavior_of
from awkward._dispatch import high_level_function
from awkward._layout import wrap_layout


@high_level_function(module="ak.str")
def to_categorical(array, *, highlevel=True, behavior=None):
"""
Args:
array: Array-like data (anything #ak.to_layout recognizes).
highlevel (bool): If True, return an #ak.Array; otherwise, return
a low-level #ak.contents.Content subclass.
behavior (None or dict): Custom #ak.behavior for the output array, if
high-level.
Returns a dictionary-encoded version of the given array of strings.
Creates a categorical dataset, which has the following properties:
* only distinct values (categories) are stored in their entirety,
* pointers to those distinct values are represented by integers
(an #ak.contents.IndexedArray or #ak.contents.IndexedOptionArray
labeled with parameter `"__array__" = "categorical"`.
This is equivalent to R's "factor", and Pandas's "categorical".
It differs from generic uses of #ak.contents.IndexedArray and
#ak.contents.IndexedOptionArray in Awkward Arrays by the guarantee of no
duplicate categories and the `"categorical"` parameter.
Unlike Arrow's `dictionary_encode`, this function has no `null_handling`
argument. This function's behavior is like`null_handling="mask"` (Arrow's default).
It is not possible to encode null values in Awkward Array, as #ak.contents.IndexedOptionArray
cannot contain an option type node.
Note: this function does not raise an error if the `array` does not
contain any string or bytestring data.
Requires the pyarrow library and calls
[pyarrow.compute.dictionary_encode](https://arrow.apache.org/docs/python/generated/pyarrow.compute.dictionary_encode.html).
"""
# Dispatch
yield (array,)

# Implementation
return _impl(array, highlevel, behavior)


def _impl(array, highlevel, behavior):
from awkward._connect.pyarrow import import_pyarrow_compute

pc = import_pyarrow_compute("ak.str.to_categorical")
behavior = behavior_of(array, behavior=behavior)

def action(layout, **kwargs):
if layout.is_list and layout.parameter("__array__") in {"string", "bytestring"}:
result = ak.from_arrow(
pc.dictionary_encode(ak.to_arrow(layout, extensionarray=False)),
highlevel=False,
)
# Arrow _always_ adds an option here, even though we know that
# no values are null. Therefore, we can safely replace the indexed
# option-type with an indexed type.
assert isinstance(result, ak.contents.IndexedOptionArray)
return ak.contents.IndexedArray(
result.index, result.content, parameters=result._parameters
)

out = ak._do.recursively_apply(
ak.operations.to_layout(array),
action,
behavior,
)

return wrap_layout(out, behavior, highlevel)
104 changes: 69 additions & 35 deletions tests/test_0401_add_categorical_type_for_arrow_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,10 @@ def test_option_two_extra():
def test_to_categorical_numbers():
array = ak.Array([1.1, 2.2, 3.3, 1.1, 2.2, 3.3, 1.1, 2.2, 3.3])
assert not ak.operations.ak_is_categorical.is_categorical(array)
categorical = ak.operations.ak_to_categorical.to_categorical(array)
with pytest.warns(
DeprecationWarning, match=r"has been replaced by.*ak\.str\.to_categorical"
):
categorical = ak.operations.ak_to_categorical.to_categorical(array)
assert ak.operations.ak_is_categorical.is_categorical(categorical)
assert to_list(array) == categorical.to_list()
assert to_list(categorical.layout.content) == [1.1, 2.2, 3.3]
Expand All @@ -225,7 +228,10 @@ def test_to_categorical_numbers():
def test_to_categorical_nested():
array = ak.Array([["one", "two", "three"], [], ["one", "two"], ["three"]])
assert not ak.operations.ak_is_categorical.is_categorical(array)
categorical = ak.operations.ak_to_categorical.to_categorical(array)
with pytest.warns(
DeprecationWarning, match=r"has been replaced by.*ak\.str\.to_categorical"
):
categorical = ak.operations.ak_to_categorical.to_categorical(array)
assert ak.operations.ak_is_categorical.is_categorical(categorical)
assert to_list(array) == categorical.to_list()
not_categorical = ak.operations.ak_from_categorical.from_categorical(categorical)
Expand All @@ -242,7 +248,10 @@ def test_to_categorical():
["one", "two", "three", "one", "two", "three", "one", "two", "three"]
)
assert not ak.operations.ak_is_categorical.is_categorical(array)
categorical = ak.operations.ak_to_categorical.to_categorical(array)
with pytest.warns(
DeprecationWarning, match=r"has been replaced by.*ak\.str\.to_categorical"
):
categorical = ak.operations.ak_to_categorical.to_categorical(array)
assert ak.operations.ak_is_categorical.is_categorical(categorical)
assert to_list(array) == categorical.to_list()
assert to_list(categorical.layout.content) == ["one", "two", "three"]
Expand Down Expand Up @@ -273,7 +282,10 @@ def test_to_categorical_none():
]
)
assert not ak.operations.ak_is_categorical.is_categorical(array)
categorical = ak.operations.ak_to_categorical.to_categorical(array)
with pytest.warns(
DeprecationWarning, match=r"has been replaced by.*ak\.str\.to_categorical"
):
categorical = ak.operations.ak_to_categorical.to_categorical(array)
assert ak.operations.ak_is_categorical.is_categorical(categorical)
assert to_list(array) == categorical.to_list()
assert to_list(categorical.layout.content) == ["one", "two", "three"]
Expand Down Expand Up @@ -323,7 +335,10 @@ def test_to_categorical_masked():
)
array = ak.Array(ak.contents.ByteMaskedArray(mask, content, valid_when=False))
assert not ak.operations.ak_is_categorical.is_categorical(array)
categorical = ak.operations.ak_to_categorical.to_categorical(array)
with pytest.warns(
DeprecationWarning, match=r"has been replaced by.*ak\.str\.to_categorical"
):
categorical = ak.operations.ak_to_categorical.to_categorical(array)
assert ak.operations.ak_is_categorical.is_categorical(categorical)
assert to_list(array) == categorical.to_list()
assert to_list(categorical.layout.content) == ["one", "two", "three"]
Expand Down Expand Up @@ -366,7 +381,10 @@ def test_to_categorical_masked_again():
ak.contents.ByteMaskedArray.simplified(mask, indexedarray, valid_when=False)
)
assert not ak.operations.ak_is_categorical.is_categorical(array)
categorical = ak.operations.ak_to_categorical.to_categorical(array)
with pytest.warns(
DeprecationWarning, match=r"has been replaced by.*ak\.str\.to_categorical"
):
categorical = ak.operations.ak_to_categorical.to_categorical(array)
assert ak.operations.ak_is_categorical.is_categorical(categorical)
assert to_list(array) == categorical.to_list()
assert to_list(categorical.layout.content) == ["one", "two", "three"]
Expand All @@ -381,46 +399,59 @@ def test_to_categorical_masked_again():

@pytest.mark.skip(reason="Fix issues for categorical type")
def test_typestr():
assert (
str(
ak.operations.type(
ak.operations.ak_to_categorical.to_categorical(
ak.Array([1.1, 2.2, 2.2, 3.3])
with pytest.warns(
DeprecationWarning, match=r"has been replaced by.*ak\.str\.to_categorical"
):
assert (
str(
ak.operations.type(
ak.operations.ak_to_categorical.to_categorical(
ak.Array([1.1, 2.2, 2.2, 3.3])
)
)
)
== "4 * categorical[type=float64]"
)
== "4 * categorical[type=float64]"
)
assert (
str(
ak.operations.type(
ak.operations.ak_to_categorical.to_categorical(
ak.Array([1.1, 2.2, None, 2.2, 3.3])

with pytest.warns(
DeprecationWarning, match=r"has been replaced by.*ak\.str\.to_categorical"
):
assert (
str(
ak.operations.type(
ak.operations.ak_to_categorical.to_categorical(
ak.Array([1.1, 2.2, None, 2.2, 3.3])
)
)
)
== "5 * categorical[type=?float64]"
)
== "5 * categorical[type=?float64]"
)
assert (
str(
ak.operations.type(
ak.operations.ak_to_categorical.to_categorical(
ak.Array(["one", "two", "two", "three"])
with pytest.warns(
DeprecationWarning, match=r"has been replaced by.*ak\.str\.to_categorical"
):
assert (
str(
ak.operations.type(
ak.operations.ak_to_categorical.to_categorical(
ak.Array(["one", "two", "two", "three"])
)
)
)
== "4 * categorical[type=string]"
)
== "4 * categorical[type=string]"
)
assert (
str(
ak.operations.type(
ak.operations.ak_to_categorical.to_categorical(
ak.Array(["one", "two", None, "two", "three"])
with pytest.warns(
DeprecationWarning, match=r"has been replaced by.*ak\.str\.to_categorical"
):
assert (
str(
ak.operations.type(
ak.operations.ak_to_categorical.to_categorical(
ak.Array(["one", "two", None, "two", "three"])
)
)
)
== "5 * categorical[type=?string]"
)
== "5 * categorical[type=?string]"
)


def test_zip():
Expand All @@ -431,7 +462,10 @@ def test_zip():
{"x": 2.2, "y": "two"},
{"x": 3.3, "y": "three"},
]
y = ak.operations.ak_to_categorical.to_categorical(y)
with pytest.warns(
DeprecationWarning, match=r"has been replaced by.*ak\.str\.to_categorical"
):
y = ak.operations.ak_to_categorical.to_categorical(y)
assert ak.zip({"x": x, "y": y}).to_list() == [
{"x": 1.1, "y": "one"},
{"x": 2.2, "y": "two"},
Expand Down
7 changes: 5 additions & 2 deletions tests/test_0404_array_validity_check.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/master/LICENSE

import numpy as np
import pytest # noqa: F401
import pytest

import awkward as ak

Expand Down Expand Up @@ -213,7 +213,10 @@ def test_subranges_equal():

def test_categorical():
array = ak.highlevel.Array(["1chchc", "1chchc", "2sss", "3", "4", "5"])
categorical = ak.operations.ak_to_categorical.to_categorical(array)
with pytest.warns(
DeprecationWarning, match=r"has been replaced by.*ak\.str\.to_categorical"
):
categorical = ak.operations.ak_to_categorical.to_categorical(array)

assert ak.operations.is_valid(categorical) is True
assert ak._do.is_unique(categorical.layout) is False
Expand Down
5 changes: 4 additions & 1 deletion tests/test_0674_categorical_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
def test_categorical_is_valid():
# validate a categorical array by its content
arr = ak.Array([2019, 2020, 2021, 2020, 2019])
categorical = ak.operations.ak_to_categorical.to_categorical(arr)
with pytest.warns(
DeprecationWarning, match=r"has been replaced by.*ak\.str\.to_categorical"
):
categorical = ak.operations.ak_to_categorical.to_categorical(arr)
assert ak.operations.is_valid(categorical)


Expand Down
Loading

0 comments on commit e2eb69a

Please sign in to comment.