diff --git a/src/carte_estimator.py b/src/carte_estimator.py index a394f73..c3454f2 100644 --- a/src/carte_estimator.py +++ b/src/carte_estimator.py @@ -1265,9 +1265,8 @@ def _get_predict_prob(self, X): idx_ = np.where(np.array(self.source_list_) == source_name)[0] model_list = [self.model_list_[idx] for idx in idx_] out += [self._generate_output(X, model_list=model_list, weights=None)] - out = np.array(out).squeeze().transpose() - - out = np.average(out, weights=self.weights_, axis=1) + out = np.array(out).squeeze() + out = np.average(out, weights=self.weights_, axis=0) # Transform according to loss if self.loss == "binary_crossentropy":