diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 7a4eb251bc5bc..4e76174d8255e 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -48,6 +48,7 @@ def generate_artifacts( custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None, additional_output_names: Optional[List[str]] = None, nominal_checkpoint: bool = False, + loss_input_names: Optional[List[str]] = None, ) -> None: """Generates artifacts required for training with ORT training api. @@ -77,7 +78,9 @@ def generate_artifacts( Default is False. Nominal checkpoint is a checkpoint that contains nominal information about the model parameters. It can be used on the device to reduce overhead while constructing the training model as well as to reduce the size of the checkpoint packaged with the on-device application. - + loss_input_names: Specifies a list of input names to be used specifically for the loss computation. When provided, + only these inputs will be passed to the loss function. If `None`, all graph outputs are passed to + the loss function. Raises: RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block` RuntimeError: If the optimizer provided is not one of the supported optimizers. @@ -111,11 +114,16 @@ def generate_artifacts( logging.info("Custom loss block provided: %s", loss.__class__.__name__) class _TrainingBlock(onnxblock.TrainingBlock): - def __init__(self, _loss): + def __init__(self, _loss, _loss_input_names=None): super().__init__() self._loss = _loss + self._loss_input_names = _loss_input_names def build(self, *inputs_to_loss): + # If loss_input_names is passed, only pass the specified input names to the loss function. + if self._loss_input_names: + inputs_to_loss = self._loss_input_names + if additional_output_names: # If additional output names is not a list, raise an error if not isinstance(additional_output_names, list): @@ -132,7 +140,7 @@ def build(self, *inputs_to_loss): return self._loss(*inputs_to_loss) - training_block = _TrainingBlock(loss_block) + training_block = _TrainingBlock(loss_block, loss_input_names) if requires_grad is not None and frozen_params is not None and set(requires_grad).intersection(set(frozen_params)): raise RuntimeError( @@ -157,9 +165,11 @@ def build(self, *inputs_to_loss): logging.info("Custom op library provided: %s", custom_op_library) custom_op_library_path = pathlib.Path(custom_op_library) - with onnxblock.base(model), onnxblock.custom_op_library( - custom_op_library_path - ) if custom_op_library is not None else contextlib.nullcontext(): + with onnxblock.base(model), ( + onnxblock.custom_op_library(custom_op_library_path) + if custom_op_library is not None + else contextlib.nullcontext() + ): _ = training_block(*[output.name for output in model.graph.output]) training_model, eval_model = training_block.to_model_proto() model_params = training_block.parameters()