Skip to content

Commit

Permalink
feat: Support &, |, ~ on all ...Predicate classes (#3668)
Browse files Browse the repository at this point in the history
Co-authored-by: Mattijn van Hoek <[email protected]>
  • Loading branch information
dangotbanned and mattijn authored Nov 3, 2024
1 parent c5d3bdf commit 64b2d33
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 42 deletions.
26 changes: 9 additions & 17 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,25 +410,25 @@ def to_dict(self) -> dict[str, str | dict[str, Any]]:
msg = f"Unrecognized parameter type: {self.param_type}"
raise ValueError(msg)

def __invert__(self) -> SelectionPredicateComposition | Any:
def __invert__(self) -> PredicateComposition | Any:
if self.param_type == "selection":
return SelectionPredicateComposition({"not": {"param": self.name}})
return core.PredicateComposition({"not": {"param": self.name}})
else:
return _expr_core.OperatorMixin.__invert__(self)

def __and__(self, other: Any) -> SelectionPredicateComposition | Any:
def __and__(self, other: Any) -> PredicateComposition | Any:
if self.param_type == "selection":
if isinstance(other, Parameter):
other = {"param": other.name}
return SelectionPredicateComposition({"and": [{"param": self.name}, other]})
return core.PredicateComposition({"and": [{"param": self.name}, other]})
else:
return _expr_core.OperatorMixin.__and__(self, other)

def __or__(self, other: Any) -> SelectionPredicateComposition | Any:
def __or__(self, other: Any) -> PredicateComposition | Any:
if self.param_type == "selection":
if isinstance(other, Parameter):
other = {"param": other.name}
return SelectionPredicateComposition({"or": [{"param": self.name}, other]})
return core.PredicateComposition({"or": [{"param": self.name}, other]})
else:
return _expr_core.OperatorMixin.__or__(self, other)

Expand Down Expand Up @@ -458,15 +458,7 @@ def __getitem__(self, field_name: str) -> GetItemExpression:


# Enables use of ~, &, | with compositions of selection objects.
class SelectionPredicateComposition(core.PredicateComposition):
def __invert__(self) -> SelectionPredicateComposition:
return SelectionPredicateComposition({"not": self.to_dict()})

def __and__(self, other: SchemaBase) -> SelectionPredicateComposition:
return SelectionPredicateComposition({"and": [self.to_dict(), other.to_dict()]})

def __or__(self, other: SchemaBase) -> SelectionPredicateComposition:
return SelectionPredicateComposition({"or": [self.to_dict(), other.to_dict()]})
SelectionPredicateComposition = core.PredicateComposition


class ParameterExpression(_expr_core.OperatorMixin):
Expand Down Expand Up @@ -532,7 +524,7 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool:
"""Permitted types for `predicate`."""

_ComposablePredicateType: TypeAlias = Union[
_expr_core.OperatorMixin, SelectionPredicateComposition
_expr_core.OperatorMixin, core.PredicateComposition
]
"""Permitted types for `&` reduced predicates."""

Expand Down Expand Up @@ -764,7 +756,7 @@ def _validate_composables(
predicates: Iterable[Any], /
) -> Iterator[_ComposablePredicateType]:
for p in predicates:
if isinstance(p, (_expr_core.OperatorMixin, SelectionPredicateComposition)):
if isinstance(p, (_expr_core.OperatorMixin, core.PredicateComposition)):
yield p
else:
msg = (
Expand Down
9 changes: 9 additions & 0 deletions altair/vegalite/v5/schema/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16153,6 +16153,15 @@ class PredicateComposition(VegaLiteSchema):
def __init__(self, *args, **kwds):
super().__init__(*args, **kwds)

def __invert__(self) -> PredicateComposition:
return PredicateComposition({"not": self.to_dict()})

def __and__(self, other: SchemaBase) -> PredicateComposition:
return PredicateComposition({"and": [self.to_dict(), other.to_dict()]})

def __or__(self, other: SchemaBase) -> PredicateComposition:
return PredicateComposition({"or": [self.to_dict(), other.to_dict()]})


class LogicalAndPredicate(PredicateComposition):
"""
Expand Down
46 changes: 28 additions & 18 deletions doc/user_guide/transform/filter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,35 +157,45 @@ to select the data to be shown in the top chart:
Logical Operands
^^^^^^^^^^^^^^^^
At times it is useful to combine several types of predicates into a single
selection. This can be accomplished using the various logical operand classes:
selection. We can use ``&``, ``|`` and ``~`` for respectively
``AND``, ``OR`` and ``NOT`` logical composition operands.

- :class:`~LogicalOrPredicate`
- :class:`~LogicalAndPredicate`
- :class:`~LogicalNotPredicate`
For example, here we wish to plot US population distributions for all data *except* the years *1950-1960*.

These are not yet part of the Altair interface
(see `Issue 695 <https://github.com/vega/altair/issues/695>`_)
but can be constructed explicitly; for example, here we plot US population
distributions for all data *except* the years 1950-1960,
by applying a ``LogicalNotPredicate`` schema to a ``FieldRangePredicate``:
First, we use a :class:`~FieldRangePredicate` to select *1950-1960*:

.. altair-plot::

:output: none

import altair as alt
from vega_datasets import data

pop = data.population.url

alt.Chart(pop).mark_line().encode(
x='age:O',
y='sum(people):Q',
color='year:O'
source = data.population.url
chart = alt.Chart(source).mark_line().encode(
x="age:O",
y="sum(people):Q",
color="year:O"
).properties(
width=600, height=200
).transform_filter(
{'not': alt.FieldRangePredicate(field='year', range=[1950, 1960])}
)

between_1950_60 = alt.FieldRangePredicate(field="year", range=[1950, 1960])

Then, we can *invert* this selection using ``~``:

.. altair-plot::

# NOT between 1950-1960
chart.transform_filter(~between_1950_60)

We can further refine our filter by *composing* multiple predicates together.
In this case, using ``alt.datum``:

.. altair-plot::

chart.transform_filter(~between_1950_60 & (alt.datum.age <= 70))


Transform Options
^^^^^^^^^^^^^^^^^
The :meth:`~Chart.transform_filter` method is built on the :class:`~FilterTransform`
Expand Down
55 changes: 55 additions & 0 deletions tests/vegalite/v5/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import re
import sys
import tempfile
from collections.abc import Mapping
from datetime import date, datetime
from importlib.metadata import version as importlib_version
from importlib.util import find_spec
Expand Down Expand Up @@ -1192,6 +1193,60 @@ def test_filter_transform_selection_predicates():
]


def test_predicate_composition() -> None:
columns = ["Drought", "Epidemic", "Earthquake", "Flood"]
field_one_of = alt.FieldOneOfPredicate(field="Entity", oneOf=columns)
field_range = alt.FieldRangePredicate(field="Year", range=[1900, 2000])
fields_and = field_one_of & field_range
expected_and = {
"and": [
{"field": "Entity", "oneOf": columns},
{"field": "Year", "range": [1900, 2000]},
]
}
assert isinstance(fields_and, alt.PredicateComposition)
actual_and = fields_and.to_dict()

# NOTE: Extra guarantee that something hasn't overloaded `__eq__` or `to_dict`
assert isinstance(actual_and, Mapping)
assert isinstance(actual_and == expected_and, bool)

assert actual_and == expected_and

actual_when = (
alt.when(field_one_of, field_range).then(alt.value(0)).otherwise(alt.value(1))
)
expected_when = {"condition": [{"test": fields_and, "value": 0}], "value": 1}
assert actual_when == expected_when

field_range = alt.FieldRangePredicate(field="year", range=[1950, 1960])
field_range_not = ~field_range
expected_not = {"not": {"field": "year", "range": [1950, 1960]}}
assert isinstance(field_range_not, alt.PredicateComposition)
actual_not = field_range_not.to_dict()
assert actual_not == expected_not

expected_or = alt.LogicalOrPredicate(
**{"or": [field_range, field_one_of]}
).to_dict()
actual_or = (field_range | field_one_of).to_dict()
assert actual_or == expected_or

param_pred = alt.ParameterPredicate(param="dummy_1", empty=True)
field_eq = alt.FieldEqualPredicate(equal=999, field="measure")
field_gt = alt.FieldGTPredicate(gt=4, field="measure 2")
expected_multi = alt.LogicalOrPredicate(
**{
"or": [
alt.LogicalNotPredicate(**{"not": param_pred}),
alt.LogicalAndPredicate(**{"and": [field_eq, field_gt]}),
]
}
).to_dict()
actual_multi = (~param_pred | (field_eq & field_gt)).to_dict()
assert actual_multi == expected_multi


def test_resolve_methods():
chart = alt.LayerChart().resolve_axis(x="shared", y="independent")
assert chart.resolve == alt.Resolve(
Expand Down
59 changes: 57 additions & 2 deletions tools/generate_schema_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
from itertools import chain
from operator import attrgetter
from pathlib import Path
from typing import TYPE_CHECKING, Any, Final, Literal
from typing import TYPE_CHECKING, Any, Final, Generic, Literal, TypedDict, TypeVar
from urllib import request

if sys.version_info >= (3, 14):
from typing import TypedDict
else:
from typing_extensions import TypedDict

import vl_convert as vlc

sys.path.insert(0, str(Path.cwd()))
Expand Down Expand Up @@ -317,6 +322,18 @@ def encode(self, *args: Any, {method_args}) -> Self:
return copy
'''

# Enables use of ~, &, | with compositions of selection objects.
DUNDER_PREDICATE_COMPOSITION = """
def __invert__(self) -> PredicateComposition:
return PredicateComposition({"not": self.to_dict()})
def __and__(self, other: SchemaBase) -> PredicateComposition:
return PredicateComposition({"and": [self.to_dict(), other.to_dict()]})
def __or__(self, other: SchemaBase) -> PredicateComposition:
return PredicateComposition({"or": [self.to_dict(), other.to_dict()]})
"""


# NOTE: Not yet reasonable to generalize `TypeAliasType`, `TypeVar`
# Revisit if this starts to become more common
Expand Down Expand Up @@ -431,6 +448,37 @@ class {classname}({basename}):
)


class MethodSchemaGenerator(SchemaGenerator):
"""Base template w/ an extra slot `{method_code}` after `{init_code}`."""

schema_class_template = textwrap.dedent(
'''
class {classname}({basename}):
"""{docstring}"""
_schema = {schema!r}
{init_code}
{method_code}
'''
)


SchGen = TypeVar("SchGen", bound=SchemaGenerator)


class OverridesItem(TypedDict, Generic[SchGen]):
tp: type[SchGen]
kwds: dict[str, Any]


CORE_OVERRIDES: dict[str, OverridesItem[SchemaGenerator]] = {
"PredicateComposition": OverridesItem(
tp=MethodSchemaGenerator, kwds={"method_code": DUNDER_PREDICATE_COMPOSITION}
)
}


class FieldSchemaGenerator(SchemaGenerator):
schema_class_template = textwrap.dedent(
'''
Expand Down Expand Up @@ -656,13 +704,20 @@ def generate_vegalite_schema_wrapper(fp: Path, /) -> str:
defschema = {"$ref": "#/definitions/" + name}
defschema_repr = {"$ref": "#/definitions/" + name}
name = get_valid_identifier(name)
definitions[name] = SchemaGenerator(
if overrides := CORE_OVERRIDES.get(name):
tp = overrides["tp"]
kwds = overrides["kwds"]
else:
tp = SchemaGenerator
kwds = {}
definitions[name] = tp(
name,
schema=defschema,
schemarepr=defschema_repr,
rootschema=rootschema,
basename=basename,
rootschemarepr=CodeSnippet(f"{basename}._rootschema"),
**kwds,
)
for name, schema in definitions.items():
graph[name] = []
Expand Down
14 changes: 9 additions & 5 deletions tools/schemapi/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,20 @@ def schema_class(self) -> str:
basename = self.basename
else:
basename = ", ".join(self.basename)
docstring = self.docstring(indent=4)
init_code = self.init_code(indent=4)
if type(self).haspropsetters:
method_code = self.overload_code(indent=4)
else:
method_code = self.kwargs.pop("method_code", None)
return self.schema_class_template.format(
classname=self.classname,
basename=basename,
schema=schemarepr,
rootschema=rootschemarepr,
docstring=self.docstring(indent=4),
init_code=self.init_code(indent=4),
method_code=(
self.overload_code(indent=4) if type(self).haspropsetters else None
),
docstring=docstring,
init_code=init_code,
method_code=method_code,
**self.kwargs,
)

Expand Down

0 comments on commit 64b2d33

Please sign in to comment.