diff --git a/examples/gnn_node.py b/examples/gnn_node.py index bba7cf14..8be164bf 100644 --- a/examples/gnn_node.py +++ b/examples/gnn_node.py @@ -155,7 +155,7 @@ def test(loader: NeighborLoader) -> np.ndarray: model.eval() pred_list = [] - for batch in loader: + for batch in tqdm(loader): batch = batch.to(device) pred = model( batch, @@ -195,8 +195,8 @@ def test(loader: NeighborLoader) -> np.ndarray: val_metrics = task.evaluate(val_pred, task.get_table("val")) print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}") - if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or ( - not higher_is_better and val_metrics[tune_metric] < best_val_metric + if (higher_is_better and val_metrics[tune_metric] >= best_val_metric) or ( + not higher_is_better and val_metrics[tune_metric] <= best_val_metric ): best_val_metric = val_metrics[tune_metric] state_dict = copy.deepcopy(model.state_dict())