From da7be5c26d57697e0506b3262a128526f54caa91 Mon Sep 17 00:00:00 2001 From: Adam Louly Date: Tue, 27 Feb 2024 21:42:13 +0000 Subject: [PATCH] add loss input names option --- .../orttraining/python/training/artifacts.py | 21 ++++++++++++++----- .../python/training/onnxblock/blocks.py | 1 - .../python/training/onnxblock/loss/loss.py | 8 +++---- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 7a4eb251bc5bc..d5e0e6b78364c 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -48,6 +48,7 @@ def generate_artifacts( custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None, additional_output_names: Optional[List[str]] = None, nominal_checkpoint: bool = False, + loss_input_names: Optional[List[str]] = None, ) -> None: """Generates artifacts required for training with ORT training api. @@ -77,7 +78,9 @@ def generate_artifacts( Default is False. Nominal checkpoint is a checkpoint that contains nominal information about the model parameters. It can be used on the device to reduce overhead while constructing the training model as well as to reduce the size of the checkpoint packaged with the on-device application. - + loss_input_names: Specifies a list of input names to be used specifically for the loss computation. When provided, + only these inputs will be passed to the loss function. If `None`, all graph outputs are passed to + the loss function. Raises: RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block` RuntimeError: If the optimizer provided is not one of the supported optimizers. @@ -111,11 +114,17 @@ def generate_artifacts( logging.info("Custom loss block provided: %s", loss.__class__.__name__) class _TrainingBlock(onnxblock.TrainingBlock): - def __init__(self, _loss): + def __init__(self, _loss, _loss_input_names=None): super().__init__() self._loss = _loss + self._loss_input_names = _loss_input_names def build(self, *inputs_to_loss): + + # If loss_input_names is passed, only pass the specified input names to the loss function. + if self._loss_input_names: + inputs_to_loss = self._loss_input_names + if additional_output_names: # If additional output names is not a list, raise an error if not isinstance(additional_output_names, list): @@ -157,9 +166,11 @@ def build(self, *inputs_to_loss): logging.info("Custom op library provided: %s", custom_op_library) custom_op_library_path = pathlib.Path(custom_op_library) - with onnxblock.base(model), onnxblock.custom_op_library( - custom_op_library_path - ) if custom_op_library is not None else contextlib.nullcontext(): + with onnxblock.base(model), ( + onnxblock.custom_op_library(custom_op_library_path) + if custom_op_library is not None + else contextlib.nullcontext() + ): _ = training_block(*[output.name for output in model.graph.output]) training_model, eval_model = training_block.to_model_proto() model_params = training_block.parameters() diff --git a/orttraining/orttraining/python/training/onnxblock/blocks.py b/orttraining/orttraining/python/training/onnxblock/blocks.py index 43b6cbdf3b877..ec54c0916d572 100644 --- a/orttraining/orttraining/python/training/onnxblock/blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/blocks.py @@ -53,7 +53,6 @@ def __call__(self, *args, **kwargs): # Check if the error is specifically due to exceeding the maximum protobuf size. if "exceeds maximum protobuf size of 2GB" in str(e): logging.info("Handling large model that exceeds the maximum protobuf size of 2GB.") - pass else: # If the error is for any other reason, re-raise it to not silently ignore important issues. raise diff --git a/orttraining/orttraining/python/training/onnxblock/loss/loss.py b/orttraining/orttraining/python/training/onnxblock/loss/loss.py index 54f40c47ce4cc..e719301e13f48 100644 --- a/orttraining/orttraining/python/training/onnxblock/loss/loss.py +++ b/orttraining/orttraining/python/training/onnxblock/loss/loss.py @@ -33,7 +33,7 @@ def __init__(self, reduction: str = "mean"): self._sub = blocks.Sub() self._square = blocks.Pow(2.0) - def build(self, loss_input_name: str, target_name: str = "target", *args): + def build(self, loss_input_name: str, target_name: str = "target"): """Adds an MSELoss subgraph on top of the base_model. Args: @@ -72,7 +72,7 @@ def __init__(self, weight=None, reduction: str = "mean", ignore_index: Optional[ self._reduction = reduction self._ignore_index = ignore_index - def build(self, scores_input_name: str, labels_name: str = "labels", *args): + def build(self, scores_input_name: str, labels_name: str = "labels"): """Adds a CrossEntropyLoss subgraph on top of an onnx model. Args: @@ -149,7 +149,7 @@ def __init__(self, weight=None, reduction: str = "mean", pos_weight=None): self._mul = blocks.Mul() self._neg = blocks.Neg() - def build(self, loss_input_name: str, target_name: str = "target", *args): + def build(self, loss_input_name: str, target_name: str = "target"): """Adds a BCEWithLogitsLoss subgraph on top of an onnx model. Creates a block that measures the binary cross entropy with logits between @@ -229,7 +229,7 @@ def __init__(self, reduction: str = "mean"): self._abs = blocks.Abs() self._sub = blocks.Sub() - def build(self, loss_input_name: str, target_name: Optional[str] = "target", *args): + def build(self, loss_input_name: str, target_name: Optional[str] = "target"): """Adds an L1 loss subgraph on top of the base_model. Args: