Skip to content

Commit

Permalink
Remove datasets as required dependency (#2087)
Browse files Browse the repository at this point in the history
* remove datasets required dependency

* install datasets when needed

* add datasets installed when needed

* style

* add require dataset

* divide datasets tests

* import datasets only when needed
  • Loading branch information
echarlaix authored Nov 21, 2024
1 parent a7a807c commit d2a5a6a
Show file tree
Hide file tree
Showing 17 changed files with 123 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/dev_test_benckmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Install dependencies
run: |
pip install wheel
pip install .[tests,onnxruntime,benchmark]
pip install .[tests,onnxruntime,benchmark] datasets
pip install -U git+https://github.com/huggingface/evaluate
pip install -U git+https://github.com/huggingface/diffusers
pip install -U git+https://github.com/huggingface/transformers
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_benckmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- name: Install dependencies
run: |
pip install wheel
pip install .[tests,onnxruntime,benchmark]
pip install .[tests,onnxruntime,benchmark] datasets
- name: Test with unittest
run: |
python -m unittest discover --start-directory tests/benchmark --pattern 'test_*.py'
11 changes: 10 additions & 1 deletion .github/workflows/test_utils.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,13 @@ jobs:
- name: Test with pytest
working-directory: tests
run: |
python -m pytest -s -vvvv utils
pytest utils -s -n auto -m "not datasets_test" --durations=0
- name: Install datasets
run: |
pip install datasets
- name: Tests needing datasets
working-directory: tests
run: |
pytest utils -s -n auto -m "datasets_test" --durations=0
16 changes: 15 additions & 1 deletion optimum/gptq/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@

import numpy as np
import torch
from datasets import load_dataset

from optimum.utils.import_utils import DATASETS_IMPORT_ERROR, is_datasets_available


if is_datasets_available():
from datasets import load_dataset


"""
Expand Down Expand Up @@ -113,6 +118,9 @@ def pad_block(block, pads):


def get_wikitext2(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
if not is_datasets_available():
raise ImportError(DATASETS_IMPORT_ERROR.format("get_wikitext2"))

if split == "train":
data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
elif split == "validation":
Expand All @@ -132,6 +140,9 @@ def get_wikitext2(tokenizer: Any, seqlen: int, nsamples: int, split: str = "trai


def get_c4(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
if not is_datasets_available():
raise ImportError(DATASETS_IMPORT_ERROR.format("get_c4"))

if split == "train":
data = load_dataset("allenai/c4", split="train", data_files={"train": "en/c4-train.00000-of-01024.json.gz"})
elif split == "validation":
Expand All @@ -157,6 +168,9 @@ def get_c4(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):


def get_c4_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
if not is_datasets_available():
raise ImportError(DATASETS_IMPORT_ERROR.format("get_c4_new"))

if split == "train":
data = load_dataset("allenai/c4", split="train", data_files={"train": "en/c4-train.00000-of-01024.json.gz"})
elif split == "validation":
Expand Down
2 changes: 1 addition & 1 deletion optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
dataset (`Union[List[str], str, Any]`, defaults to `None`):
The dataset used for quantization. You can provide your own dataset in a list of string or in a list of tokenized data
(e.g. [{ "input_ids": [ 1, 100, 15, ... ],"attention_mask": [ 1, 1, 1, ... ]},...])
or just use the original datasets used in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new'].
or just use the original datasets used in GPTQ paper ['wikitext2','c4','c4-new'].
group_size (int, defaults to 128):
The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
damp_percent (`float`, defaults to `0.1`):
Expand Down
15 changes: 10 additions & 5 deletions optimum/onnxruntime/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from datasets import Dataset
from packaging.version import Version, parse

from onnxruntime import __version__ as ort_version
Expand All @@ -33,6 +32,10 @@
from ..utils import logging


if TYPE_CHECKING:
from datasets import Dataset


logger = logging.get_logger(__name__)

# This value is used to indicate ORT which axis it should use to quantize an operator "per-channel"
Expand Down Expand Up @@ -117,7 +120,9 @@ def create_calibrator(

class AutoCalibrationConfig:
@staticmethod
def minmax(dataset: Dataset, moving_average: bool = False, averaging_constant: float = 0.01) -> CalibrationConfig:
def minmax(
dataset: "Dataset", moving_average: bool = False, averaging_constant: float = 0.01
) -> CalibrationConfig:
"""
Args:
dataset (`Dataset`):
Expand Down Expand Up @@ -151,7 +156,7 @@ def minmax(dataset: Dataset, moving_average: bool = False, averaging_constant: f

@staticmethod
def entropy(
dataset: Dataset,
dataset: "Dataset",
num_bins: int = 128,
num_quantized_bins: int = 128,
) -> CalibrationConfig:
Expand Down Expand Up @@ -188,7 +193,7 @@ def entropy(
)

@staticmethod
def percentiles(dataset: Dataset, num_bins: int = 2048, percentile: float = 99.999) -> CalibrationConfig:
def percentiles(dataset: "Dataset", num_bins: int = 2048, percentile: float = 99.999) -> CalibrationConfig:
"""
Args:
dataset (`Dataset`):
Expand Down
9 changes: 6 additions & 3 deletions optimum/onnxruntime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@

import logging
import os
from typing import Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union

import numpy as np
from datasets import Dataset
from transformers import EvalPrediction
from transformers.trainer_pt_utils import nested_concat
from transformers.trainer_utils import EvalLoopOutput

from onnxruntime import InferenceSession


if TYPE_CHECKING:
from datasets import Dataset


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -59,7 +62,7 @@ def __init__(
self.session = InferenceSession(str(model_path), providers=[execution_provider])
self.onnx_input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}

def evaluation_loop(self, dataset: Dataset):
def evaluation_loop(self, dataset: "Dataset"):
"""
Run evaluation and returns metrics and predictions.
Expand Down
17 changes: 11 additions & 6 deletions optimum/onnxruntime/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

import onnx
from datasets import Dataset, load_dataset
from packaging.version import Version, parse
from transformers import AutoConfig

from onnxruntime import __version__ as ort_version
from onnxruntime.quantization import CalibrationDataReader, QuantFormat, QuantizationMode, QuantType
from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
from onnxruntime.quantization.qdq_quantizer import QDQQuantizer
from optimum.utils.import_utils import requires_backends

from ..quantization_base import OptimumQuantizer
from ..utils.save_utils import maybe_save_preprocessors
Expand All @@ -40,6 +40,7 @@


if TYPE_CHECKING:
from datasets import Dataset
from transformers import PretrainedConfig

LOGGER = logging.getLogger(__name__)
Expand All @@ -48,7 +49,7 @@
class ORTCalibrationDataReader(CalibrationDataReader):
__slots__ = ["batch_size", "dataset", "_dataset_iter"]

def __init__(self, dataset: Dataset, batch_size: int = 1):
def __init__(self, dataset: "Dataset", batch_size: int = 1):
if dataset is None:
raise ValueError("Provided dataset is None.")

Expand Down Expand Up @@ -158,7 +159,7 @@ def from_pretrained(

def fit(
self,
dataset: Dataset,
dataset: "Dataset",
calibration_config: CalibrationConfig,
onnx_augmented_model_name: Union[str, Path] = "augmented_model.onnx",
operators_to_quantize: Optional[List[str]] = None,
Expand Down Expand Up @@ -212,7 +213,7 @@ def fit(

def partial_fit(
self,
dataset: Dataset,
dataset: "Dataset",
calibration_config: CalibrationConfig,
onnx_augmented_model_name: Union[str, Path] = "augmented_model.onnx",
operators_to_quantize: Optional[List[str]] = None,
Expand Down Expand Up @@ -428,7 +429,7 @@ def get_calibration_dataset(
seed: int = 2016,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
) -> Dataset:
) -> "Dataset":
"""
Creates the calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
Expand Down Expand Up @@ -474,6 +475,10 @@ def get_calibration_dataset(
"provided."
)

requires_backends(self, ["datasets"])

from datasets import load_dataset

calib_dataset = load_dataset(
dataset_name,
name=dataset_config_name,
Expand All @@ -492,7 +497,7 @@ def get_calibration_dataset(

return self.clean_calibration_dataset(processed_calib_dataset)

def clean_calibration_dataset(self, dataset: Dataset) -> Dataset:
def clean_calibration_dataset(self, dataset: "Dataset") -> "Dataset":
model = onnx.load(self.onnx_model_path)
model_inputs = {input.name for input in model.graph.input}
ignored_columns = list(set(dataset.column_names) - model_inputs)
Expand Down
10 changes: 6 additions & 4 deletions optimum/onnxruntime/runs/calibrator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import Dict, List

from datasets import Dataset
from typing import TYPE_CHECKING, Dict, List

from ...runs_base import Calibrator
from .. import ORTQuantizer
Expand All @@ -9,10 +7,14 @@
from ..preprocessors.passes import ExcludeGeLUNodes, ExcludeLayerNormNodes, ExcludeNodeAfter, ExcludeNodeFollowedBy


if TYPE_CHECKING:
from datasets import Dataset


class OnnxRuntimeCalibrator(Calibrator):
def __init__(
self,
calibration_dataset: Dataset,
calibration_dataset: "Dataset",
quantizer: ORTQuantizer,
model_path: str,
qconfig: QuantizationConfig,
Expand Down
8 changes: 5 additions & 3 deletions optimum/runs_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import subprocess
from contextlib import contextmanager
from time import perf_counter_ns
from typing import Set
from typing import TYPE_CHECKING, Set

import numpy as np
import optuna
import torch
import transformers
from datasets import Dataset
from tqdm import trange

from . import version as optimum_version
Expand All @@ -21,6 +20,9 @@
from .utils.runs import RunConfig, cpu_info_command


if TYPE_CHECKING:
from datasets import Dataset

os.environ["TOKENIZERS_PARALLELISM"] = "false"


Expand All @@ -34,7 +36,7 @@ def get_autoclass_name(task):

class Calibrator:
def __init__(
self, calibration_dataset: Dataset, quantizer, model_path, qconfig, calibration_params, node_exclusion
self, calibration_dataset: "Dataset", quantizer, model_path, qconfig, calibration_params, node_exclusion
):
self.calibration_dataset = calibration_dataset
self.quantizer = quantizer
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
check_if_transformers_greater,
is_accelerate_available,
is_auto_gptq_available,
is_datasets_available,
is_diffusers_available,
is_onnx_available,
is_onnxruntime_available,
Expand Down
12 changes: 12 additions & 0 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_auto_gptq_available = _is_package_available("auto_gptq")
_timm_available = _is_package_available("timm")
_sentence_transformers_available = _is_package_available("sentence_transformers")
_datasets_available = _is_package_available("datasets")

torch_version = None
if is_torch_available():
Expand Down Expand Up @@ -131,6 +132,10 @@ def is_sentence_transformers_available():
return _sentence_transformers_available


def is_datasets_available():
return _datasets_available


def is_auto_gptq_available():
if _auto_gptq_available:
version_autogptq = version.parse(importlib_metadata.version("auto_gptq"))
Expand Down Expand Up @@ -230,6 +235,12 @@ def require_numpy_strictly_lower(package_version: str, message: str):
-U transformers`. Please note that you may need to restart your runtime after installation.
"""

DATASETS_IMPORT_ERROR = """
{0} requires the datasets library but it was not found in your environment. You can install it with pip:
`pip install datasets`. Please note that you may need to restart your runtime after installation.
"""


BACKENDS_MAPPING = OrderedDict(
[
("diffusers", (is_diffusers_available, DIFFUSERS_IMPORT_ERROR)),
Expand All @@ -245,6 +256,7 @@ def require_numpy_strictly_lower(package_version: str, message: str):
"transformers_434",
(lambda: check_if_transformers_greater("4.34"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.34")),
),
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
]
)

Expand Down
Loading

0 comments on commit d2a5a6a

Please sign in to comment.