From 0c9cb7bef4b09b8dc686dbc568485bd3e0274e24 Mon Sep 17 00:00:00 2001 From: arnaudon Date: Wed, 22 May 2024 16:00:04 +0200 Subject: [PATCH] fix test --- MARBLE/geometry.py | 4 ++-- install.sh | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/MARBLE/geometry.py b/MARBLE/geometry.py index a01e2def..dfbc3352 100644 --- a/MARBLE/geometry.py +++ b/MARBLE/geometry.py @@ -365,11 +365,11 @@ def fit_graph(x, graph_type="cknn", par=1, delta=1.0, metric="euclidean"): edge_index = utils.np2torch(edge_index, dtype="double") elif graph_type == "knn": - edge_index = knn_graph(x, k=par, metric=metric) + edge_index = knn_graph(x, k=par) edge_index = PyGu.add_self_loops(edge_index)[0] elif graph_type == "radius": - edge_index = radius_graph(x, r=par, metric=metric) + edge_index = radius_graph(x, r=par) edge_index = PyGu.add_self_loops(edge_index)[0] else: diff --git a/install.sh b/install.sh index df05cbe8..6bda6d0a 100755 --- a/install.sh +++ b/install.sh @@ -4,7 +4,6 @@ echo "Checking the PyTorch version" SED=$(which gsed || which sed) TORCH_VERSION=$(pip freeze | grep torch== | $SED -re "s/torch==([^+]+).*/\1/") -echo $TORCH_VERSION if [ -z "$TORCH_VERSION" ] then