diff --git a/nncf/common/tensor_statistics/aggregator.py b/nncf/common/tensor_statistics/aggregator.py index 47494145e78..63c03191802 100644 --- a/nncf/common/tensor_statistics/aggregator.py +++ b/nncf/common/tensor_statistics/aggregator.py @@ -23,6 +23,7 @@ from nncf.data.dataset import DataItem from nncf.data.dataset import Dataset from nncf.data.dataset import ModelInput +from nncf.common.logging import nncf_logger TensorType = TypeVar("TensorType") TModel = TypeVar("TModel") @@ -70,7 +71,9 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None: engine = factory.EngineFactory.create(model_with_outputs) iterations_number = self._get_iterations_number() - empty_statistics = True + + processed_samples = 0 + for input_data in track( # type: ignore islice(self.dataset.get_inference_data(), iterations_number), total=iterations_number, @@ -79,9 +82,13 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None: outputs = engine.infer(input_data) processed_outputs = self._process_outputs(outputs) self._register_statistics(processed_outputs, merged_statistics) - empty_statistics = False - if empty_statistics: + processed_samples += 1 + + if processed_samples == 0: raise nncf.ValidationError(EMPTY_DATASET_ERROR) + + if subset_size > processed_samples: + nncf_logger.warning(f"Dataset contains only {processed_samples} samples, smaller than the requested subset size {subset_size}.") def register_statistic_points(self, statistic_points: StatisticPointsContainer) -> None: """