diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index d4868cf1b..14d47064c 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -410,25 +410,25 @@ def to_dict(self) -> dict[str, str | dict[str, Any]]: msg = f"Unrecognized parameter type: {self.param_type}" raise ValueError(msg) - def __invert__(self) -> SelectionPredicateComposition | Any: + def __invert__(self) -> PredicateComposition | Any: if self.param_type == "selection": - return SelectionPredicateComposition({"not": {"param": self.name}}) + return core.PredicateComposition({"not": {"param": self.name}}) else: return _expr_core.OperatorMixin.__invert__(self) - def __and__(self, other: Any) -> SelectionPredicateComposition | Any: + def __and__(self, other: Any) -> PredicateComposition | Any: if self.param_type == "selection": if isinstance(other, Parameter): other = {"param": other.name} - return SelectionPredicateComposition({"and": [{"param": self.name}, other]}) + return core.PredicateComposition({"and": [{"param": self.name}, other]}) else: return _expr_core.OperatorMixin.__and__(self, other) - def __or__(self, other: Any) -> SelectionPredicateComposition | Any: + def __or__(self, other: Any) -> PredicateComposition | Any: if self.param_type == "selection": if isinstance(other, Parameter): other = {"param": other.name} - return SelectionPredicateComposition({"or": [{"param": self.name}, other]}) + return core.PredicateComposition({"or": [{"param": self.name}, other]}) else: return _expr_core.OperatorMixin.__or__(self, other) @@ -458,15 +458,7 @@ def __getitem__(self, field_name: str) -> GetItemExpression: # Enables use of ~, &, | with compositions of selection objects. -class SelectionPredicateComposition(core.PredicateComposition): - def __invert__(self) -> SelectionPredicateComposition: - return SelectionPredicateComposition({"not": self.to_dict()}) - - def __and__(self, other: SchemaBase) -> SelectionPredicateComposition: - return SelectionPredicateComposition({"and": [self.to_dict(), other.to_dict()]}) - - def __or__(self, other: SchemaBase) -> SelectionPredicateComposition: - return SelectionPredicateComposition({"or": [self.to_dict(), other.to_dict()]}) +SelectionPredicateComposition = core.PredicateComposition class ParameterExpression(_expr_core.OperatorMixin): @@ -532,7 +524,7 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool: """Permitted types for `predicate`.""" _ComposablePredicateType: TypeAlias = Union[ - _expr_core.OperatorMixin, SelectionPredicateComposition + _expr_core.OperatorMixin, core.PredicateComposition ] """Permitted types for `&` reduced predicates.""" @@ -764,7 +756,7 @@ def _validate_composables( predicates: Iterable[Any], / ) -> Iterator[_ComposablePredicateType]: for p in predicates: - if isinstance(p, (_expr_core.OperatorMixin, SelectionPredicateComposition)): + if isinstance(p, (_expr_core.OperatorMixin, core.PredicateComposition)): yield p else: msg = ( diff --git a/altair/vegalite/v5/schema/core.py b/altair/vegalite/v5/schema/core.py index 17ba53d10..51ae8ab5d 100644 --- a/altair/vegalite/v5/schema/core.py +++ b/altair/vegalite/v5/schema/core.py @@ -16153,6 +16153,15 @@ class PredicateComposition(VegaLiteSchema): def __init__(self, *args, **kwds): super().__init__(*args, **kwds) + def __invert__(self) -> PredicateComposition: + return PredicateComposition({"not": self.to_dict()}) + + def __and__(self, other: SchemaBase) -> PredicateComposition: + return PredicateComposition({"and": [self.to_dict(), other.to_dict()]}) + + def __or__(self, other: SchemaBase) -> PredicateComposition: + return PredicateComposition({"or": [self.to_dict(), other.to_dict()]}) + class LogicalAndPredicate(PredicateComposition): """ diff --git a/doc/user_guide/transform/filter.rst b/doc/user_guide/transform/filter.rst index fb4c7420f..39c268210 100644 --- a/doc/user_guide/transform/filter.rst +++ b/doc/user_guide/transform/filter.rst @@ -157,35 +157,45 @@ to select the data to be shown in the top chart: Logical Operands ^^^^^^^^^^^^^^^^ At times it is useful to combine several types of predicates into a single -selection. This can be accomplished using the various logical operand classes: +selection. We can use ``&``, ``|`` and ``~`` for respectively +``AND``, ``OR`` and ``NOT`` logical composition operands. -- :class:`~LogicalOrPredicate` -- :class:`~LogicalAndPredicate` -- :class:`~LogicalNotPredicate` +For example, here we wish to plot US population distributions for all data *except* the years *1950-1960*. -These are not yet part of the Altair interface -(see `Issue 695 `_) -but can be constructed explicitly; for example, here we plot US population -distributions for all data *except* the years 1950-1960, -by applying a ``LogicalNotPredicate`` schema to a ``FieldRangePredicate``: +First, we use a :class:`~FieldRangePredicate` to select *1950-1960*: .. altair-plot:: - + :output: none + import altair as alt from vega_datasets import data - pop = data.population.url - - alt.Chart(pop).mark_line().encode( - x='age:O', - y='sum(people):Q', - color='year:O' + source = data.population.url + chart = alt.Chart(source).mark_line().encode( + x="age:O", + y="sum(people):Q", + color="year:O" ).properties( width=600, height=200 - ).transform_filter( - {'not': alt.FieldRangePredicate(field='year', range=[1950, 1960])} ) + between_1950_60 = alt.FieldRangePredicate(field="year", range=[1950, 1960]) + +Then, we can *invert* this selection using ``~``: + +.. altair-plot:: + + # NOT between 1950-1960 + chart.transform_filter(~between_1950_60) + +We can further refine our filter by *composing* multiple predicates together. +In this case, using ``alt.datum``: + +.. altair-plot:: + + chart.transform_filter(~between_1950_60 & (alt.datum.age <= 70)) + + Transform Options ^^^^^^^^^^^^^^^^^ The :meth:`~Chart.transform_filter` method is built on the :class:`~FilterTransform` diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index 7c4b6a151..98b30a9b6 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -10,6 +10,7 @@ import re import sys import tempfile +from collections.abc import Mapping from datetime import date, datetime from importlib.metadata import version as importlib_version from importlib.util import find_spec @@ -1192,6 +1193,60 @@ def test_filter_transform_selection_predicates(): ] +def test_predicate_composition() -> None: + columns = ["Drought", "Epidemic", "Earthquake", "Flood"] + field_one_of = alt.FieldOneOfPredicate(field="Entity", oneOf=columns) + field_range = alt.FieldRangePredicate(field="Year", range=[1900, 2000]) + fields_and = field_one_of & field_range + expected_and = { + "and": [ + {"field": "Entity", "oneOf": columns}, + {"field": "Year", "range": [1900, 2000]}, + ] + } + assert isinstance(fields_and, alt.PredicateComposition) + actual_and = fields_and.to_dict() + + # NOTE: Extra guarantee that something hasn't overloaded `__eq__` or `to_dict` + assert isinstance(actual_and, Mapping) + assert isinstance(actual_and == expected_and, bool) + + assert actual_and == expected_and + + actual_when = ( + alt.when(field_one_of, field_range).then(alt.value(0)).otherwise(alt.value(1)) + ) + expected_when = {"condition": [{"test": fields_and, "value": 0}], "value": 1} + assert actual_when == expected_when + + field_range = alt.FieldRangePredicate(field="year", range=[1950, 1960]) + field_range_not = ~field_range + expected_not = {"not": {"field": "year", "range": [1950, 1960]}} + assert isinstance(field_range_not, alt.PredicateComposition) + actual_not = field_range_not.to_dict() + assert actual_not == expected_not + + expected_or = alt.LogicalOrPredicate( + **{"or": [field_range, field_one_of]} + ).to_dict() + actual_or = (field_range | field_one_of).to_dict() + assert actual_or == expected_or + + param_pred = alt.ParameterPredicate(param="dummy_1", empty=True) + field_eq = alt.FieldEqualPredicate(equal=999, field="measure") + field_gt = alt.FieldGTPredicate(gt=4, field="measure 2") + expected_multi = alt.LogicalOrPredicate( + **{ + "or": [ + alt.LogicalNotPredicate(**{"not": param_pred}), + alt.LogicalAndPredicate(**{"and": [field_eq, field_gt]}), + ] + } + ).to_dict() + actual_multi = (~param_pred | (field_eq & field_gt)).to_dict() + assert actual_multi == expected_multi + + def test_resolve_methods(): chart = alt.LayerChart().resolve_axis(x="shared", y="independent") assert chart.resolve == alt.Resolve( diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index e7307078a..70c3980e4 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -12,9 +12,14 @@ from itertools import chain from operator import attrgetter from pathlib import Path -from typing import TYPE_CHECKING, Any, Final, Literal +from typing import TYPE_CHECKING, Any, Final, Generic, Literal, TypedDict, TypeVar from urllib import request +if sys.version_info >= (3, 14): + from typing import TypedDict +else: + from typing_extensions import TypedDict + import vl_convert as vlc sys.path.insert(0, str(Path.cwd())) @@ -317,6 +322,18 @@ def encode(self, *args: Any, {method_args}) -> Self: return copy ''' +# Enables use of ~, &, | with compositions of selection objects. +DUNDER_PREDICATE_COMPOSITION = """ + def __invert__(self) -> PredicateComposition: + return PredicateComposition({"not": self.to_dict()}) + + def __and__(self, other: SchemaBase) -> PredicateComposition: + return PredicateComposition({"and": [self.to_dict(), other.to_dict()]}) + + def __or__(self, other: SchemaBase) -> PredicateComposition: + return PredicateComposition({"or": [self.to_dict(), other.to_dict()]}) +""" + # NOTE: Not yet reasonable to generalize `TypeAliasType`, `TypeVar` # Revisit if this starts to become more common @@ -431,6 +448,37 @@ class {classname}({basename}): ) +class MethodSchemaGenerator(SchemaGenerator): + """Base template w/ an extra slot `{method_code}` after `{init_code}`.""" + + schema_class_template = textwrap.dedent( + ''' + class {classname}({basename}): + """{docstring}""" + _schema = {schema!r} + + {init_code} + + {method_code} + ''' + ) + + +SchGen = TypeVar("SchGen", bound=SchemaGenerator) + + +class OverridesItem(TypedDict, Generic[SchGen]): + tp: type[SchGen] + kwds: dict[str, Any] + + +CORE_OVERRIDES: dict[str, OverridesItem[SchemaGenerator]] = { + "PredicateComposition": OverridesItem( + tp=MethodSchemaGenerator, kwds={"method_code": DUNDER_PREDICATE_COMPOSITION} + ) +} + + class FieldSchemaGenerator(SchemaGenerator): schema_class_template = textwrap.dedent( ''' @@ -656,13 +704,20 @@ def generate_vegalite_schema_wrapper(fp: Path, /) -> str: defschema = {"$ref": "#/definitions/" + name} defschema_repr = {"$ref": "#/definitions/" + name} name = get_valid_identifier(name) - definitions[name] = SchemaGenerator( + if overrides := CORE_OVERRIDES.get(name): + tp = overrides["tp"] + kwds = overrides["kwds"] + else: + tp = SchemaGenerator + kwds = {} + definitions[name] = tp( name, schema=defschema, schemarepr=defschema_repr, rootschema=rootschema, basename=basename, rootschemarepr=CodeSnippet(f"{basename}._rootschema"), + **kwds, ) for name, schema in definitions.items(): graph[name] = [] diff --git a/tools/schemapi/codegen.py b/tools/schemapi/codegen.py index 47d96dcd7..37d512087 100644 --- a/tools/schemapi/codegen.py +++ b/tools/schemapi/codegen.py @@ -260,16 +260,20 @@ def schema_class(self) -> str: basename = self.basename else: basename = ", ".join(self.basename) + docstring = self.docstring(indent=4) + init_code = self.init_code(indent=4) + if type(self).haspropsetters: + method_code = self.overload_code(indent=4) + else: + method_code = self.kwargs.pop("method_code", None) return self.schema_class_template.format( classname=self.classname, basename=basename, schema=schemarepr, rootschema=rootschemarepr, - docstring=self.docstring(indent=4), - init_code=self.init_code(indent=4), - method_code=( - self.overload_code(indent=4) if type(self).haspropsetters else None - ), + docstring=docstring, + init_code=init_code, + method_code=method_code, **self.kwargs, )