Skip to content

Commit

Permalink
[ORT Training] Some important updates of ONNX Runtime training APIs (#…
Browse files Browse the repository at this point in the history
…1335)

* update trainer

* update args

* update to main

* update to 4.33

* fix style

* make style

* fix when testing

* Update optimum/onnxruntime/trainer_seq2seq.py

Co-authored-by: fxmarty <[email protected]>

* deprecate ort inf

* deprectae ort inf for seq2seq

* update trainer and its args to main

* try CI permission

* update tests

* update examples

* withdraw CI change

---------

Co-authored-by: JingyaHuang <[email protected]>
Co-authored-by: fxmarty <[email protected]>
  • Loading branch information
3 people authored Oct 18, 2023
1 parent 2f7d396 commit 85e6fff
Show file tree
Hide file tree
Showing 20 changed files with 1,340 additions and 2,548 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ CMD nvidia-smi
ENV DEBIAN_FRONTEND noninteractive

# Versions
# available options 3.8, 3.9, 3.10, 3.11
ARG PYTHON_VERSION=3.9
ARG TORCH_CUDA_VERSION=cu118
ARG TORCH_VERSION=2.0.0
Expand All @@ -34,7 +35,7 @@ SHELL ["/bin/bash", "-c"]
# Install and update tools to minimize security vulnerabilities
RUN apt-get update
RUN apt-get install -y software-properties-common wget apt-utils patchelf git libprotobuf-dev protobuf-compiler cmake \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev && \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev ffmpeg && \
apt-get clean
RUN unattended-upgrade
RUN apt-get autoremove -y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ARG TORCHVISION_VERSION=0.14.1
# Install and update tools to minimize security vulnerabilities
RUN apt-get update
RUN apt-get install -y software-properties-common wget apt-utils patchelf git libprotobuf-dev protobuf-compiler cmake \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev && \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev ffmpeg && \
apt-get clean
RUN unattended-upgrade
RUN apt-get autoremove -y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ARG TORCHVISION_VERSION=0.15.1
# Install and update tools to minimize security vulnerabilities
RUN apt-get update
RUN apt-get install -y software-properties-common wget apt-utils patchelf git libprotobuf-dev protobuf-compiler cmake \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev && \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev ffmpeg && \
apt-get clean
RUN unattended-upgrade
RUN apt-get autoremove -y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ SHELL ["/bin/bash", "-c"]
# Install and update tools to minimize security vulnerabilities
RUN apt-get update
RUN apt-get install -y software-properties-common wget apt-utils patchelf git libprotobuf-dev protobuf-compiler cmake \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev && \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev ffmpeg && \
apt-get clean
RUN unattended-upgrade
RUN apt-get autoremove -y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional

Expand Down Expand Up @@ -54,7 +55,7 @@
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.26.0")
check_min_version("4.34.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

Expand Down Expand Up @@ -141,12 +142,28 @@ class ModelArguments:
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
token: str = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
"execute code present on the Hub on your local machine."
)
},
)
Expand All @@ -162,32 +179,24 @@ def collate_fn(examples):
return {"pixel_values": pixel_values, "labels": labels}


@dataclass
class InferenceArguments:
"""
Arguments for inference(evaluate, predict).
"""

inference_with_ort: bool = field(
default=False,
metadata={"help": "Whether use ONNX Runtime as backend for inference. Default set to false."},
)


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ORTTrainingArguments, InferenceArguments))
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ORTTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args, inference_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, inference_args = parser.parse_args_into_dataclasses()
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if model_args.use_auth_token is not None:
warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
Expand All @@ -200,6 +209,10 @@ def main():
handlers=[logging.StreamHandler(sys.stdout)],
)

if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
Expand All @@ -209,7 +222,7 @@ def main():
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")

Expand Down Expand Up @@ -238,7 +251,7 @@ def main():
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
task="image-classification",
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)
else:
data_files = {}
Expand Down Expand Up @@ -285,22 +298,25 @@ def compute_metrics(p):
finetuning_task="image-classification",
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
model = AutoModelForImageClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
image_processor = AutoImageProcessor.from_pretrained(
model_args.image_processor_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)

# Define torchvision transforms to be applied to each image.
Expand Down Expand Up @@ -367,7 +383,6 @@ def val_transforms(example_batch):
compute_metrics=compute_metrics,
tokenizer=image_processor,
data_collator=collate_fn,
feature="image-classification",
)

# Training
Expand All @@ -385,7 +400,7 @@ def val_transforms(example_batch):

# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(inference_with_ort=inference_args.inference_with_ort)
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

Expand Down
Loading

0 comments on commit 85e6fff

Please sign in to comment.