From e87f21914fefbf5cea74de8f34fdfed2e8ed5494 Mon Sep 17 00:00:00 2001 From: "Egor.Kraev" Date: Mon, 13 May 2024 13:04:40 +0100 Subject: [PATCH] Fix bug in tree solver --- wise_pizza/slicer.py | 5 +++-- wise_pizza/solve/fitter.py | 6 ++++-- wise_pizza/solve/tree.py | 4 ++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/wise_pizza/slicer.py b/wise_pizza/slicer.py index 10ac069..9f53c08 100644 --- a/wise_pizza/slicer.py +++ b/wise_pizza/slicer.py @@ -117,7 +117,7 @@ def fit( @param min_segments: Minimum number of segments to find @param max_segments: Maximum number of segments to find, defaults to min_segments @param min_depth: Minimum number of dimension to constrain in segment definition - @param max_depth: Maximum number of dimension to constrain in segment definition + @param max_depth: Maximum number of dimensions to constrain in segment definition; also max depth pf tree in tree solver @param solver: Valid values are "lasso" (default), "tree" (for non-overlapping segments), "omp", or "lp" @param verbose: If set to a truish value, lots of debug info is printed to console @param force_dim: To add dim @@ -287,7 +287,8 @@ def fit( # assert wgt == wgts[i] s["orig_i"] = i s["coef"] = self.reg.coef_[i] - s["impact"] = np.abs(s["coef"]) * (np.abs(this_vec) * self.weights).sum() + # TODO: does not taking the abs of coef here break time series? + s["impact"] = s["coef"] * (np.abs(this_vec) * self.weights).sum() s["avg_impact"] = s["impact"] / sum(self.weights) s["total"] = (self.totals * dummy).sum() s["seg_size"] = wgt diff --git a/wise_pizza/solve/fitter.py b/wise_pizza/solve/fitter.py index 4f44254..1904049 100644 --- a/wise_pizza/solve/fitter.py +++ b/wise_pizza/solve/fitter.py @@ -18,10 +18,12 @@ def fit_predict(self, X, y, sample_weight=None): return self.predict(X) def error(self, X, y, sample_weight=None): + # Error is chosen so that it's minimized by the weighted mean of y err = y - self.predict(X) + errsq = err**2 if sample_weight is not None: - err *= sample_weight - return np.nansum(err**2) + errsq *= sample_weight + return np.nansum(errsq) class AverageFitter(Fitter): diff --git a/wise_pizza/solve/tree.py b/wise_pizza/solve/tree.py index eda97c5..0a1504e 100644 --- a/wise_pizza/solve/tree.py +++ b/wise_pizza/solve/tree.py @@ -69,7 +69,7 @@ def __init__( dim_split: Optional[Dict[str, List]] = None, depth: int = 0, ): - self.df = df + self.df = df.copy() self.fitter = fitter self.dims = dims self._best_submodels = None @@ -85,7 +85,7 @@ def error(self): self.model = copy.deepcopy(self.fitter) self.model.fit( X=self.df[self.dims], - y=self.df["totals"], + y=self.df["__avg"], sample_weight=self.df["weights"], ) return self.model.error(