Skip to content

Commit

Permalink
Offline tooling for training to use reduction with keepdims=False (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Jan 11, 2024
1 parent 4694edc commit 58bf836
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 8 deletions.
12 changes: 7 additions & 5 deletions orttraining/orttraining/python/training/onnxblock/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ def build(self, pow_input_name):
class _UnaryOp(Block):
"""Base class for all nodes that take in a single argument."""

def __init__(self, op_name):
def __init__(self, op_name, **attributes):
super().__init__()
self._op_name = op_name
self._attributes = attributes

def build(self, input_name):
# get the model to manipulate
Expand All @@ -165,6 +166,7 @@ def build(self, input_name):
node_input_names,
node_output_names,
_graph_utils.generate_graph_name(self._op_name),
**self._attributes,
)
onnx_model.graph.node.append(node)

Expand All @@ -174,15 +176,15 @@ def build(self, input_name):
class ReduceMean(_UnaryOp):
"""Adds ReduceMean node to the onnx model."""

def __init__(self):
super().__init__("ReduceMean")
def __init__(self, keepdims=True):
super().__init__("ReduceMean", keepdims=keepdims)


class ReduceSum(_UnaryOp):
"""Adds ReduceSum node to the onnx model."""

def __init__(self):
super().__init__("ReduceSum")
def __init__(self, keepdims=True):
super().__init__("ReduceSum", keepdims=keepdims)


class Sigmoid(_UnaryOp):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, reduction: str = "mean"):

reduction_blocks = {"mean": blocks.ReduceMean, "sum": blocks.ReduceSum, "none": blocks.PassThrough}

self._reduce = reduction_blocks[reduction]()
self._reduce = reduction_blocks[reduction](keepdims=False)
self._sub = blocks.Sub()
self._square = blocks.Pow(2.0)

Expand Down Expand Up @@ -139,7 +139,7 @@ def __init__(self, weight=None, reduction: str = "mean", pos_weight=None):
reduction_blocks = {"mean": blocks.ReduceMean, "sum": blocks.ReduceSum, "none": blocks.PassThrough}

self._weight = weight
self._reduce = reduction_blocks[reduction]()
self._reduce = reduction_blocks[reduction](keepdims=False)
self._pos_weight = pos_weight

self._sigmoid = blocks.Sigmoid()
Expand Down Expand Up @@ -225,7 +225,7 @@ def __init__(self, reduction: str = "mean"):
raise RuntimeError(f"Reduction {reduction} not supported.")

reduction_blocks = {"mean": blocks.ReduceMean, "sum": blocks.ReduceSum, "none": blocks.PassThrough}
self._reduce = reduction_blocks[reduction]()
self._reduce = reduction_blocks[reduction](keepdims=False)
self._abs = blocks.Abs()
self._sub = blocks.Sub()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1017,3 +1017,33 @@ def test_save_ort_format():
raise AssertionError(f"Opsets mismatch {base_opsets['']} != {eval_opsets['']}.")
if base_opsets[""] != optimizer_opsets[""]:
raise AssertionError(f"Opsets mismatch {base_opsets['']} != {optimizer_opsets['']}.")


def test_custom_loss_function():
# This test tries to add a custom loss function to the model.
# The custom loss function tries to use two model outputs of two different ranks, computes the
# two losses and returns the sum of the two losses.
# If the artifacts are generated successfully, without an exception being raised, the test passes.
class ModelWithTwoOutputs(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.randn(20, 100, 35, 45)
self.b = torch.randn(40, 100, 70)

def forward(self, x, y):
return self.a + x, self.b + y

class CustomLossBlock(onnxblock.Block):
def __init__(self):
self._loss1 = onnxblock.loss.MSELoss()
self._loss2 = onnxblock.loss.BCEWithLogitsLoss()
self._add = onnxblock.blocks.Add()

def build(self, input1, input2):
return self._add(self._loss1(input1, target_name="target1"), self._loss2(input2, target_name="target2"))

model = ModelWithTwoOutputs()
onnx_model = _get_onnx_model(model, (torch.randn(20, 100, 35, 45), torch.randn(40, 100, 70)))

with tempfile.TemporaryDirectory() as temp_dir:
artifacts.generate_artifacts(onnx_model, loss=CustomLossBlock(), artifact_directory=temp_dir)

0 comments on commit 58bf836

Please sign in to comment.