Skip to content

Commit

Permalink
Introducing customizable input names for loss in generate_artifacts. (m…
Browse files Browse the repository at this point in the history
…icrosoft#19705)

# loss function extra inputs.
Currently, the loss functions in onnxblock expect exactly two inputs in
their build method.
Occasionally, models may pass additional inputs, causing the build
function to fail.
To solve this issue, we can let users pass a list of loss input names to
be used in the loss function.
  • Loading branch information
AdamLouly authored and Zhenze Wang committed Mar 7, 2024
1 parent d0ab588 commit 3bc4579
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def generate_artifacts(
custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None,
additional_output_names: Optional[List[str]] = None,
nominal_checkpoint: bool = False,
loss_input_names: Optional[List[str]] = None,
) -> None:
"""Generates artifacts required for training with ORT training api.
Expand Down Expand Up @@ -77,7 +78,9 @@ def generate_artifacts(
Default is False. Nominal checkpoint is a checkpoint that contains nominal information about the model
parameters. It can be used on the device to reduce overhead while constructing the training model
as well as to reduce the size of the checkpoint packaged with the on-device application.
loss_input_names: Specifies a list of input names to be used specifically for the loss computation. When provided,
only these inputs will be passed to the loss function. If `None`, all graph outputs are passed to
the loss function.
Raises:
RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block`
RuntimeError: If the optimizer provided is not one of the supported optimizers.
Expand Down Expand Up @@ -111,11 +114,16 @@ def generate_artifacts(
logging.info("Custom loss block provided: %s", loss.__class__.__name__)

class _TrainingBlock(onnxblock.TrainingBlock):
def __init__(self, _loss):
def __init__(self, _loss, _loss_input_names=None):
super().__init__()
self._loss = _loss
self._loss_input_names = _loss_input_names

def build(self, *inputs_to_loss):
# If loss_input_names is passed, only pass the specified input names to the loss function.
if self._loss_input_names:
inputs_to_loss = self._loss_input_names

if additional_output_names:
# If additional output names is not a list, raise an error
if not isinstance(additional_output_names, list):
Expand All @@ -132,7 +140,7 @@ def build(self, *inputs_to_loss):

return self._loss(*inputs_to_loss)

training_block = _TrainingBlock(loss_block)
training_block = _TrainingBlock(loss_block, loss_input_names)

if requires_grad is not None and frozen_params is not None and set(requires_grad).intersection(set(frozen_params)):
raise RuntimeError(
Expand All @@ -157,9 +165,11 @@ def build(self, *inputs_to_loss):
logging.info("Custom op library provided: %s", custom_op_library)
custom_op_library_path = pathlib.Path(custom_op_library)

with onnxblock.base(model), onnxblock.custom_op_library(
custom_op_library_path
) if custom_op_library is not None else contextlib.nullcontext():
with onnxblock.base(model), (
onnxblock.custom_op_library(custom_op_library_path)
if custom_op_library is not None
else contextlib.nullcontext()
):
_ = training_block(*[output.name for output in model.graph.output])
training_model, eval_model = training_block.to_model_proto()
model_params = training_block.parameters()
Expand Down

0 comments on commit 3bc4579

Please sign in to comment.