Skip to content

Commit

Permalink
lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Jan 6, 2024
1 parent 79708a1 commit a41aa27
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, reduction: str = "mean"):

reduction_blocks = {"mean": blocks.ReduceMean, "sum": blocks.ReduceSum, "none": blocks.PassThrough}

self._reduce = reduction_blocks[reduction](keepdims = False)
self._reduce = reduction_blocks[reduction](keepdims=False)
self._sub = blocks.Sub()
self._square = blocks.Pow(2.0)

Expand Down Expand Up @@ -139,7 +139,7 @@ def __init__(self, weight=None, reduction: str = "mean", pos_weight=None):
reduction_blocks = {"mean": blocks.ReduceMean, "sum": blocks.ReduceSum, "none": blocks.PassThrough}

self._weight = weight
self._reduce = reduction_blocks[reduction](keepdims = False)
self._reduce = reduction_blocks[reduction](keepdims=False)
self._pos_weight = pos_weight

self._sigmoid = blocks.Sigmoid()
Expand Down Expand Up @@ -225,7 +225,7 @@ def __init__(self, reduction: str = "mean"):
raise RuntimeError(f"Reduction {reduction} not supported.")

reduction_blocks = {"mean": blocks.ReduceMean, "sum": blocks.ReduceSum, "none": blocks.PassThrough}
self._reduce = reduction_blocks[reduction](keepdims = False)
self._reduce = reduction_blocks[reduction](keepdims=False)
self._abs = blocks.Abs()
self._sub = blocks.Sub()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1035,11 +1035,8 @@ def __init__(self):
self._loss2 = onnxblock.loss.BCEWithLogitsLoss()
self._add = onnxblock.blocks.Add()


def build(self, input1, input2):
return self._add(
self._loss1(input1, target_name="target1"),
self._loss2(input2, target_name="target2"))
return self._add(self._loss1(input1, target_name="target1"), self._loss2(input2, target_name="target2"))

model = ModelWithTwoOutputs()
onnx_model = _get_onnx_model(model, (torch.randn(20, 100, 35, 45), torch.randn(40, 100, 70)))
Expand Down

0 comments on commit a41aa27

Please sign in to comment.