diff --git a/luxonis_ml/tracker/tracker.py b/luxonis_ml/tracker/tracker.py index bda59d2f..b2c898b6 100644 --- a/luxonis_ml/tracker/tracker.py +++ b/luxonis_ml/tracker/tracker.py @@ -1,6 +1,8 @@ import glob +import logging import os from functools import wraps +from importlib.util import find_spec from pathlib import Path from typing import Any, Callable, Dict, Optional, Union @@ -9,6 +11,8 @@ from luxonis_ml.utils.filesystem import LuxonisFileSystem, PathType +logger = logging.getLogger(__name__) + class LuxonisTracker: def __init__( @@ -96,10 +100,11 @@ def __init__( raise Exception("Must specify wandb_entity when using wandb!") else: self.wandb_entity = wandb_entity - if self.is_mlflow and mlflow_tracking_uri is None: - raise Exception("Must specify mlflow_tracking_uri when using mlflow!") - else: - self.mlflow_tracking_uri = mlflow_tracking_uri + if self.is_mlflow: + if mlflow_tracking_uri is None: + raise Exception("Must specify mlflow_tracking_uri when using mlflow!") + else: + self.mlflow_tracking_uri = mlflow_tracking_uri if not (self.is_tensorboard or self.is_wandb or self.is_mlflow): raise Exception("At least one integration must be used!") @@ -157,7 +162,7 @@ def experiment(self) -> Dict[str, Any]: self._experiment = {} if self.is_tensorboard: - from torch.utils.tensorboard import SummaryWriter + from torch.utils.tensorboard.writer import SummaryWriter log_dir = f"{self.save_directory}/tensorboard_logs/{self.run_name}" self._experiment["tensorboard"] = SummaryWriter(log_dir=log_dir) @@ -182,6 +187,19 @@ def experiment(self) -> Dict[str, Any]: if self.is_mlflow: import mlflow + if find_spec("psutil") is not None: + mlflow.enable_system_metrics_logging() + if find_spec("pynvml") is None: + logger.warning( + "pynvml not found, GPU stats will not be monitored. " + "To enable GPU monitoring, install it using 'pip install pynvml'" + ) + else: + logger.warning( + "`psutil` not found. To enable system metric logging, " + "install it using 'pip install psutil'" + ) + self._experiment["mlflow"] = mlflow self.artifacts_dir = f"{self.save_directory}/{self.run_name}/artifacts"