Skip to content

Commit

Permalink
Fix bug in tree solver
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorKraevTransferwise committed May 13, 2024
1 parent f87bb99 commit e87f219
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
5 changes: 3 additions & 2 deletions wise_pizza/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def fit(
@param min_segments: Minimum number of segments to find
@param max_segments: Maximum number of segments to find, defaults to min_segments
@param min_depth: Minimum number of dimension to constrain in segment definition
@param max_depth: Maximum number of dimension to constrain in segment definition
@param max_depth: Maximum number of dimensions to constrain in segment definition; also max depth pf tree in tree solver
@param solver: Valid values are "lasso" (default), "tree" (for non-overlapping segments), "omp", or "lp"
@param verbose: If set to a truish value, lots of debug info is printed to console
@param force_dim: To add dim
Expand Down Expand Up @@ -287,7 +287,8 @@ def fit(
# assert wgt == wgts[i]
s["orig_i"] = i
s["coef"] = self.reg.coef_[i]
s["impact"] = np.abs(s["coef"]) * (np.abs(this_vec) * self.weights).sum()
# TODO: does not taking the abs of coef here break time series?
s["impact"] = s["coef"] * (np.abs(this_vec) * self.weights).sum()
s["avg_impact"] = s["impact"] / sum(self.weights)
s["total"] = (self.totals * dummy).sum()
s["seg_size"] = wgt
Expand Down
6 changes: 4 additions & 2 deletions wise_pizza/solve/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ def fit_predict(self, X, y, sample_weight=None):
return self.predict(X)

def error(self, X, y, sample_weight=None):
# Error is chosen so that it's minimized by the weighted mean of y
err = y - self.predict(X)
errsq = err**2
if sample_weight is not None:
err *= sample_weight
return np.nansum(err**2)
errsq *= sample_weight
return np.nansum(errsq)


class AverageFitter(Fitter):
Expand Down
4 changes: 2 additions & 2 deletions wise_pizza/solve/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
dim_split: Optional[Dict[str, List]] = None,
depth: int = 0,
):
self.df = df
self.df = df.copy()
self.fitter = fitter
self.dims = dims
self._best_submodels = None
Expand All @@ -85,7 +85,7 @@ def error(self):
self.model = copy.deepcopy(self.fitter)
self.model.fit(
X=self.df[self.dims],
y=self.df["totals"],
y=self.df["__avg"],
sample_weight=self.df["weights"],
)
return self.model.error(
Expand Down

0 comments on commit e87f219

Please sign in to comment.