Skip to content

Commit

Permalink
UPD: Refacto notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurNardone committed Jul 11, 2023
1 parent 59db4ff commit 2c2244f
Show file tree
Hide file tree
Showing 2 changed files with 339 additions and 993 deletions.
38 changes: 24 additions & 14 deletions mapie/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,13 +610,8 @@ def _get_true_label_cumsum_proba(
y_pred_proba_sorted = np.take_along_axis(y_pred_proba, index_sorted, axis=1)
y_true_sorted = np.take_along_axis(y_true, index_sorted, axis=1)
if self.method == "ssaps":
j_values = np.arange(1, len(self.classes_) + 1)
penalized_residuals = 1 - np.power(residuals[:, np.newaxis], j_values)
y_pred_proba_penalized_sorted = np.multiply(
penalized_residuals, y_pred_proba_sorted
)
y_pred_proba_sorted_cumsum = np.cumsum(
y_pred_proba_penalized_sorted, axis=1
y_pred_proba_sorted_cumsum = self._get_penalized_ssaps(
y_pred_proba_sorted, residuals
)
else:
y_pred_proba_sorted_cumsum = np.cumsum(y_pred_proba_sorted, axis=1)
Expand Down Expand Up @@ -684,6 +679,24 @@ def _get_true_label_position(self, y_pred_proba: NDArray, y: NDArray) -> NDArray

return position

def _get_penalized_ssaps(self, y_pred_proba_sorted, residuals, alpha_dim=False):
if alpha_dim:
n_alphas = y_pred_proba_sorted.shape[-1]
y_pred_proba_sorted = y_pred_proba_sorted[:, :, 0]
j_values = np.arange(1, len(self.classes_) + 1)
penalized_residuals = 1 - np.power(residuals[:, np.newaxis], j_values - 1)
penalized_residuals[:, 0] = 1.0
y_pred_proba_penalized_sorted = np.multiply(
penalized_residuals, y_pred_proba_sorted
)
y_pred_proba_sorted_cumsum = np.cumsum(y_pred_proba_penalized_sorted, axis=1)
if alpha_dim:
y_pred_proba_sorted_cumsum = y_pred_proba_sorted_cumsum[:, :, np.newaxis]
y_pred_proba_sorted_cumsum = np.repeat(
y_pred_proba_sorted_cumsum, repeats=n_alphas, axis=2
)
return y_pred_proba_sorted_cumsum

def _get_last_included_proba(
self,
y_pred_proba: NDArray,
Expand Down Expand Up @@ -727,13 +740,10 @@ def _get_last_included_proba(
y_pred_proba_sorted = np.take_along_axis(y_pred_proba, index_sorted, axis=1)

if self.method == "ssaps":
j_values = np.arange(1, len(self.classes_) + 1)
penalized_residuals = 1 - np.power(residuals[:, np.newaxis], j_values)
y_pred_proba_penalized_sorted = np.multiply(
penalized_residuals[:, :, np.newaxis], y_pred_proba_sorted
)
y_pred_proba_sorted_cumsum = np.cumsum(
y_pred_proba_penalized_sorted, axis=1
y_pred_proba_sorted_cumsum = self._get_penalized_ssaps(
y_pred_proba_sorted,
residuals,
alpha_dim=True,
)
else:
# get sorted cumulated score
Expand Down
1,294 changes: 315 additions & 979 deletions notebooks/ImageNet-crf.ipynb

Large diffs are not rendered by default.

0 comments on commit 2c2244f

Please sign in to comment.