diff --git a/tools/utils/inference.py b/tools/utils/inference.py index 50f154be..831ba9e9 100644 --- a/tools/utils/inference.py +++ b/tools/utils/inference.py @@ -74,7 +74,7 @@ def __init__(self, model: List or AnyStr or Tuple): ) inter = tf.lite.Interpreter net = inter(model) - self._input_shape = list(net.get_input_details()[0]['shape'][1:]) + self._input_shape = tuple(net.get_input_details()[0]['shape'][1:]) net.allocate_tensors() self.engine = 'tf' else: @@ -280,7 +280,8 @@ def __init__( def init(self, cfg): self.evaluator: Evaluator = self.runner.build_evaluator(self.cfg.get('val_evaluator')) - self.evaluator.dataset_meta = self.dataloader.dataset.METAINFO + if hasattr(self.dataloader, 'dataset'): + self.evaluator.dataset_meta = self.dataloader.dataset.METAINFO if hasattr(cfg.model, 'data_preprocessor'): self.data_preprocess = MODELS.build(cfg.model.data_preprocessor)