Skip to content

Commit

Permalink
fixed cross_entropy carte_multitable
Browse files Browse the repository at this point in the history
  • Loading branch information
myungkim930 committed Sep 2, 2024
1 parent d65d9b6 commit 74c0b26
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/carte_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,9 +1265,8 @@ def _get_predict_prob(self, X):
idx_ = np.where(np.array(self.source_list_) == source_name)[0]
model_list = [self.model_list_[idx] for idx in idx_]
out += [self._generate_output(X, model_list=model_list, weights=None)]
out = np.array(out).squeeze().transpose()

out = np.average(out, weights=self.weights_, axis=1)
out = np.array(out).squeeze()
out = np.average(out, weights=self.weights_, axis=0)

# Transform according to loss
if self.loss == "binary_crossentropy":
Expand Down

0 comments on commit 74c0b26

Please sign in to comment.