diff --git a/tests/metrics/test_compute_metrics.py b/tests/metrics/test_compute_metrics.py index 2c7c346..ed92616 100644 --- a/tests/metrics/test_compute_metrics.py +++ b/tests/metrics/test_compute_metrics.py @@ -146,6 +146,14 @@ def test_arguments(self): evaluate_range=('a', 'b', 'c'), device=self.device) + with pytest.raises(ValueError): + compute_metrics.evaluate(metric=metric, + log_dir=self.log_dir, + netG=self.netG, + dataset=self.dataset, + evaluate_range=(100, 100, 100, 100), + device=self.device) + with pytest.raises(ValueError): compute_metrics.evaluate(metric=metric, log_dir=self.log_dir, diff --git a/torch_mimicry/metrics/compute_metrics.py b/torch_mimicry/metrics/compute_metrics.py index 4a0b6c7..2e9682b 100644 --- a/torch_mimicry/metrics/compute_metrics.py +++ b/torch_mimicry/metrics/compute_metrics.py @@ -48,7 +48,7 @@ def evaluate(metric, if evaluate_range: if (type(evaluate_range) != tuple or not all(map(lambda x: type(x) == int, evaluate_range)) - or not len(x) == 3): + or not len(evaluate_range) == 3): raise ValueError( "evaluate_range must be a tuple of ints (start, end, step).")