diff --git a/docs/source/tutorials/configure_data.rst b/docs/source/tutorials/configure_data.rst index e12c94d43..40b61b835 100644 --- a/docs/source/tutorials/configure_data.rst +++ b/docs/source/tutorials/configure_data.rst @@ -211,7 +211,7 @@ If no data config template can meet the requirement, we can also define the `dat * - `post_process_data `_ - post_process(default), text_classification_post_process, ner_post_process, text_generation_post_process * - `dataloader `_ - - default_dataloader(default), skip_dataloader, no_auto_batch_dataloader + - default_dataloader(default), no_auto_batch_dataloader each component can be customized by the following fields: - ``name``: the name of the component. diff --git a/olive/data/component/dataloader.py b/olive/data/component/dataloader.py index 4d2582695..5c3d19677 100644 --- a/olive/data/component/dataloader.py +++ b/olive/data/component/dataloader.py @@ -8,11 +8,6 @@ from olive.data.registry import Registry -@Registry.register_dataloader() -def skip_dataloader(dataset): - return dataset - - @Registry.register_default_dataloader() def default_dataloader(dataset, batch_size=1, **kwargs): return DataLoader(dataset, batch_size=batch_size, **kwargs) diff --git a/olive/data/component/dataset.py b/olive/data/component/dataset.py index 5e2701942..fa35e4e4f 100644 --- a/olive/data/component/dataset.py +++ b/olive/data/component/dataset.py @@ -94,41 +94,48 @@ def to_hf_dataset(self, label_name="label"): class DummyDataset(BaseDataset): - def __init__(self, input_shapes, input_names: Optional[List] = None, input_types: Optional[List] = None): - """Initialize the dummy dataset. + def __init__( + self, + input_shapes, + input_names: Optional[List] = None, + input_types: Optional[List] = None, + max_samples: Optional[int] = 32, + ): + """Initialize the dataset with dummy data. - if input_names is None, the dummy dataset will return a tuple of tensors - else the dummy dataset will return a dict of tensors + if input_names is not provided, the dataset will return a tuple of tensors + else the dataset will return a dict of tensors """ # pylint: disable=super-init-not-called - self.input_shapes = input_shapes - self.input_names = input_names - self.input_types = input_types or ["float32"] * len(input_shapes) + if not input_types: + input_types = ["float32"] * len(input_shapes) + input_types = [resolve_torch_dtype(dtype_str) for dtype_str in input_types] + + if input_names: + dummy_data = {} + for input_name, input_shape, input_type in zip(input_names, input_shapes, input_types): + dummy_data.update({input_name: torch.ones(input_shape, dtype=input_type)}) + dummy_data = dummy_data if len(dummy_data) > 1 else dummy_data[input_names[0]] + else: + dummy_data = [] + for shape, dtype in zip(input_shapes, input_types): + dummy_data.append(torch.ones(shape, dtype=dtype)) + dummy_data = tuple(dummy_data) if len(dummy_data) > 1 else dummy_data[0] + + self.max_samples = max_samples + self.dummy_data = dummy_data, torch.tensor([0]) def __len__(self): - return 256 + return self.max_samples def __getitem__(self, index): # From https://docs.python.org/3/reference/datamodel.html#object.__getitem__, # __getitem__ should raise IndexError when index is out of range # Otherwise, the enumerate function will enter infinite loop - if index < 0 or index >= len(self): + if index < 0 or index >= self.max_samples: raise IndexError("Index out of range") - input_types = [resolve_torch_dtype(dtype_str) for dtype_str in self.input_types] - - if not self.input_names: - dummy_inputs = [] - for shape, dtype in zip(self.input_shapes, input_types): - dummy_inputs.append(torch.ones(shape, dtype=dtype)) - dummy_inputs = tuple(dummy_inputs) if len(dummy_inputs) > 1 else dummy_inputs[0] - else: - dummy_inputs = {} - for input_name, input_shape, input_type in zip(self.input_names, self.input_shapes, input_types): - dummy_inputs.update({input_name: torch.ones(input_shape, dtype=input_type)}) - dummy_inputs = dummy_inputs if len(dummy_inputs) > 1 else dummy_inputs[self.input_names[0]] - label = 0 - return dummy_inputs, label + return self.dummy_data class RawDataset(BaseDataset): diff --git a/olive/data/component/load_dataset.py b/olive/data/component/load_dataset.py index 136f53c0e..7aed5274a 100644 --- a/olive/data/component/load_dataset.py +++ b/olive/data/component/load_dataset.py @@ -35,8 +35,8 @@ def huggingface_dataset(data_dir, data_name=None, subset=None, split="validation @Registry.register_dataset() -def dummy_dataset(data_dir, input_shapes, input_names=None, input_types=None): - return DummyDataset(input_shapes, input_names, input_types) +def dummy_dataset(data_dir, input_shapes, input_names=None, input_types=None, max_samples=32): + return DummyDataset(input_shapes, input_names, input_types, max_samples) @Registry.register_dataset() diff --git a/olive/data/container/dummy_data_container.py b/olive/data/container/dummy_data_container.py index 9a5f0acb6..429e28c5b 100644 --- a/olive/data/container/dummy_data_container.py +++ b/olive/data/container/dummy_data_container.py @@ -18,17 +18,17 @@ class DummyDataContainer(DataContainer): dummy_data_config = DataConfig( name="dummy", type="DummyDataContainer", - "load_dataset_config"={ - "params": { + load_dataset_config=DataComponentConfig( + params={ "input_names": metric.user_config.input_names, "input_shapes": metric.user_config.input_shapes, "input_types": metric.user_config.input_types, } - } + ), ) """ default_components_type: ClassVar[dict] = { DataComponentType.LOAD_DATASET.value: "dummy_dataset", - DataComponentType.DATALOADER.value: "skip_dataloader", + DataComponentType.DATALOADER.value: "no_auto_batch_dataloader", } diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index 9197a951b..2f023a9df 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -271,7 +271,7 @@ def get_user_config(cls, framework: Framework, data_root: str, metric: Metric): dataloader = user_module.call_object( dataloader_func, data_dir, - metric.user_config.batch_size, + batch_size=metric.user_config.batch_size, model_framework=framework, **cls._get_func_kwargs(metric, "dataloader_func"), ) diff --git a/olive/passes/pytorch/qat_utils.py b/olive/passes/pytorch/qat_utils.py index 61ec78c36..f0b70f80c 100644 --- a/olive/passes/pytorch/qat_utils.py +++ b/olive/passes/pytorch/qat_utils.py @@ -71,7 +71,7 @@ def execute_local(self) -> PyTorchModelHandler: ptl_data_module = self.config.ptl_data_module() else: train_dataloader_func = self.config.train_dataloader_func( - self.config.train_data_dir, self.config.train_batch_size + self.config.train_data_dir, batch_size=self.config.train_batch_size ) ptl_module = DefaultPTLModule(model=quan_model, training_dataloader=train_dataloader_func) diff --git a/test/integ_test/utils.py b/test/integ_test/utils.py index 0e9910bc3..29655c881 100644 --- a/test/integ_test/utils.py +++ b/test/integ_test/utils.py @@ -24,12 +24,16 @@ def get_olive_workspace_config(): if client_id is None: raise Exception("Please set the environment variable MANAGED_IDENTITY_CLIENT_ID") + exclude_managed_identity_credential = ( + {"exclude_managed_identity_credential": True} if "EXCLUDE_MANAGED_IDENTITY_CREDENTIAL" in os.environ else {} + ) + return { "subscription_id": subscription_id, "resource_group": resource_group, "workspace_name": workspace_name, # pipeline agents have multiple managed identities, so we need to specify the client_id - "default_auth_params": {"managed_identity_client_id": client_id}, + "default_auth_params": {"managed_identity_client_id": client_id, **exclude_managed_identity_credential}, } diff --git a/test/unit_test/data_container/test_dataset.py b/test/unit_test/data_container/test_dataset.py index c2b47ae66..e755a4bc0 100644 --- a/test/unit_test/data_container/test_dataset.py +++ b/test/unit_test/data_container/test_dataset.py @@ -18,10 +18,10 @@ def get_dict_dataset(length=256, max_samples=None): return BaseDataset(data, ["original_label"], max_samples) -def get_dummy_dataset(): +def get_dummy_dataset(length=256): input_shapes = [[2], [3]] input_names = ["input_1", "input_2"] - return DummyDataset(input_shapes, input_names) + return DummyDataset(input_shapes, input_names, max_samples=length) def get_hf_dataset(): @@ -70,4 +70,4 @@ def test_dataset_to_hf_dataset(self, dataset_func, label_name): # assert shape of the first sample assert hf_dataset["input_1"][0].shape == (2,) assert hf_dataset["input_2"][0].shape == (3,) - assert hf_dataset[label_name][0].shape == () + assert hf_dataset[label_name][0].shape == ((1,) if isinstance(dataset, DummyDataset) else ()) diff --git a/test/unit_test/data_container/test_template.py b/test/unit_test/data_container/test_template.py index 726eacec1..c535792ee 100644 --- a/test/unit_test/data_container/test_template.py +++ b/test/unit_test/data_container/test_template.py @@ -52,10 +52,10 @@ def test_dummy_template(self, input_names): input_types=["int64", "int64", "int64"], ) dummy_inputs, _ = dataloader.to_data_container().get_first_batch() - if not input_names: - assert isinstance(dummy_inputs, tuple), "Failed to create dummy tuple input from dummy template." - else: + if input_names: assert isinstance(dummy_inputs, dict), "Failed to create dummy dict dataset from dummy template." + else: + assert isinstance(dummy_inputs, list), "Failed to create dummy list input from dummy template." def test_raw_data_template(self, tmpdir): input_names = ["float_input", "int_input"] diff --git a/test/unit_test/evaluator/test_olive_evaluator.py b/test/unit_test/evaluator/test_olive_evaluator.py index 4d2b61f76..6ec7a604b 100644 --- a/test/unit_test/evaluator/test_olive_evaluator.py +++ b/test/unit_test/evaluator/test_olive_evaluator.py @@ -289,9 +289,8 @@ def test_dataloader_func_kwargs(self, dataloader_func_kwargs): # setup dataloader_func = MagicMock(spec=FunctionType) data_dir = None - batch_size = 1 model_framework = "PyTorch" - user_config = {"dataloader_func": dataloader_func, "batch_size": batch_size, "data_dir": data_dir} + user_config = {"dataloader_func": dataloader_func, "data_dir": data_dir} if dataloader_func_kwargs: user_config["func_kwargs"] = {"dataloader_func": dataloader_func_kwargs} metric = get_latency_metric(LatencySubType.AVG, user_config=user_config) @@ -301,7 +300,7 @@ def test_dataloader_func_kwargs(self, dataloader_func_kwargs): # assert dataloader_func.assert_called_once_with( - data_dir, batch_size, model_framework=model_framework, **(dataloader_func_kwargs or {}) + data_dir, batch_size=1, model_framework=model_framework, **(dataloader_func_kwargs or {}) ) # this is enough to test the kwargs for `evaluate_func`, `metric_func` and `post_process_func` diff --git a/test/unit_test/passes/onnx/test_perf_tuning.py b/test/unit_test/passes/onnx/test_perf_tuning.py index 3ff6cef8a..0a6b7b94f 100644 --- a/test/unit_test/passes/onnx/test_perf_tuning.py +++ b/test/unit_test/passes/onnx/test_perf_tuning.py @@ -5,7 +5,7 @@ import logging import math import re -from test.unit_test.utils import create_dataloader, get_onnx_model +from test.unit_test.utils import create_dummy_dataloader, get_onnx_model from unittest.mock import MagicMock, PropertyMock, patch import psutil @@ -21,10 +21,10 @@ @pytest.mark.parametrize( "config", [ - {"input_names": ["input"], "input_shapes": [[1, 1]]}, + {"input_names": ["input"], "input_shapes": [(1, 1)]}, {}, - {"dataloader_func": create_dataloader}, - {"dataloader_func": create_dataloader, "dataloader_func_kwargs": {"dummy_kwarg": 1}}, + {"dataloader_func": create_dummy_dataloader}, + {"dataloader_func": create_dummy_dataloader, "dataloader_func_kwargs": {"dummy_kwarg": 1}}, ], ) def test_ort_perf_tuning_pass(config, tmp_path): @@ -42,7 +42,7 @@ def test_ort_perf_tuning_pass(config, tmp_path): "config", [ {}, - {"input_names": ["input"], "input_shapes": [[1, 1]]}, + {"input_names": ["input"], "input_shapes": [(1, 1)]}, {"providers_list": ["CPUExecutionProvider", "CUDAExecutionProvider"], "device": "gpu"}, ], ) diff --git a/test/unit_test/passes/pytorch/test_quantization_aware_training.py b/test/unit_test/passes/pytorch/test_quantization_aware_training.py index d42cb981c..fdf1b9128 100644 --- a/test/unit_test/passes/pytorch/test_quantization_aware_training.py +++ b/test/unit_test/passes/pytorch/test_quantization_aware_training.py @@ -2,18 +2,24 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from test.unit_test.utils import create_dataloader, get_pytorch_model +from test.unit_test.utils import create_dummy_dataloader, get_pytorch_model from olive.hardware.accelerator import AcceleratorSpec from olive.passes.olive_pass import FullPassConfig, create_pass_from_dict from olive.passes.pytorch.quantization_aware_training import QuantizationAwareTraining +# TODO(shaahji): Remove this once QuantizationAwareTraining pass supports DataConfig +def _create_dummy_dataloader(data_dir, **kwargs): + kwargs.pop("batch_size", None) + return create_dummy_dataloader(data_dir, max_samples=1, batch_size=1, **kwargs) + + def test_quantization_aware_training_pass_default(tmp_path): # setup input_model = get_pytorch_model() config = { - "train_dataloader_func": create_dataloader, + "train_dataloader_func": _create_dummy_dataloader, "checkpoint_path": str(tmp_path / "checkpoint"), } diff --git a/test/unit_test/systems/isolated_ort/test_isolated_ort_system.py b/test/unit_test/systems/isolated_ort/test_isolated_ort_system.py index 9d48e4adc..30f892398 100644 --- a/test/unit_test/systems/isolated_ort/test_isolated_ort_system.py +++ b/test/unit_test/systems/isolated_ort/test_isolated_ort_system.py @@ -116,7 +116,7 @@ def setup(self, tmp_path): python_path = shutil.which("python", path=python_environment_path) # install only onnxruntime - run_subprocess([python_path, "-m", "pip", "install", "onnxruntime"], env=self.system.environ) + run_subprocess([python_path, "-m", "pip", "install", "onnxruntime", "numpy<2"], env=self.system.environ) self.evaluator = IsolatedORTEvaluator(self.system.environ) self.onnx_evaluator = OnnxEvaluator() @@ -125,7 +125,7 @@ def setup(self, tmp_path): def test__inference(self): model = get_onnx_model_config().create_model() - metric = get_accuracy_metric(AccuracySubType.ACCURACY_SCORE, random_dataloader=False) + metric = get_accuracy_metric(AccuracySubType.ACCURACY_SCORE) metric = OliveEvaluator.generate_metric_user_config_with_model_io(metric, model) dataloader, _, post_func = OliveEvaluator.get_user_config(model.framework, None, metric) diff --git a/test/unit_test/test_data_root.py b/test/unit_test/test_data_root.py index 4071a3452..c861a2beb 100644 --- a/test/unit_test/test_data_root.py +++ b/test/unit_test/test_data_root.py @@ -3,13 +3,12 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from pathlib import Path -from test.unit_test.utils import get_pytorch_model_dummy_input, pytorch_model_loader +from test.unit_test.utils import create_dummy_dataloader, get_pytorch_model_dummy_input, pytorch_model_loader from unittest.mock import MagicMock, patch import pytest -import torch -from torch.utils.data import DataLoader, Dataset +from olive.data.component.load_dataset import dummy_dataset from olive.data.config import DataComponentConfig, DataConfig from olive.data.registry import Registry from olive.resource_path import create_resource_path @@ -18,31 +17,11 @@ # pylint: disable=redefined-outer-name -class DummyDataset(Dataset): - def __init__(self, size): - self.size = size - - def __getitem__(self, idx): - return torch.randn(1), torch.rand(10).argmax() - - def __len__(self): - return self.size - - -@Registry.register_dataset() -def dummy_dataset_dataroot(data_dir): - return DummyDataset(1) - - @Registry.register_post_process() def post_processing_func(output): return output.argmax(axis=1) -def create_dataloader(data_dir, batch_size, *args, **kwargs): - return DataLoader(DummyDataset(1)) - - def get_dataloader_config(): return { "input_model": { @@ -53,6 +32,16 @@ def get_dataloader_config(): "io_config": {"input_names": ["input"], "output_names": ["output"], "input_shapes": [(1, 1)]}, }, }, + "data_configs": [ + { + "name": "test_data_config", + "type": "DummyDataContainer", + "load_dataset_config": { + "type": "dummy_dataset", + "params": {"data_dir": "data", "input_shapes": [(1, 1)], "max_samples": 1}, + }, + } + ], "evaluators": { "common_evaluator": { "metrics": [ @@ -66,10 +55,8 @@ def get_dataloader_config(): "metric_config": {"num_classes": 10, "task": "multiclass"}, } ], + "data_config": "test_data_config", "user_config": { - "data_dir": "data", - "dataloader_func": create_dataloader, - "batch_size": 16, "post_processing_func": post_processing_func, }, } @@ -81,7 +68,7 @@ def get_dataloader_config(): "perf_tuning": { "type": "OrtPerfTuning", "config": { - "dataloader_func": create_dataloader, + "dataloader_func": create_dummy_dataloader, "batch_size": 16, "data_dir": "data", }, @@ -108,14 +95,15 @@ def get_data_config(): }, }, "data_configs": [ - DataConfig( - name="test_data_config", - load_dataset_config=DataComponentConfig( - type="dummy_dataset_dataroot", - params={"data_dir": "data"}, - ), - post_process_data_config=DataComponentConfig(type="post_processing_func"), - ) + { + "name": "test_data_config", + "type": "DummyDataContainer", + "load_dataset_config": { + "type": "dummy_dataset", + "params": {"data_dir": "data", "input_shapes": [(1, 1)], "max_samples": 1}, + }, + "post_process_data_config": {"type": "post_processing_func"}, + } ], "evaluators": { "common_evaluator": { @@ -144,10 +132,11 @@ def get_data_config(): # "data_config": "test_data_config" # This is just demo purpose to show how to use data_config in passes "data_config": DataConfig( - name="test_data_config", + name="test_data_config_inlined", + type="DummyDataContainer", load_dataset_config=DataComponentConfig( - type="dummy_dataset_dataroot", - params={"data_dir": "perfdata"}, + type="dummy_dataset", + params={"data_dir": "perfdata", "input_shapes": [(1, 1)], "max_samples": 1}, ), post_process_data_config=DataComponentConfig(type="post_processing_func"), ) @@ -227,11 +216,18 @@ def test_data_root_for_dataset(mock_get_local_path, data_config): config_obj = data_config data_root = config_obj.get("data_root") - mock = MagicMock(side_effect=dummy_dataset_dataroot) - Registry.register_dataset("dummy_dataset_dataroot")(mock) + mock = MagicMock(side_effect=dummy_dataset) + Registry.register_dataset("dummy_dataset")(mock) best = olive_run(config_obj) - mock.assert_called_with(data_dir=concat_data_dir(data_root, "data")) - - data_dir_expected = concat_data_dir(data_root, "perfdata") - mock.assert_any_call(data_dir=data_dir_expected) + mock.assert_called_with(data_dir=concat_data_dir(data_root, "data"), input_shapes=[(1, 1)], max_samples=1) + if data_root is None: + mock.assert_any_call( + data_dir=concat_data_dir(data_root, "perfdata"), + input_shapes=[(1, 1)], + input_names=None, + input_types=None, + max_samples=1, + ) + else: + mock.assert_any_call(data_dir=concat_data_dir(data_root, "perfdata"), input_shapes=[(1, 1)], max_samples=1) assert best is not None diff --git a/test/unit_test/utils.py b/test/unit_test/utils.py index 608792b56..ae950537e 100644 --- a/test/unit_test/utils.py +++ b/test/unit_test/utils.py @@ -9,9 +9,11 @@ import numpy as np import torch import torch.nn as nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader +from olive.common.config_utils import validate_config from olive.constants import Framework +from olive.data.component.dataset import DummyDataset from olive.data.config import DataComponentConfig, DataConfig from olive.data.registry import Registry from olive.evaluator.metric import AccuracySubType, LatencySubType, Metric, MetricType @@ -23,59 +25,39 @@ class DummyModel(nn.Module): - def __init__(self): + def __init__(self, batch_size=1): super().__init__() - self.fc1 = nn.Linear(1, 10) + self.fc1 = nn.Linear(batch_size, 10) def forward(self, x): - return torch.relu(self.fc1(x)) + return torch.sigmoid(self.fc1(x)) -class DummyDataset(Dataset): - def __init__(self, size): - self.size = size - - def __getitem__(self, idx): - return torch.randn(1), torch.rand(10) - - def __len__(self): - return self.size - - -class FixedDummyDataset(Dataset): - def __init__(self, size): - self.size = size - self.rng = np.random.default_rng(0) - self.data = torch.tensor(self.rng.random((size, 1))) - self.labels = torch.tensor(self.rng.random(1)) - - def __getitem__(self, idx): - return self.data[idx], self.labels[idx] - - def __len__(self): - return self.size +# TODO(shaahji): Remove this once perf_tuning pass supports DataConfig +def create_dummy_dataloader(data_dir, batch_size=1, max_samples=32, **kwargs): + return DataLoader(DummyDataset([(batch_size or 1, 1)], max_samples=max_samples), batch_size=None) def pytorch_model_loader(model_path): return DummyModel().eval() -def get_pytorch_model_config(): +def get_pytorch_model_config(batch_size=1): config = { "type": "PyTorchModel", "config": { "model_loader": pytorch_model_loader, - "io_config": {"input_names": ["input"], "output_names": ["output"], "input_shapes": [(1, 1)]}, + "io_config": {"input_names": ["input"], "output_names": ["output"], "input_shapes": [(batch_size, 1)]}, }, } return ModelConfig.parse_obj(config) -def get_pytorch_model(): +def get_pytorch_model(batch_size=1): return PyTorchModelHandler( model_loader=pytorch_model_loader, model_path=None, - io_config={"input_names": ["input"], "output_names": ["output"], "input_shapes": [(1, 1)]}, + io_config={"input_names": ["input"], "output_names": ["output"], "input_shapes": [(batch_size, 1)]}, ) @@ -98,8 +80,8 @@ def get_hf_model_with_past(): ) -def get_pytorch_model_dummy_input(model=None): - return torch.randn(1, 1) +def get_pytorch_model_dummy_input(model=None, batch_size=1): + return torch.randn(batch_size, 1) def create_onnx_model_file(): @@ -110,9 +92,9 @@ def create_onnx_model_file(): ) -def create_onnx_model_with_dynamic_axis(onnx_model_path): +def create_onnx_model_with_dynamic_axis(onnx_model_path, batch_size=1): pytorch_model = pytorch_model_loader(model_path=None) - dummy_input = get_pytorch_model_dummy_input(pytorch_model) + dummy_input = get_pytorch_model_dummy_input(pytorch_model, batch_size) torch.onnx.export( pytorch_model, dummy_input, @@ -162,23 +144,28 @@ def get_mock_openvino_model(): return olive_model -def create_dataloader(data_dir, batch_size, *args, **kwargs): - return DataLoader(DummyDataset(1)) - - -def create_fixed_dataloader(data_dir, batch_size, *args, **kwargs): - return DataLoader(FixedDummyDataset(1)) +def _get_dummy_data_config(name, input_shapes, max_samples=1): + data_config = DataConfig( + name=name, + type="DummyDataContainer", + load_dataset_config=DataComponentConfig( + params={ + "input_shapes": input_shapes, + "max_samples": max_samples, + } + ), + post_process_data_config=DataComponentConfig(type="text_classification_post_process"), + ) + return validate_config(data_config, DataConfig) def get_accuracy_metric( *acc_subtype, - random_dataloader=True, user_config=None, backend="torch_metrics", goal_type="threshold", goal_value=0.99, ): - accuracy_metric_config = {"dataloader_func": create_dataloader if random_dataloader else create_fixed_dataloader} accuracy_score_metric_config = {"task": "multiclass", "num_classes": 10} sub_types = [ { @@ -193,8 +180,9 @@ def get_accuracy_metric( name="accuracy", type=MetricType.ACCURACY, sub_types=sub_types, - user_config=user_config or accuracy_metric_config, + user_config=user_config, backend=backend, + data_config=_get_dummy_data_config("accuracy_metric_data_config", [[1, 1]]), ) @@ -233,24 +221,24 @@ def get_custom_metric_no_eval(): def get_latency_metric(*lat_subtype, user_config=None): - latency_metric_config = {"dataloader_func": create_dataloader} sub_types = [{"name": sub} for sub in lat_subtype] return Metric( name="latency", type=MetricType.LATENCY, sub_types=sub_types, - user_config=user_config or latency_metric_config, + user_config=user_config, + data_config=_get_dummy_data_config("latency_metric_data_config", [[1, 1]]), ) def get_throughput_metric(*lat_subtype, user_config=None): - metric_config = {"dataloader_func": create_dataloader} sub_types = [{"name": sub} for sub in lat_subtype] return Metric( name="throughput", type=MetricType.THROUGHPUT, sub_types=sub_types, - user_config=user_config or metric_config, + user_config=user_config, + data_config=_get_dummy_data_config("throughput_metric_data_config", [[1, 1]]), ) diff --git a/test/unit_test/workflows/test_workfow_run.py b/test/unit_test/workflows/test_workflow_run.py similarity index 91% rename from test/unit_test/workflows/test_workfow_run.py rename to test/unit_test/workflows/test_workflow_run.py index 6eee7bced..62549022b 100644 --- a/test/unit_test/workflows/test_workfow_run.py +++ b/test/unit_test/workflows/test_workflow_run.py @@ -1,4 +1,9 @@ -from test.unit_test.utils import create_dataloader, get_pytorch_model, get_pytorch_model_config, pytorch_model_loader +from test.unit_test.utils import ( + create_dummy_dataloader, + get_pytorch_model, + get_pytorch_model_config, + pytorch_model_loader, +) from unittest.mock import patch import pytest @@ -24,7 +29,7 @@ "name": "avg", }, ], - "user_config": {"dataloader_func": create_dataloader}, + "user_config": {"dataloader_func": create_dummy_dataloader}, } ] }