Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Chart.transform_filter(*predicates, **constraints) #3664

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5bdb33d
chore(typing): Add temporary alias for `filter`
dangotbanned Oct 30, 2024
2efd5f8
refactor: Make `empty` a regular keyword arg
dangotbanned Oct 30, 2024
d2868f8
refactor: Remove `**kwargs`
dangotbanned Oct 30, 2024
841e887
docs: Add note on `Predicate`
dangotbanned Oct 30, 2024
63de259
feat(DRAFT): Adds `_transform_filter_impl`
dangotbanned Oct 30, 2024
40146f7
Merge remote-tracking branch 'upstream/main' into transform-filter-pr…
dangotbanned Oct 30, 2024
c22ba58
feat: Adds `transform_filter` implementation
dangotbanned Oct 30, 2024
fc672bc
fix(DRAFT): Add temp ignore for `line_chart_with_cumsum_faceted`
dangotbanned Oct 30, 2024
8554a46
feat(typing): Widen `_FieldEqualType` to include `IntoExpression`
dangotbanned Oct 30, 2024
ff9d33f
fix: Try replacing `Undefined` first
dangotbanned Oct 31, 2024
b497039
test: Add `(*predicates, **constraints)` syntax tests
dangotbanned Oct 31, 2024
54d0cbc
Merge branch 'main' into transform-filter-predicates
dangotbanned Nov 2, 2024
a375ab5
Merge remote-tracking branch 'upstream/main' into transform-filter-pr…
dangotbanned Nov 3, 2024
7f6c188
docs: Use `*predicates` in "Faceted Line Chart with Cumulative Sum"
dangotbanned Nov 3, 2024
0d4ff86
Merge remote-tracking branch 'upstream/main' into transform-filter-pr…
dangotbanned Nov 4, 2024
be63d4e
refactor: Remove `_OrigFilterType`
dangotbanned Nov 4, 2024
7a0cc42
docs: Update `.transform_filter()` docstring
dangotbanned Nov 4, 2024
08a4207
docs: Minor corrections in examples
dangotbanned Nov 4, 2024
5fd4f5a
refactor(typing): Cast deprecated `filter` in one location
dangotbanned Nov 4, 2024
d640933
refactor: Remove `pred` assignment
dangotbanned Nov 4, 2024
2d57d6e
docs(typing): Update `_FieldEqualType`
dangotbanned Nov 4, 2024
01571dd
Merge branch 'main' into transform-filter-predicates
dangotbanned Nov 4, 2024
10f1d1d
Merge branch 'main' into transform-filter-predicates
dangotbanned Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 110 additions & 36 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@
NamedData,
ParameterName,
PointSelectionConfig,
Predicate,
PredicateComposition,
ProjectionType,
RepeatMapping,
Expand Down Expand Up @@ -543,12 +542,19 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool:
"""


_FieldEqualType: TypeAlias = Union[PrimitiveValue_T, Map, Parameter, SchemaBase]
"""Permitted types for equality checks on field values:
_FieldEqualType: TypeAlias = Union["IntoExpression", Parameter, SchemaBase]
"""
Permitted types for equality checks on field values.

Applies to the following context(s):

import altair as alt

- `datum.field == ...`
- `FieldEqualPredicate(equal=...)`
- `when(**constraints=...)`
alt.datum.field == ...
alt.FieldEqualPredicate(field="field", equal=...)
alt.when(field=...)
alt.when().then().when(field=...)
alt.Chart.transform_filter(field=...)
"""


Expand Down Expand Up @@ -2988,45 +2994,113 @@ def transform_extent(
"""
return self._add_transform(core.ExtentTransform(extent=extent, param=param))

# TODO: Update docstring
# # E.g. {'not': alt.FieldRangePredicate(field='year', range=[1950, 1960])}
def transform_filter(
self,
filter: str
| Expr
| Expression
| Predicate
| Parameter
| PredicateComposition
| dict[str, Predicate | str | list | bool],
**kwargs: Any,
predicate: Optional[_PredicateType] = Undefined,
*more_predicates: _ComposablePredicateType,
empty: Optional[bool] = Undefined,
**constraints: _FieldEqualType,
) -> Self:
"""
Add a :class:`FilterTransform` to the schema.
Add a :class:`FilterTransform` to the spec.

The resulting predicate is an ``&`` reduction over ``predicate`` and optional ``*``, ``**``, arguments.

Parameters
----------
filter : a filter expression or :class:`PredicateComposition`
The `filter` property must be one of the predicate definitions:
(1) a string or alt.expr expression
(2) a range predicate
(3) a selection predicate
(4) a logical operand combining (1)-(3)
(5) a Selection object
predicate
A selection or test predicate. ``str`` input will be treated as a test operand.
*more_predicates
Additional predicates, restricted to types supporting ``&``.
empty
For selection parameters, the predicate of empty selections returns ``True`` by default.
Override this behavior, with ``empty=False``.

Returns
-------
self : Chart object
returns chart to allow for chaining
.. note::
When ``predicate`` is a ``Parameter`` that is used more than once,
``self.transform_filter(..., empty=...)`` provides granular control for each occurrence.
**constraints
Specify `Field Equal Predicate`_'s.
Shortcut for ``alt.datum.field_name == value``, see examples for usage.

Warns
-----
AltairDeprecationWarning
If called using ``filter`` as a keyword argument.

See Also
--------
alt.when : Uses a similar syntax for defining conditional values.

Notes
-----
- Directly inspired by the syntax used in `polars.DataFrame.filter`_.

.. _Field Equal Predicate:
https://vega.github.io/vega-lite/docs/predicate.html#equal-predicate
.. _polars.DataFrame.filter:
https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.filter.html

Examples
--------
Setting up a common chart::

import altair as alt
from altair import datum
from vega_datasets import data

source = data.population.url
chart = (
alt.Chart(source)
.mark_line()
.encode(
x="age:O",
y="sum(people):Q",
color=alt.Color("year:O").legend(symbolType="square"),
)
)
chart

Singular predicates can be expressed via ``datum``::

chart.transform_filter(datum.year <= 1980)

We can also use selection parameters directly::

selection = alt.selection_point(encodings=["color"], bind="legend")
chart.transform_filter(selection).add_params(selection)

Or a field predicate::

between_1950_60 = alt.FieldRangePredicate(field="year", range=[1950, 1960])
chart.transform_filter(between_1950_60) | chart.transform_filter(~between_1950_60)

Predicates can be composed together using logical operands::

chart.transform_filter(between_1950_60 | (datum.year == 1850))

Predicates passed as positional arguments will be reduced with ``&``::

chart.transform_filter(datum.year > 1980, datum.age != 90)

Using keyword-argument ``constraints`` can simplify compositions like::

verbose_composition = chart.transform_filter((datum.year == 2000) & (datum.sex == 1))
chart.transform_filter(year=2000, sex=1)
"""
if isinstance(filter, Parameter):
new_filter: dict[str, Any] = {"param": filter.name}
if "empty" in kwargs:
new_filter["empty"] = kwargs.pop("empty")
elif isinstance(filter.empty, bool):
new_filter["empty"] = filter.empty
filter = new_filter
return self._add_transform(core.FilterTransform(filter=filter, **kwargs))
if depr_filter := t.cast(Any, constraints.pop("filter", None)):
utils.deprecated_warn(
"Passing `filter` as a keyword is ambiguous.\n\n"
"Use a positional argument for `<5.5.0` behavior.\n"
"Or, `alt.datum['filter'] == ...` if referring to a column named 'filter'.",
version="5.5.0",
)
if utils.is_undefined(predicate):
predicate = depr_filter
else:
more_predicates = *more_predicates, depr_filter
cond = _parse_when(predicate, *more_predicates, empty=empty, **constraints)
return self._add_transform(core.FilterTransform(filter=cond.get("test", cond)))

def transform_flatten(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
columns_sorted = ['Drought', 'Epidemic', 'Earthquake', 'Flood']

alt.Chart(source).transform_filter(
{'and': [
alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted), # Filter data to show only disasters in columns_sorted
alt.FieldRangePredicate(field='Year', range=[1900, 2000]) # Filter data to show only 20th century
]}
alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted),
alt.FieldRangePredicate(field='Year', range=[1900, 2000])
).transform_window(
cumulative_deaths='sum(Deaths)', groupby=['Entity'] # Calculate cumulative sum of Deaths by Entity
).mark_line().encode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
columns_sorted = ['Drought', 'Epidemic', 'Earthquake', 'Flood']

alt.Chart(source).transform_filter(
{'and': [
alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted), # Filter data to show only disasters in columns_sorted
alt.FieldRangePredicate(field='Year', range=[1900, 2000]) # Filter data to show only 20th century
]}
alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted),
alt.FieldRangePredicate(field='Year', range=[1900, 2000])
).transform_window(
cumulative_deaths='sum(Deaths)', groupby=['Entity'] # Calculate cumulative sum of Deaths by Entity
).mark_line().encode(
Expand Down
61 changes: 60 additions & 1 deletion tests/vegalite/v5/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import re
import sys
import tempfile
import warnings
from collections.abc import Mapping
from datetime import date, datetime
from importlib.metadata import version as importlib_version
Expand Down Expand Up @@ -85,7 +86,7 @@ def _make_chart_type(chart_type):


@pytest.fixture
def basic_chart():
def basic_chart() -> alt.Chart:
data = pd.DataFrame(
{
"a": ["A", "B", "C", "D", "E", "F", "G", "H", "I"],
Expand Down Expand Up @@ -1247,6 +1248,64 @@ def test_predicate_composition() -> None:
assert actual_multi == expected_multi


def test_filter_transform_predicates(basic_chart) -> None:
lhs, rhs = alt.datum["b"] >= 30, alt.datum["b"] < 60
expected = [{"filter": lhs & rhs}]
actual = basic_chart.transform_filter(lhs, rhs).to_dict()["transform"]
assert actual == expected


def test_filter_transform_constraints(basic_chart) -> None:
lhs, rhs = alt.datum["a"] == "A", alt.datum["b"] == 30
expected = [{"filter": lhs & rhs}]
actual = basic_chart.transform_filter(a="A", b=30).to_dict()["transform"]
assert actual == expected


def test_filter_transform_predicates_constraints(basic_chart) -> None:
from functools import reduce
from operator import and_

predicates = (
alt.datum["a"] != "A",
alt.datum["a"] != "B",
alt.datum["a"] != "C",
alt.datum["b"] > 1,
alt.datum["b"] < 99,
)
constraints = {"b": 30, "a": "D"}
pred_constraints = *predicates, alt.datum["b"] == 30, alt.datum["a"] != "D"
expected = [{"filter": reduce(and_, pred_constraints)}]
actual = basic_chart.transform_filter(*predicates, **constraints).to_dict()[
"transform"
]
assert actual == expected


def test_filter_transform_errors(basic_chart) -> None:
NO_ARGS = r"At least one.+Undefined"
FILTER_KWARGS = r"ambiguous"

depr_filter = {"field": "year", "oneOf": [1955, 2000]}
expected = [{"filter": depr_filter}]

with pytest.raises(TypeError, match=NO_ARGS):
basic_chart.transform_filter()
with pytest.raises(TypeError, match=NO_ARGS):
basic_chart.transform_filter(empty=True)
with pytest.raises(TypeError, match=NO_ARGS):
basic_chart.transform_filter(empty=False)

with pytest.warns(alt.AltairDeprecationWarning, match=FILTER_KWARGS):
basic_chart.transform_filter(filter=depr_filter)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=alt.AltairDeprecationWarning)
actual = basic_chart.transform_filter(filter=depr_filter).to_dict()["transform"]

assert actual == expected


def test_resolve_methods():
chart = alt.LayerChart().resolve_axis(x="shared", y="independent")
assert chart.resolve == alt.Resolve(
Expand Down