Skip to content

Commit

Permalink
Merge pull request #53 from transferwise/deprecate_min_segments
Browse files Browse the repository at this point in the history
Deprecate min_segments parameter
  • Loading branch information
AlxdrPolyakov authored May 14, 2024
2 parents 1f1b216 + 7e846f8 commit 0f4b7ed
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 30 deletions.
8 changes: 4 additions & 4 deletions wise_pizza/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def explain_changes_in_average(
total_name: str,
size_name: str,
min_segments: Optional[int] = None,
max_segments: int = 5,
max_segments: int = None,
min_depth: int = 1,
max_depth: int = 2,
solver: str = "lasso",
Expand Down Expand Up @@ -124,7 +124,7 @@ def explain_changes_in_totals(
total_name: str,
size_name: str,
min_segments: Optional[int] = None,
max_segments: int = 5,
max_segments: int = None,
min_depth: int = 1,
max_depth: int = 2,
solver: str = "lasso",
Expand Down Expand Up @@ -271,7 +271,7 @@ def explain_levels(
total_name: str,
size_name: Optional[str] = None,
min_segments: int = None,
max_segments: int = 10,
max_segments: int = None,
min_depth: int = 1,
max_depth: int = 2,
solver="lasso",
Expand Down Expand Up @@ -353,7 +353,7 @@ def explain_timeseries(
time_name: str,
size_name: Optional[str] = None,
min_segments: int = None,
max_segments: int = 5,
max_segments: int = None,
min_depth: int = 1,
max_depth: int = 2,
solver: str = "omp",
Expand Down
5 changes: 3 additions & 2 deletions wise_pizza/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pandas as pd
from scipy.sparse import csc_matrix, diags

from wise_pizza.solve.find_alpha import clean_up_min_max, find_alpha
from wise_pizza.solve.find_alpha import find_alpha
from wise_pizza.utils import clean_up_min_max
from wise_pizza.make_matrix import sparse_dummy_matrix
from wise_pizza.cluster import make_clusters
from wise_pizza.preselect import HeuristicSelector
Expand Down Expand Up @@ -98,7 +99,7 @@ def fit(
weights: pd.Series = None,
time_col: pd.Series = None,
time_basis: pd.DataFrame = None,
min_segments: int = 10,
min_segments: int = None,
max_segments: int = None,
min_depth: int = 1,
max_depth: int = 3,
Expand Down
17 changes: 1 addition & 16 deletions wise_pizza/solve/find_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from scipy.linalg import svd

from wise_pizza.solve.solver import solve_lasso, solve_lp, solve_omp
from wise_pizza.utils import clean_up_min_max


def find_alpha(
Expand Down Expand Up @@ -219,19 +220,3 @@ def print_errors(a: np.ndarray):
# fit_intercept=not use_proj
)
return reg, nonzeros


def clean_up_min_max(min_nonzeros: int = None, max_nonzeros: int = None):
assert min_nonzeros is not None or max_nonzeros is not None
if max_nonzeros is None:
if min_nonzeros is None:
max_nonzeros = 5
min_nonzeros = 5
else:
max_nonzeros = min_nonzeros
else:
if min_nonzeros is None:
min_nonzeros = max_nonzeros

assert min_nonzeros <= max_nonzeros
return min_nonzeros, max_nonzeros
52 changes: 44 additions & 8 deletions wise_pizza/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import List, Optional

import numpy as np
Expand Down Expand Up @@ -120,7 +121,9 @@ def rel_error(x, y):

if return_multiple:
sd_size = SegmentData(
combined.rename(columns={"Change in totals": "Change from segment size"}),
combined.rename(
columns={"Change in totals": "Change from segment size"}
),
dimensions=dims,
segment_total="Change from segment size",
segment_size=weights,
Expand All @@ -137,7 +140,9 @@ def rel_error(x, y):
combined["Change from"] = "Segment size"
c2["Change from"] = "Segment average"

df = pd.concat([combined, c2])[dims + [weights, "Change in totals", "Change from"]]
df = pd.concat([combined, c2])[
dims + [weights, "Change in totals", "Change from"]
]
df_change_in_totals = np.array(df["Change in totals"], dtype=np.longdouble)
combined_dtotals = np.array(combined["dtotals"], dtype=np.longdouble)
df_change_in_totals_sum = np.nansum(df_change_in_totals)
Expand All @@ -160,7 +165,9 @@ def rel_error(x, y):
combined[weights] = 1.0 # combined[totals + "_x"]
combined[weights] = np.maximum(1.0, combined[weights])
cols = (
dims + ["Change in totals", totals + "_x", totals + "_y"] + [c for c in combined.columns if "baseline" in c]
dims
+ ["Change in totals", totals + "_x", totals + "_y"]
+ [c for c in combined.columns if "baseline" in c]
)

return SegmentData(
Expand Down Expand Up @@ -205,16 +212,26 @@ def prepare_df(

# replace NaN values in categorical columns with the column name + "_unknown"
object_columns = list(new_df[dims].select_dtypes("object").columns)
new_df[object_columns] = new_df[object_columns].fillna(new_df[object_columns].apply(lambda x: x.name + "_unknown"))
new_df[object_columns] = new_df[object_columns].fillna(
new_df[object_columns].apply(lambda x: x.name + "_unknown")
)
new_df[object_columns] = new_df[object_columns].astype(str)

# Groupby all relevant dims to decrease the dataframe size, if possible
group_dims = dims if time_name is None else dims + [time_name]

if size_name is not None:
new_df = new_df.groupby(by=group_dims, observed=True)[[total_name, size_name]].sum().reset_index()
new_df = (
new_df.groupby(by=group_dims, observed=True)[[total_name, size_name]]
.sum()
.reset_index()
)
else:
new_df = new_df.groupby(by=group_dims, observed=True)[[total_name]].sum().reset_index()
new_df = (
new_df.groupby(by=group_dims, observed=True)[[total_name]]
.sum()
.reset_index()
)

return new_df

Expand Down Expand Up @@ -280,5 +297,24 @@ def prepare_df(
# new_df[object_columns] = new_df[object_columns].astype(str)
#
# return new_df
def almost_equals(x1, x2, eps: float=1e-6) -> bool:
return np.sum(np.abs(x1-x2))/np.mean(np.abs(x1+x2)) < eps
def almost_equals(x1, x2, eps: float = 1e-6) -> bool:
return np.sum(np.abs(x1 - x2)) / np.mean(np.abs(x1 + x2)) < eps


def clean_up_min_max(min_nonzeros: int = None, max_nonzeros: int = None):
if min_nonzeros is not None:
logging.warning(
"min_segments parameter is deprecated, please use max_nonzeros instead."
)
if max_nonzeros is None:
if min_nonzeros is None:
max_nonzeros = 5
min_nonzeros = 5
else:
max_nonzeros = min_nonzeros
else:
if min_nonzeros is None:
min_nonzeros = max_nonzeros

assert min_nonzeros <= max_nonzeros
return min_nonzeros, max_nonzeros

0 comments on commit 0f4b7ed

Please sign in to comment.