From d5606cd7ee394ba9444ef509021720ebe63c9856 Mon Sep 17 00:00:00 2001 From: Adam Louly Date: Thu, 29 Feb 2024 13:40:56 -0800 Subject: [PATCH] Introducing customizable input names for loss in generate_artifacts. (#19705) # loss function extra inputs. Currently, the loss functions in onnxblock expect exactly two inputs in their build method. Occasionally, models may pass additional inputs, causing the build function to fail. To solve this issue, we can let users pass a list of loss input names to be used in the loss function. --- .../orttraining/python/training/artifacts.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 7a4eb251bc5bc..4e76174d8255e 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -48,6 +48,7 @@ def generate_artifacts( custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None, additional_output_names: Optional[List[str]] = None, nominal_checkpoint: bool = False, + loss_input_names: Optional[List[str]] = None, ) -> None: """Generates artifacts required for training with ORT training api. @@ -77,7 +78,9 @@ def generate_artifacts( Default is False. Nominal checkpoint is a checkpoint that contains nominal information about the model parameters. It can be used on the device to reduce overhead while constructing the training model as well as to reduce the size of the checkpoint packaged with the on-device application. - + loss_input_names: Specifies a list of input names to be used specifically for the loss computation. When provided, + only these inputs will be passed to the loss function. If `None`, all graph outputs are passed to + the loss function. Raises: RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block` RuntimeError: If the optimizer provided is not one of the supported optimizers. @@ -111,11 +114,16 @@ def generate_artifacts( logging.info("Custom loss block provided: %s", loss.__class__.__name__) class _TrainingBlock(onnxblock.TrainingBlock): - def __init__(self, _loss): + def __init__(self, _loss, _loss_input_names=None): super().__init__() self._loss = _loss + self._loss_input_names = _loss_input_names def build(self, *inputs_to_loss): + # If loss_input_names is passed, only pass the specified input names to the loss function. + if self._loss_input_names: + inputs_to_loss = self._loss_input_names + if additional_output_names: # If additional output names is not a list, raise an error if not isinstance(additional_output_names, list): @@ -132,7 +140,7 @@ def build(self, *inputs_to_loss): return self._loss(*inputs_to_loss) - training_block = _TrainingBlock(loss_block) + training_block = _TrainingBlock(loss_block, loss_input_names) if requires_grad is not None and frozen_params is not None and set(requires_grad).intersection(set(frozen_params)): raise RuntimeError( @@ -157,9 +165,11 @@ def build(self, *inputs_to_loss): logging.info("Custom op library provided: %s", custom_op_library) custom_op_library_path = pathlib.Path(custom_op_library) - with onnxblock.base(model), onnxblock.custom_op_library( - custom_op_library_path - ) if custom_op_library is not None else contextlib.nullcontext(): + with onnxblock.base(model), ( + onnxblock.custom_op_library(custom_op_library_path) + if custom_op_library is not None + else contextlib.nullcontext() + ): _ = training_block(*[output.name for output in model.graph.output]) training_model, eval_model = training_block.to_model_proto() model_params = training_block.parameters()