Skip to content

Commit

Permalink
GroupBy: Avoid guessing variable types
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Oct 3, 2024
1 parent 60a2f61 commit f025959
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 77 deletions.
53 changes: 39 additions & 14 deletions Orange/data/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import lru_cache
from typing import Callable, Dict, List, Tuple, Union
from typing import Callable, Dict, List, Tuple, Union, Type

import pandas as pd

Expand Down Expand Up @@ -39,15 +39,20 @@ def __init__(self, table: Table, by: List[Variable]):
df = table_to_frame(table, include_metas=True)
# observed=True keeps only groups with at leas one instance
self.group_by = df.groupby([a.name for a in by], observed=True)
self.by = tuple(by)

# lru_cache that is caches on the object level
self.compute_aggregation = lru_cache()(self._compute_aggregation)

AggDescType = Union[str,
Callable,
Tuple[str, Union[str, Callable]],
Tuple[str, Union[str, Callable], Union[Type[Variable], bool]]
]

def aggregate(
self,
aggregations: Dict[
Variable, List[Union[str, Callable, Tuple[str, Union[str, Callable]]]]
],
aggregations: Dict[Variable, List[AggDescType]],
callback: Callable = dummy_callback,
) -> Table:
"""
Expand All @@ -57,12 +62,16 @@ def aggregate(
----------
aggregations
The dictionary that defines aggregations that need to be computed
for variables. We support two formats:
for variables. We support three formats:
- {variable name: [agg function 1, agg function 2]}
- {variable name: [(agg name 1, agg function 1), (agg name 1, agg function 1)]}
- {variable name: [(agg name 1, agg function 1, output_variable_type1), ...]}
Where agg name is the aggregation name used in the output column name.
Aggregation function can be either function or string that defines
aggregation in Pandas (e.g. mean).
output_variable_type can be a type for a new variable, True to copy
the input variable, or False to create a new variable of the same type
as the input
callback
Callback function to report the progress
Expand All @@ -75,29 +84,45 @@ def aggregate(
count = 0

result_agg = []
output_variables = []
for col, aggs in aggregations.items():
for agg in aggs:
res = self._compute_aggregation(col, agg)
res, var = self._compute_aggregation(col, agg)
result_agg.append(res)
output_variables.append(var)
count += 1
callback(count / num_aggs * 0.8)

agg_table = self._aggregations_to_table(result_agg)
agg_table = self._aggregations_to_table(result_agg, output_variables)
callback(1)
return agg_table

def _compute_aggregation(
self, col: Variable, agg: Union[str, Callable, Tuple[str, Union[str, Callable]]]
) -> pd.Series:
self, col: Variable, agg: AggDescType) -> Tuple[pd.Series, Variable]:
# use named aggregation to avoid issues with same column names when reset_index
if isinstance(agg, tuple):
name, agg = agg
name, agg, var_type, *_ = (*agg, None)
else:
name = agg if isinstance(agg, str) else agg.__name__
var_type = None
col_name = f"{col.name} - {name}"
return self.group_by[col.name].agg(**{col_name: agg})

def _aggregations_to_table(self, aggregations: List[pd.Series]) -> Table:
agg_col = self.group_by[col.name].agg(**{col_name: agg})
match var_type:
case True:
var = col.copy(name=col_name)
case False:
var = type(col).make(name=col_name)
case None:
var = None
case var_type:
assert issubclass(var_type, Variable)
var = var_type.make(name=col_name)
return agg_col, var

def _aggregations_to_table(
self,
aggregations: List[pd.Series],
output_variables: List[Union[Variable, None]]) -> Table:
"""Concatenate aggregation series and convert back to Table"""
if aggregations:
df = pd.concat(aggregations, axis=1)
Expand All @@ -107,7 +132,7 @@ def _aggregations_to_table(self, aggregations: List[pd.Series]) -> Table:
df = df.drop(columns=df.columns)
gb_attributes = df.index.names
df = df.reset_index() # move group by var that are in index to columns
table = table_from_frame(df)
table = table_from_frame(df, variables=(*self.by, *output_variables))

# group by variables should be last two columns in metas in the output
metas = table.domain.metas
Expand Down
84 changes: 54 additions & 30 deletions Orange/data/pandas_compat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Pandas DataFrame↔Table conversion helpers"""
from unittest.mock import patch
from itertools import zip_longest

import numpy as np
from scipy import sparse as sp
Expand Down Expand Up @@ -249,7 +250,14 @@ def to_categorical(s, _):
return np.asarray(x)


def vars_from_df(df, role=None, force_nominal=False):
def to_numeric(s, _):
return np.asarray(pd.to_numeric(s))


def vars_from_df(df, role=None, force_nominal=False, variables=None):
if variables is not None:
assert len(variables) == len(df.columns)

if role is None and hasattr(df, 'orange_role'):
role = df.orange_role
df = _reset_index(df)
Expand All @@ -258,39 +266,53 @@ def vars_from_df(df, role=None, force_nominal=False):
exprs = [], [], []
vars_ = [], [], []

for column in df.columns:
def _convert_string(s, _):
return np.asarray(
# to object so that fillna can replace with nans if Unknown in nan
# replace nan with object Unknown assure that all values are string
s.astype(object).fillna(StringVariable.Unknown).astype(str),
dtype=object
)

conversions = {
DiscreteVariable: to_categorical,
ContinuousVariable: to_numeric,
TimeVariable: _convert_datetime,
StringVariable: _convert_string
}

for column, var in zip_longest(df.columns, variables or [], fillvalue=None):
s = df[column]
_role = Role.Attribute if role is None else role
if hasattr(df, 'orange_variables') and column in df.orange_variables:
if var is not None:
if not var.is_primitive():
_role = Role.Meta
expr = conversions[type(var)]
elif hasattr(df, 'orange_variables') and column in df.orange_variables:
original_var = df.orange_variables[column]
var = original_var.copy(compute_value=None)
expr = None
elif _is_datetime(s):
var = TimeVariable(str(column))
expr = _convert_datetime
elif _is_discrete(s, force_nominal):
discrete = s.astype("category").cat
var = DiscreteVariable(
str(column), discrete.categories.astype(str).tolist()
)
expr = to_categorical
elif is_numeric_dtype(s):
var = ContinuousVariable(
# set number of decimals to 0 if int else keeps default behaviour
str(column), number_of_decimals=(0 if is_integer_dtype(s) else None)
)
expr = None
else:
if role is not None and role != Role.Meta:
raise ValueError("String variable must be in metas.")
_role = Role.Meta
var = StringVariable(str(column))
expr = lambda s, _: np.asarray(
# to object so that fillna can replace with nans if Unknown in nan
# replace nan with object Unknown assure that all values are string
s.astype(object).fillna(StringVariable.Unknown).astype(str),
dtype=object
)
if _is_datetime(s):
_is_datetime(s)
var = TimeVariable(str(column))
elif _is_discrete(s, force_nominal):
discrete = s.astype("category").cat
var = DiscreteVariable(
str(column), discrete.categories.astype(str).tolist()
)
elif is_numeric_dtype(s):
var = ContinuousVariable(
# set number of decimals to 0 if int else keeps default behaviour
str(column), number_of_decimals=(0 if is_integer_dtype(s) else None)
)
else:
if role is not None and role != Role.Meta:
raise ValueError("String variable must be in metas.")

Check warning on line 311 in Orange/data/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

Orange/data/pandas_compat.py#L311

Added line #L311 was not covered by tests
_role = Role.Meta
var = StringVariable(str(column))
expr = conversions[type(var)]


cols[_role].append(column)
exprs[_role].append(expr)
Expand Down Expand Up @@ -324,8 +346,10 @@ def vars_from_df(df, role=None, force_nominal=False):
return xym, Domain(*vars_)


def table_from_frame(df, *, force_nominal=False):
XYM, domain = vars_from_df(df, force_nominal=force_nominal)
def table_from_frame(df, *, force_nominal=False, variables=None):
XYM, domain = vars_from_df(df,
force_nominal=force_nominal,
variables=variables)

if hasattr(df, 'orange_weights') and hasattr(df, 'orange_attributes'):
W = [df.orange_weights[i] for i in df.index if i in df.orange_weights]
Expand Down
90 changes: 71 additions & 19 deletions Orange/widgets/data/owgroupby.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections import namedtuple
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Set
from typing import \
Any, Dict, List, Optional, Set, Union, NamedTuple, Callable, Type

import pandas as pd
from numpy import nan
Expand Down Expand Up @@ -44,7 +44,14 @@
from Orange.widgets.utils.itemmodels import DomainModel
from Orange.widgets.widget import OWWidget

Aggregation = namedtuple("Aggregation", ["function", "types"])

class Aggregation(NamedTuple):
function: Union[str, Callable]
types: Set[Type[Variable]]
# Gives the type of the result,
# or True to copy the original variable,
# or False to create a new variable of the same type as the input
result_type: Union[Type[Variable], bool]


def concatenate(x):
Expand Down Expand Up @@ -90,43 +97,88 @@ def span(s):


AGGREGATIONS = {
"Mean": Aggregation("mean", {ContinuousVariable, TimeVariable}),
"Median": Aggregation("median", {ContinuousVariable, TimeVariable}),
"Q1": Aggregation(lambda s: s.quantile(0.25), {ContinuousVariable, TimeVariable}),
"Q3": Aggregation(lambda s: s.quantile(0.75), {ContinuousVariable, TimeVariable}),
"Min. value": Aggregation("min", {ContinuousVariable, TimeVariable}),
"Max. value": Aggregation("max", {ContinuousVariable, TimeVariable}),
"Mean": Aggregation(
"mean",
{ContinuousVariable, TimeVariable},
False),
"Median": Aggregation(
"median",
{ContinuousVariable, TimeVariable},
True),
"Q1": Aggregation(
lambda s: s.quantile(0.25),
{ContinuousVariable, TimeVariable},
True),
"Q3": Aggregation(
lambda s: s.quantile(0.75),
{ContinuousVariable, TimeVariable},
True),
"Min. value": Aggregation(
"min",
{ContinuousVariable, TimeVariable},
True),
"Max. value": Aggregation(
"max",
{ContinuousVariable, TimeVariable},
True),
"Mode": Aggregation(
lambda x: pd.Series.mode(x).get(0, nan),
{ContinuousVariable, DiscreteVariable, TimeVariable}
{ContinuousVariable, DiscreteVariable, TimeVariable},
True
),
"Standard deviation": Aggregation(
std,
{ContinuousVariable, TimeVariable},
ContinuousVariable
),
"Standard deviation": Aggregation(std, {ContinuousVariable, TimeVariable}),
"Variance": Aggregation(var, {ContinuousVariable, TimeVariable}),
"Sum": Aggregation("sum", {ContinuousVariable}),
"Variance": Aggregation(
var,
{ContinuousVariable, TimeVariable},
ContinuousVariable
),
"Sum": Aggregation(
"sum",
{ContinuousVariable},
True),
"Concatenate": Aggregation(
concatenate,
{ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable},
StringVariable
),
"Span": Aggregation(
span,
{ContinuousVariable, TimeVariable},
ContinuousVariable
),
"Span": Aggregation(span, {ContinuousVariable, TimeVariable}),
"First value": Aggregation(
"first", {ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable}
"first",
{ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable},
True
),
"Last value": Aggregation(
"last", {ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable}
"last",
{ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable},
True
),
"Random value": Aggregation(
lambda x: x.sample(1, random_state=0),
{ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable},
True
),
"Count defined": Aggregation(
"count", {ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable}
"count",
{ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable},
ContinuousVariable
),
"Count": Aggregation(
"size", {ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable}
"size",
{ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable},
ContinuousVariable
),
"Proportion defined": Aggregation(
lambda x: x.count() / x.size,
{ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable},
ContinuousVariable
),
}
# list of ordered aggregation names is required on several locations so we
Expand Down Expand Up @@ -166,7 +218,7 @@ def progress(part):

aggregations = {
var: [
(agg, AGGREGATIONS[agg].function)
(agg, AGGREGATIONS[agg].function, AGGREGATIONS[agg].result_type)
for agg in sorted(aggs, key=AGGREGATIONS_ORD.index)
]
for var, aggs in aggregations.items()
Expand Down
Loading

0 comments on commit f025959

Please sign in to comment.