Skip to content

Commit

Permalink
Use DummyDataLoader instead of declaring new types for tests (#1209)
Browse files Browse the repository at this point in the history
  • Loading branch information
shaahji authored Jul 8, 2024
1 parent 59ee35b commit 7e6f512
Show file tree
Hide file tree
Showing 17 changed files with 150 additions and 150 deletions.
2 changes: 1 addition & 1 deletion docs/source/tutorials/configure_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ If no data config template can meet the requirement, we can also define the `dat
* - `post_process_data <https://github.com/microsoft/Olive/blob/main/olive/data/component/post_process_data.py>`_
- post_process(default), text_classification_post_process, ner_post_process, text_generation_post_process
* - `dataloader <https://github.com/microsoft/Olive/blob/main/olive/data/component/dataloader.py>`_
- default_dataloader(default), skip_dataloader, no_auto_batch_dataloader
- default_dataloader(default), no_auto_batch_dataloader

each component can be customized by the following fields:
- ``name``: the name of the component.
Expand Down
5 changes: 0 additions & 5 deletions olive/data/component/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@
from olive.data.registry import Registry


@Registry.register_dataloader()
def skip_dataloader(dataset):
return dataset


@Registry.register_default_dataloader()
def default_dataloader(dataset, batch_size=1, **kwargs):
return DataLoader(dataset, batch_size=batch_size, **kwargs)
Expand Down
53 changes: 30 additions & 23 deletions olive/data/component/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,41 +94,48 @@ def to_hf_dataset(self, label_name="label"):


class DummyDataset(BaseDataset):
def __init__(self, input_shapes, input_names: Optional[List] = None, input_types: Optional[List] = None):
"""Initialize the dummy dataset.
def __init__(
self,
input_shapes,
input_names: Optional[List] = None,
input_types: Optional[List] = None,
max_samples: Optional[int] = 32,
):
"""Initialize the dataset with dummy data.
if input_names is None, the dummy dataset will return a tuple of tensors
else the dummy dataset will return a dict of tensors
if input_names is not provided, the dataset will return a tuple of tensors
else the dataset will return a dict of tensors
"""
# pylint: disable=super-init-not-called
self.input_shapes = input_shapes
self.input_names = input_names
self.input_types = input_types or ["float32"] * len(input_shapes)
if not input_types:
input_types = ["float32"] * len(input_shapes)
input_types = [resolve_torch_dtype(dtype_str) for dtype_str in input_types]

if input_names:
dummy_data = {}
for input_name, input_shape, input_type in zip(input_names, input_shapes, input_types):
dummy_data.update({input_name: torch.ones(input_shape, dtype=input_type)})
dummy_data = dummy_data if len(dummy_data) > 1 else dummy_data[input_names[0]]
else:
dummy_data = []
for shape, dtype in zip(input_shapes, input_types):
dummy_data.append(torch.ones(shape, dtype=dtype))
dummy_data = tuple(dummy_data) if len(dummy_data) > 1 else dummy_data[0]

self.max_samples = max_samples
self.dummy_data = dummy_data, torch.tensor([0])

def __len__(self):
return 256
return self.max_samples

def __getitem__(self, index):
# From https://docs.python.org/3/reference/datamodel.html#object.__getitem__,
# __getitem__ should raise IndexError when index is out of range
# Otherwise, the enumerate function will enter infinite loop
if index < 0 or index >= len(self):
if index < 0 or index >= self.max_samples:
raise IndexError("Index out of range")

input_types = [resolve_torch_dtype(dtype_str) for dtype_str in self.input_types]

if not self.input_names:
dummy_inputs = []
for shape, dtype in zip(self.input_shapes, input_types):
dummy_inputs.append(torch.ones(shape, dtype=dtype))
dummy_inputs = tuple(dummy_inputs) if len(dummy_inputs) > 1 else dummy_inputs[0]
else:
dummy_inputs = {}
for input_name, input_shape, input_type in zip(self.input_names, self.input_shapes, input_types):
dummy_inputs.update({input_name: torch.ones(input_shape, dtype=input_type)})
dummy_inputs = dummy_inputs if len(dummy_inputs) > 1 else dummy_inputs[self.input_names[0]]
label = 0
return dummy_inputs, label
return self.dummy_data


class RawDataset(BaseDataset):
Expand Down
4 changes: 2 additions & 2 deletions olive/data/component/load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def huggingface_dataset(data_dir, data_name=None, subset=None, split="validation


@Registry.register_dataset()
def dummy_dataset(data_dir, input_shapes, input_names=None, input_types=None):
return DummyDataset(input_shapes, input_names, input_types)
def dummy_dataset(data_dir, input_shapes, input_names=None, input_types=None, max_samples=32):
return DummyDataset(input_shapes, input_names, input_types, max_samples)


@Registry.register_dataset()
Expand Down
8 changes: 4 additions & 4 deletions olive/data/container/dummy_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ class DummyDataContainer(DataContainer):
dummy_data_config = DataConfig(
name="dummy",
type="DummyDataContainer",
"load_dataset_config"={
"params": {
load_dataset_config=DataComponentConfig(
params={
"input_names": metric.user_config.input_names,
"input_shapes": metric.user_config.input_shapes,
"input_types": metric.user_config.input_types,
}
}
),
)
"""

default_components_type: ClassVar[dict] = {
DataComponentType.LOAD_DATASET.value: "dummy_dataset",
DataComponentType.DATALOADER.value: "skip_dataloader",
DataComponentType.DATALOADER.value: "no_auto_batch_dataloader",
}
2 changes: 1 addition & 1 deletion olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def get_user_config(cls, framework: Framework, data_root: str, metric: Metric):
dataloader = user_module.call_object(
dataloader_func,
data_dir,
metric.user_config.batch_size,
batch_size=metric.user_config.batch_size,
model_framework=framework,
**cls._get_func_kwargs(metric, "dataloader_func"),
)
Expand Down
2 changes: 1 addition & 1 deletion olive/passes/pytorch/qat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def execute_local(self) -> PyTorchModelHandler:
ptl_data_module = self.config.ptl_data_module()
else:
train_dataloader_func = self.config.train_dataloader_func(
self.config.train_data_dir, self.config.train_batch_size
self.config.train_data_dir, batch_size=self.config.train_batch_size
)
ptl_module = DefaultPTLModule(model=quan_model, training_dataloader=train_dataloader_func)

Expand Down
6 changes: 5 additions & 1 deletion test/integ_test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@ def get_olive_workspace_config():
if client_id is None:
raise Exception("Please set the environment variable MANAGED_IDENTITY_CLIENT_ID")

exclude_managed_identity_credential = (
{"exclude_managed_identity_credential": True} if "EXCLUDE_MANAGED_IDENTITY_CREDENTIAL" in os.environ else {}
)

return {
"subscription_id": subscription_id,
"resource_group": resource_group,
"workspace_name": workspace_name,
# pipeline agents have multiple managed identities, so we need to specify the client_id
"default_auth_params": {"managed_identity_client_id": client_id},
"default_auth_params": {"managed_identity_client_id": client_id, **exclude_managed_identity_credential},
}


Expand Down
6 changes: 3 additions & 3 deletions test/unit_test/data_container/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def get_dict_dataset(length=256, max_samples=None):
return BaseDataset(data, ["original_label"], max_samples)


def get_dummy_dataset():
def get_dummy_dataset(length=256):
input_shapes = [[2], [3]]
input_names = ["input_1", "input_2"]
return DummyDataset(input_shapes, input_names)
return DummyDataset(input_shapes, input_names, max_samples=length)


def get_hf_dataset():
Expand Down Expand Up @@ -70,4 +70,4 @@ def test_dataset_to_hf_dataset(self, dataset_func, label_name):
# assert shape of the first sample
assert hf_dataset["input_1"][0].shape == (2,)
assert hf_dataset["input_2"][0].shape == (3,)
assert hf_dataset[label_name][0].shape == ()
assert hf_dataset[label_name][0].shape == ((1,) if isinstance(dataset, DummyDataset) else ())
6 changes: 3 additions & 3 deletions test/unit_test/data_container/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def test_dummy_template(self, input_names):
input_types=["int64", "int64", "int64"],
)
dummy_inputs, _ = dataloader.to_data_container().get_first_batch()
if not input_names:
assert isinstance(dummy_inputs, tuple), "Failed to create dummy tuple input from dummy template."
else:
if input_names:
assert isinstance(dummy_inputs, dict), "Failed to create dummy dict dataset from dummy template."
else:
assert isinstance(dummy_inputs, list), "Failed to create dummy list input from dummy template."

def test_raw_data_template(self, tmpdir):
input_names = ["float_input", "int_input"]
Expand Down
5 changes: 2 additions & 3 deletions test/unit_test/evaluator/test_olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,8 @@ def test_dataloader_func_kwargs(self, dataloader_func_kwargs):
# setup
dataloader_func = MagicMock(spec=FunctionType)
data_dir = None
batch_size = 1
model_framework = "PyTorch"
user_config = {"dataloader_func": dataloader_func, "batch_size": batch_size, "data_dir": data_dir}
user_config = {"dataloader_func": dataloader_func, "data_dir": data_dir}
if dataloader_func_kwargs:
user_config["func_kwargs"] = {"dataloader_func": dataloader_func_kwargs}
metric = get_latency_metric(LatencySubType.AVG, user_config=user_config)
Expand All @@ -301,7 +300,7 @@ def test_dataloader_func_kwargs(self, dataloader_func_kwargs):

# assert
dataloader_func.assert_called_once_with(
data_dir, batch_size, model_framework=model_framework, **(dataloader_func_kwargs or {})
data_dir, batch_size=1, model_framework=model_framework, **(dataloader_func_kwargs or {})
)

# this is enough to test the kwargs for `evaluate_func`, `metric_func` and `post_process_func`
Expand Down
10 changes: 5 additions & 5 deletions test/unit_test/passes/onnx/test_perf_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import math
import re
from test.unit_test.utils import create_dataloader, get_onnx_model
from test.unit_test.utils import create_dummy_dataloader, get_onnx_model
from unittest.mock import MagicMock, PropertyMock, patch

import psutil
Expand All @@ -21,10 +21,10 @@
@pytest.mark.parametrize(
"config",
[
{"input_names": ["input"], "input_shapes": [[1, 1]]},
{"input_names": ["input"], "input_shapes": [(1, 1)]},
{},
{"dataloader_func": create_dataloader},
{"dataloader_func": create_dataloader, "dataloader_func_kwargs": {"dummy_kwarg": 1}},
{"dataloader_func": create_dummy_dataloader},
{"dataloader_func": create_dummy_dataloader, "dataloader_func_kwargs": {"dummy_kwarg": 1}},
],
)
def test_ort_perf_tuning_pass(config, tmp_path):
Expand All @@ -42,7 +42,7 @@ def test_ort_perf_tuning_pass(config, tmp_path):
"config",
[
{},
{"input_names": ["input"], "input_shapes": [[1, 1]]},
{"input_names": ["input"], "input_shapes": [(1, 1)]},
{"providers_list": ["CPUExecutionProvider", "CUDAExecutionProvider"], "device": "gpu"},
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from test.unit_test.utils import create_dataloader, get_pytorch_model
from test.unit_test.utils import create_dummy_dataloader, get_pytorch_model

from olive.hardware.accelerator import AcceleratorSpec
from olive.passes.olive_pass import FullPassConfig, create_pass_from_dict
from olive.passes.pytorch.quantization_aware_training import QuantizationAwareTraining


# TODO(shaahji): Remove this once QuantizationAwareTraining pass supports DataConfig
def _create_dummy_dataloader(data_dir, **kwargs):
kwargs.pop("batch_size", None)
return create_dummy_dataloader(data_dir, max_samples=1, batch_size=1, **kwargs)


def test_quantization_aware_training_pass_default(tmp_path):
# setup
input_model = get_pytorch_model()
config = {
"train_dataloader_func": create_dataloader,
"train_dataloader_func": _create_dummy_dataloader,
"checkpoint_path": str(tmp_path / "checkpoint"),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def setup(self, tmp_path):

python_path = shutil.which("python", path=python_environment_path)
# install only onnxruntime
run_subprocess([python_path, "-m", "pip", "install", "onnxruntime"], env=self.system.environ)
run_subprocess([python_path, "-m", "pip", "install", "onnxruntime", "numpy<2"], env=self.system.environ)

self.evaluator = IsolatedORTEvaluator(self.system.environ)
self.onnx_evaluator = OnnxEvaluator()
Expand All @@ -125,7 +125,7 @@ def setup(self, tmp_path):

def test__inference(self):
model = get_onnx_model_config().create_model()
metric = get_accuracy_metric(AccuracySubType.ACCURACY_SCORE, random_dataloader=False)
metric = get_accuracy_metric(AccuracySubType.ACCURACY_SCORE)
metric = OliveEvaluator.generate_metric_user_config_with_model_io(metric, model)
dataloader, _, post_func = OliveEvaluator.get_user_config(model.framework, None, metric)

Expand Down
Loading

0 comments on commit 7e6f512

Please sign in to comment.