Skip to content

Commit

Permalink
add loss input names option
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamLouly committed Feb 27, 2024
1 parent 9384e08 commit da7be5c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
21 changes: 16 additions & 5 deletions orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def generate_artifacts(
custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None,
additional_output_names: Optional[List[str]] = None,
nominal_checkpoint: bool = False,
loss_input_names: Optional[List[str]] = None,
) -> None:
"""Generates artifacts required for training with ORT training api.
Expand Down Expand Up @@ -77,7 +78,9 @@ def generate_artifacts(
Default is False. Nominal checkpoint is a checkpoint that contains nominal information about the model
parameters. It can be used on the device to reduce overhead while constructing the training model
as well as to reduce the size of the checkpoint packaged with the on-device application.
loss_input_names: Specifies a list of input names to be used specifically for the loss computation. When provided,
only these inputs will be passed to the loss function. If `None`, all graph outputs are passed to
the loss function.
Raises:
RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block`
RuntimeError: If the optimizer provided is not one of the supported optimizers.
Expand Down Expand Up @@ -111,11 +114,17 @@ def generate_artifacts(
logging.info("Custom loss block provided: %s", loss.__class__.__name__)

class _TrainingBlock(onnxblock.TrainingBlock):
def __init__(self, _loss):
def __init__(self, _loss, _loss_input_names=None):
super().__init__()
self._loss = _loss
self._loss_input_names = _loss_input_names

def build(self, *inputs_to_loss):

# If loss_input_names is passed, only pass the specified input names to the loss function.
if self._loss_input_names:
inputs_to_loss = self._loss_input_names

if additional_output_names:
# If additional output names is not a list, raise an error
if not isinstance(additional_output_names, list):
Expand Down Expand Up @@ -157,9 +166,11 @@ def build(self, *inputs_to_loss):
logging.info("Custom op library provided: %s", custom_op_library)
custom_op_library_path = pathlib.Path(custom_op_library)

with onnxblock.base(model), onnxblock.custom_op_library(
custom_op_library_path
) if custom_op_library is not None else contextlib.nullcontext():
with onnxblock.base(model), (
onnxblock.custom_op_library(custom_op_library_path)
if custom_op_library is not None
else contextlib.nullcontext()
):
_ = training_block(*[output.name for output in model.graph.output])
training_model, eval_model = training_block.to_model_proto()
model_params = training_block.parameters()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __call__(self, *args, **kwargs):
# Check if the error is specifically due to exceeding the maximum protobuf size.
if "exceeds maximum protobuf size of 2GB" in str(e):
logging.info("Handling large model that exceeds the maximum protobuf size of 2GB.")
pass
else:
# If the error is for any other reason, re-raise it to not silently ignore important issues.
raise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, reduction: str = "mean"):
self._sub = blocks.Sub()
self._square = blocks.Pow(2.0)

def build(self, loss_input_name: str, target_name: str = "target", *args):
def build(self, loss_input_name: str, target_name: str = "target"):
"""Adds an MSELoss subgraph on top of the base_model.
Args:
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(self, weight=None, reduction: str = "mean", ignore_index: Optional[
self._reduction = reduction
self._ignore_index = ignore_index

def build(self, scores_input_name: str, labels_name: str = "labels", *args):
def build(self, scores_input_name: str, labels_name: str = "labels"):
"""Adds a CrossEntropyLoss subgraph on top of an onnx model.
Args:
Expand Down Expand Up @@ -149,7 +149,7 @@ def __init__(self, weight=None, reduction: str = "mean", pos_weight=None):
self._mul = blocks.Mul()
self._neg = blocks.Neg()

def build(self, loss_input_name: str, target_name: str = "target", *args):
def build(self, loss_input_name: str, target_name: str = "target"):
"""Adds a BCEWithLogitsLoss subgraph on top of an onnx model.
Creates a block that measures the binary cross entropy with logits between
Expand Down Expand Up @@ -229,7 +229,7 @@ def __init__(self, reduction: str = "mean"):
self._abs = blocks.Abs()
self._sub = blocks.Sub()

def build(self, loss_input_name: str, target_name: Optional[str] = "target", *args):
def build(self, loss_input_name: str, target_name: Optional[str] = "target"):
"""Adds an L1 loss subgraph on top of the base_model.
Args:
Expand Down

0 comments on commit da7be5c

Please sign in to comment.