From e7078ca96e4aa98ea9af8a782fdc1c378ad802d4 Mon Sep 17 00:00:00 2001 From: sofiane Date: Tue, 11 Jul 2023 09:38:54 +0200 Subject: [PATCH] ENH: Change in example to make it work --- .../plot_tutorial_multilabel_classification.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/multilabel_classification/1-quickstart/plot_tutorial_multilabel_classification.py b/examples/multilabel_classification/1-quickstart/plot_tutorial_multilabel_classification.py index ff1cd417a..f7f0e3f29 100644 --- a/examples/multilabel_classification/1-quickstart/plot_tutorial_multilabel_classification.py +++ b/examples/multilabel_classification/1-quickstart/plot_tutorial_multilabel_classification.py @@ -118,17 +118,19 @@ } clf = MultiOutputClassifier(GaussianNB()).fit(X_train, y_train) -mapie = MapieMultiLabelClassifier(estimator=clf) -mapie.fit(X_cal, y_cal) alpha = np.arange(0.01, 1, 0.01) y_pss, recalls, thresholds, r_hats, r_hat_pluss = {}, {}, {}, {}, {} y_test_repeat = np.repeat(y_test[:, :, np.newaxis], len(alpha), 2) for i, (name, (method, bound)) in enumerate(method_params.items()): + mapie = MapieMultiLabelClassifier( + estimator=clf, method=method, metric_control="recall" + ) + mapie.fit(X_cal, y_cal) + _, y_pss[name] = mapie.predict( - X_test, method=method, - alpha=alpha, bound=bound, delta=.1 + X_test, alpha=alpha, bound=bound, delta=.1 ) recalls[name] = ( (y_test_repeat * y_pss[name]).sum(axis=1) / @@ -138,6 +140,7 @@ r_hats[name] = mapie.r_hat r_hat_pluss[name] = mapie.r_hat_plus + ############################################################################## # 3. Results # ----------