From 9033a39bed8a702849f10dd08f02cb95e507085b Mon Sep 17 00:00:00 2001 From: Charles Gaydon Date: Tue, 6 Feb 2024 16:12:06 +0100 Subject: [PATCH] Flake8 --- myria3d/callbacks/comet_callbacks.py | 2 +- myria3d/models/model.py | 6 ++---- myria3d/predict.py | 2 +- tests/runif.py | 3 +-- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/myria3d/callbacks/comet_callbacks.py b/myria3d/callbacks/comet_callbacks.py index 4446b83b..84a82d0c 100755 --- a/myria3d/callbacks/comet_callbacks.py +++ b/myria3d/callbacks/comet_callbacks.py @@ -70,4 +70,4 @@ def setup(self, trainer, pl_module, stage): if logger: log_path = os.getcwd() log.info(f"----------------\n LOGS DIR is {log_path}\n ----------------") - logger.experiment.log_parameter("experiment_logs_dirpath", log_path) \ No newline at end of file + logger.experiment.log_parameter("experiment_logs_dirpath", log_path) diff --git a/myria3d/models/model.py b/myria3d/models/model.py index 8e49b319..67c2752d 100755 --- a/myria3d/models/model.py +++ b/myria3d/models/model.py @@ -1,15 +1,13 @@ -from typing import Optional - import torch from pytorch_lightning import LightningModule from torch import nn from torch_geometric.data import Batch from torch_geometric.nn import knn_interpolate -from myria3d.metrics.iou import iou +from torchmetrics.classification import MulticlassJaccardIndex +from myria3d.metrics.iou import iou from myria3d.models.modules.pyg_randla_net import PyGRandLANet from myria3d.utils import utils -from torchmetrics.classification import MulticlassJaccardIndex log = utils.get_logger(__name__) diff --git a/myria3d/predict.py b/myria3d/predict.py index 50c03f62..7c50219e 100644 --- a/myria3d/predict.py +++ b/myria3d/predict.py @@ -5,7 +5,7 @@ import hydra import torch from omegaconf import DictConfig -from pytorch_lightning import LightningDataModule, LightningModule +from pytorch_lightning import LightningDataModule from tqdm import tqdm from myria3d.models.model import Model diff --git a/tests/runif.py b/tests/runif.py index ec5da504..7a2ac5f6 100644 --- a/tests/runif.py +++ b/tests/runif.py @@ -1,5 +1,4 @@ import pytest -import torch from lightning.pytorch.accelerators import find_usable_cuda_devices """ @@ -39,7 +38,7 @@ def __new__( try: find_usable_cuda_devices(min_gpus) conditions.append(False) - except (ValueError, RuntimeError) as _: + except (ValueError, RuntimeError): conditions.append(True) reasons.append(f"GPUs>={min_gpus}")