Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support &, |, ~ on all ...Predicate classes #3668

Merged
merged 10 commits into from
Nov 3, 2024
26 changes: 9 additions & 17 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,25 +410,25 @@ def to_dict(self) -> dict[str, str | dict[str, Any]]:
msg = f"Unrecognized parameter type: {self.param_type}"
raise ValueError(msg)

def __invert__(self) -> SelectionPredicateComposition | Any:
def __invert__(self) -> PredicateComposition | Any:
if self.param_type == "selection":
return SelectionPredicateComposition({"not": {"param": self.name}})
return core.PredicateComposition({"not": {"param": self.name}})
else:
return _expr_core.OperatorMixin.__invert__(self)

def __and__(self, other: Any) -> SelectionPredicateComposition | Any:
def __and__(self, other: Any) -> PredicateComposition | Any:
if self.param_type == "selection":
if isinstance(other, Parameter):
other = {"param": other.name}
return SelectionPredicateComposition({"and": [{"param": self.name}, other]})
return core.PredicateComposition({"and": [{"param": self.name}, other]})
else:
return _expr_core.OperatorMixin.__and__(self, other)

def __or__(self, other: Any) -> SelectionPredicateComposition | Any:
def __or__(self, other: Any) -> PredicateComposition | Any:
if self.param_type == "selection":
if isinstance(other, Parameter):
other = {"param": other.name}
return SelectionPredicateComposition({"or": [{"param": self.name}, other]})
return core.PredicateComposition({"or": [{"param": self.name}, other]})
else:
return _expr_core.OperatorMixin.__or__(self, other)

Expand Down Expand Up @@ -458,15 +458,7 @@ def __getitem__(self, field_name: str) -> GetItemExpression:


# Enables use of ~, &, | with compositions of selection objects.
class SelectionPredicateComposition(core.PredicateComposition):
def __invert__(self) -> SelectionPredicateComposition:
return SelectionPredicateComposition({"not": self.to_dict()})

def __and__(self, other: SchemaBase) -> SelectionPredicateComposition:
return SelectionPredicateComposition({"and": [self.to_dict(), other.to_dict()]})

def __or__(self, other: SchemaBase) -> SelectionPredicateComposition:
return SelectionPredicateComposition({"or": [self.to_dict(), other.to_dict()]})
SelectionPredicateComposition = core.PredicateComposition


class ParameterExpression(_expr_core.OperatorMixin):
Expand Down Expand Up @@ -532,7 +524,7 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool:
"""Permitted types for `predicate`."""

_ComposablePredicateType: TypeAlias = Union[
_expr_core.OperatorMixin, SelectionPredicateComposition
_expr_core.OperatorMixin, core.PredicateComposition
]
"""Permitted types for `&` reduced predicates."""

Expand Down Expand Up @@ -764,7 +756,7 @@ def _validate_composables(
predicates: Iterable[Any], /
) -> Iterator[_ComposablePredicateType]:
for p in predicates:
if isinstance(p, (_expr_core.OperatorMixin, SelectionPredicateComposition)):
if isinstance(p, (_expr_core.OperatorMixin, core.PredicateComposition)):
yield p
else:
msg = (
Expand Down
9 changes: 9 additions & 0 deletions altair/vegalite/v5/schema/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16153,6 +16153,15 @@ class PredicateComposition(VegaLiteSchema):
def __init__(self, *args, **kwds):
super().__init__(*args, **kwds)

def __invert__(self) -> PredicateComposition:
return PredicateComposition({"not": self.to_dict()})

def __and__(self, other: SchemaBase) -> PredicateComposition:
return PredicateComposition({"and": [self.to_dict(), other.to_dict()]})

def __or__(self, other: SchemaBase) -> PredicateComposition:
return PredicateComposition({"or": [self.to_dict(), other.to_dict()]})


class LogicalAndPredicate(PredicateComposition):
"""
Expand Down
46 changes: 28 additions & 18 deletions doc/user_guide/transform/filter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,35 +157,45 @@ to select the data to be shown in the top chart:
Logical Operands
^^^^^^^^^^^^^^^^
At times it is useful to combine several types of predicates into a single
selection. This can be accomplished using the various logical operand classes:
selection. We can use ``&``, ``|`` and ``~`` for respectively
``AND``, ``OR`` and ``NOT`` logical composition operands.

- :class:`~LogicalOrPredicate`
- :class:`~LogicalAndPredicate`
- :class:`~LogicalNotPredicate`
For example, here we wish to plot US population distributions for all data *except* the years *1950-1960*.

These are not yet part of the Altair interface
(see `Issue 695 <https://github.com/vega/altair/issues/695>`_)
but can be constructed explicitly; for example, here we plot US population
distributions for all data *except* the years 1950-1960,
by applying a ``LogicalNotPredicate`` schema to a ``FieldRangePredicate``:
First, we use a :class:`~FieldRangePredicate` to select *1950-1960*:

.. altair-plot::

:output: none

import altair as alt
from vega_datasets import data

pop = data.population.url

alt.Chart(pop).mark_line().encode(
x='age:O',
y='sum(people):Q',
color='year:O'
source = data.population.url
chart = alt.Chart(source).mark_line().encode(
x="age:O",
y="sum(people):Q",
color="year:O"
).properties(
width=600, height=200
).transform_filter(
{'not': alt.FieldRangePredicate(field='year', range=[1950, 1960])}
)

between_1950_60 = alt.FieldRangePredicate(field="year", range=[1950, 1960])

Then, we can *invert* this selection using ``~``:

.. altair-plot::

# NOT between 1950-1960
chart.transform_filter(~between_1950_60)

We can further refine our filter by *composing* multiple predicates together.
In this case, using ``alt.datum``:

.. altair-plot::

chart.transform_filter(~between_1950_60 & (alt.datum.age <= 70))
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved


Transform Options
^^^^^^^^^^^^^^^^^
The :meth:`~Chart.transform_filter` method is built on the :class:`~FilterTransform`
Expand Down
55 changes: 55 additions & 0 deletions tests/vegalite/v5/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import re
import sys
import tempfile
from collections.abc import Mapping
from datetime import date, datetime
from importlib.metadata import version as importlib_version
from importlib.util import find_spec
Expand Down Expand Up @@ -1192,6 +1193,60 @@ def test_filter_transform_selection_predicates():
]


def test_predicate_composition() -> None:
columns = ["Drought", "Epidemic", "Earthquake", "Flood"]
field_one_of = alt.FieldOneOfPredicate(field="Entity", oneOf=columns)
field_range = alt.FieldRangePredicate(field="Year", range=[1900, 2000])
fields_and = field_one_of & field_range
expected_and = {
"and": [
{"field": "Entity", "oneOf": columns},
{"field": "Year", "range": [1900, 2000]},
]
}
assert isinstance(fields_and, alt.PredicateComposition)
actual_and = fields_and.to_dict()

# NOTE: Extra guarantee that something hasn't overloaded `__eq__` or `to_dict`
assert isinstance(actual_and, Mapping)
assert isinstance(actual_and == expected_and, bool)

assert actual_and == expected_and

actual_when = (
alt.when(field_one_of, field_range).then(alt.value(0)).otherwise(alt.value(1))
)
expected_when = {"condition": [{"test": fields_and, "value": 0}], "value": 1}
assert actual_when == expected_when

field_range = alt.FieldRangePredicate(field="year", range=[1950, 1960])
field_range_not = ~field_range
expected_not = {"not": {"field": "year", "range": [1950, 1960]}}
assert isinstance(field_range_not, alt.PredicateComposition)
actual_not = field_range_not.to_dict()
assert actual_not == expected_not

expected_or = alt.LogicalOrPredicate(
**{"or": [field_range, field_one_of]}
).to_dict()
actual_or = (field_range | field_one_of).to_dict()
assert actual_or == expected_or

param_pred = alt.ParameterPredicate(param="dummy_1", empty=True)
field_eq = alt.FieldEqualPredicate(equal=999, field="measure")
field_gt = alt.FieldGTPredicate(gt=4, field="measure 2")
expected_multi = alt.LogicalOrPredicate(
**{
"or": [
alt.LogicalNotPredicate(**{"not": param_pred}),
alt.LogicalAndPredicate(**{"and": [field_eq, field_gt]}),
]
}
).to_dict()
actual_multi = (~param_pred | (field_eq & field_gt)).to_dict()
assert actual_multi == expected_multi


def test_resolve_methods():
chart = alt.LayerChart().resolve_axis(x="shared", y="independent")
assert chart.resolve == alt.Resolve(
Expand Down
59 changes: 57 additions & 2 deletions tools/generate_schema_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
from itertools import chain
from operator import attrgetter
from pathlib import Path
from typing import TYPE_CHECKING, Any, Final, Literal
from typing import TYPE_CHECKING, Any, Final, Generic, Literal, TypedDict, TypeVar
from urllib import request

if sys.version_info >= (3, 14):
from typing import TypedDict
else:
from typing_extensions import TypedDict

import vl_convert as vlc

sys.path.insert(0, str(Path.cwd()))
Expand Down Expand Up @@ -317,6 +322,18 @@ def encode(self, *args: Any, {method_args}) -> Self:
return copy
'''

# Enables use of ~, &, | with compositions of selection objects.
DUNDER_PREDICATE_COMPOSITION = """
def __invert__(self) -> PredicateComposition:
return PredicateComposition({"not": self.to_dict()})

def __and__(self, other: SchemaBase) -> PredicateComposition:
return PredicateComposition({"and": [self.to_dict(), other.to_dict()]})

def __or__(self, other: SchemaBase) -> PredicateComposition:
return PredicateComposition({"or": [self.to_dict(), other.to_dict()]})
"""


# NOTE: Not yet reasonable to generalize `TypeAliasType`, `TypeVar`
# Revisit if this starts to become more common
Expand Down Expand Up @@ -431,6 +448,37 @@ class {classname}({basename}):
)


class MethodSchemaGenerator(SchemaGenerator):
"""Base template w/ an extra slot `{method_code}` after `{init_code}`."""

schema_class_template = textwrap.dedent(
'''
class {classname}({basename}):
"""{docstring}"""
_schema = {schema!r}

{init_code}

{method_code}
'''
)


SchGen = TypeVar("SchGen", bound=SchemaGenerator)


class OverridesItem(TypedDict, Generic[SchGen]):
tp: type[SchGen]
kwds: dict[str, Any]


CORE_OVERRIDES: dict[str, OverridesItem[SchemaGenerator]] = {
"PredicateComposition": OverridesItem(
tp=MethodSchemaGenerator, kwds={"method_code": DUNDER_PREDICATE_COMPOSITION}
)
}


class FieldSchemaGenerator(SchemaGenerator):
schema_class_template = textwrap.dedent(
'''
Expand Down Expand Up @@ -656,13 +704,20 @@ def generate_vegalite_schema_wrapper(fp: Path, /) -> str:
defschema = {"$ref": "#/definitions/" + name}
defschema_repr = {"$ref": "#/definitions/" + name}
name = get_valid_identifier(name)
definitions[name] = SchemaGenerator(
if overrides := CORE_OVERRIDES.get(name):
tp = overrides["tp"]
kwds = overrides["kwds"]
else:
tp = SchemaGenerator
kwds = {}
definitions[name] = tp(
name,
schema=defschema,
schemarepr=defschema_repr,
rootschema=rootschema,
basename=basename,
rootschemarepr=CodeSnippet(f"{basename}._rootschema"),
**kwds,
)
for name, schema in definitions.items():
graph[name] = []
Expand Down
14 changes: 9 additions & 5 deletions tools/schemapi/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,20 @@ def schema_class(self) -> str:
basename = self.basename
else:
basename = ", ".join(self.basename)
docstring = self.docstring(indent=4)
init_code = self.init_code(indent=4)
if type(self).haspropsetters:
method_code = self.overload_code(indent=4)
else:
method_code = self.kwargs.pop("method_code", None)
return self.schema_class_template.format(
classname=self.classname,
basename=basename,
schema=schemarepr,
rootschema=rootschemarepr,
docstring=self.docstring(indent=4),
init_code=self.init_code(indent=4),
method_code=(
self.overload_code(indent=4) if type(self).haspropsetters else None
),
docstring=docstring,
init_code=init_code,
method_code=method_code,
**self.kwargs,
)

Expand Down