Skip to content

Commit

Permalink
add loss iunput names to generate artifact
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamLouly committed Feb 28, 2024
1 parent a93c31e commit ad53900
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def generate_artifacts(
custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None,
additional_output_names: Optional[List[str]] = None,
nominal_checkpoint: bool = False,
loss_input_names: Optional[List[str]] = None,
) -> None:
"""Generates artifacts required for training with ORT training api.
Expand Down Expand Up @@ -77,7 +78,9 @@ def generate_artifacts(
Default is False. Nominal checkpoint is a checkpoint that contains nominal information about the model
parameters. It can be used on the device to reduce overhead while constructing the training model
as well as to reduce the size of the checkpoint packaged with the on-device application.
loss_input_names: Specifies a list of input names to be used specifically for the loss computation. When provided,
only these inputs will be passed to the loss function. If `None`, all graph outputs are passed to
the loss function.
Raises:
RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block`
RuntimeError: If the optimizer provided is not one of the supported optimizers.
Expand Down Expand Up @@ -111,11 +114,17 @@ def generate_artifacts(
logging.info("Custom loss block provided: %s", loss.__class__.__name__)

class _TrainingBlock(onnxblock.TrainingBlock):
def __init__(self, _loss):
def __init__(self, _loss, _loss_input_names=None):
super().__init__()
self._loss = _loss
self._loss_input_names = _loss_input_names

def build(self, *inputs_to_loss):

# If loss_input_names is passed, only pass the specified input names to the loss function.
if self._loss_input_names:
inputs_to_loss = self._loss_input_names

if additional_output_names:
# If additional output names is not a list, raise an error
if not isinstance(additional_output_names, list):
Expand All @@ -132,7 +141,7 @@ def build(self, *inputs_to_loss):

return self._loss(*inputs_to_loss)

training_block = _TrainingBlock(loss_block)
training_block = _TrainingBlock(loss_block, loss_input_names)

if requires_grad is not None and frozen_params is not None and set(requires_grad).intersection(set(frozen_params)):
raise RuntimeError(
Expand All @@ -157,9 +166,11 @@ def build(self, *inputs_to_loss):
logging.info("Custom op library provided: %s", custom_op_library)
custom_op_library_path = pathlib.Path(custom_op_library)

with onnxblock.base(model), onnxblock.custom_op_library(
custom_op_library_path
) if custom_op_library is not None else contextlib.nullcontext():
with onnxblock.base(model), (
onnxblock.custom_op_library(custom_op_library_path)
if custom_op_library is not None
else contextlib.nullcontext()
):
_ = training_block(*[output.name for output in model.graph.output])
training_model, eval_model = training_block.to_model_proto()
model_params = training_block.parameters()
Expand Down

0 comments on commit ad53900

Please sign in to comment.