Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add op (batch_norm) | feat(torchlib) #1761

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,22 +1033,6 @@ def aten_bartlett_window(window_length: int) -> TensorType:
raise NotImplementedError()


def aten_batch_norm(
input: TensorType,
weight: Optional[TensorType],
bias: Optional[TensorType],
running_mean: Optional[TensorType],
running_var: Optional[TensorType],
training: bool,
momentum: float,
eps: float,
cudnn_enabled: bool,
) -> TensorType:
"""batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor"""

raise NotImplementedError()


def aten_batch_norm_backward_elemt(
grad_out: TensorType,
input: TensorType,
Expand Down
55 changes: 55 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,61 @@ def aten_avg_pool3d_backward(
raise NotImplementedError()


@torch_op("aten::batch_norm", trace_only=True)
def aten_batch_norm(
input: TFloat,
weight: Optional[TFloat] = None,
bias: Optional[TFloat] = None,
running_mean: Optional[TFloat] = None,
running_var: Optional[TFloat] = None,
training: bool = False,
momentum: float = 0.9,
eps: float = 1e-05,
cudnn_enabled: bool = False,
) -> TFloat:
"""batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor"""

# if weight is None: # Set to 1.0 as default
# weight = op.Expand(
# op.CastLike(op.Constant(value_floats=[1.0]), input),
# op.Shape(input, start=1, end=2),
# )

# if bias is None: # Set to 0.0 as default
# bias = op.Expand(
# op.CastLike(op.Constant(value_floats=[0.0]), input),
# op.Shape(input, start=1, end=2),
# )

# axes = list(range(len(input.shape)))
# axes.pop(1)
# axes = op.Constant(value_ints=axes)
# if running_mean is None: # Using input mean
# running_mean = op.Squeeze(op.ReduceMean(input, axes))

# if running_var is None: # Using input var
# mean = op.ReduceMean(input, axes)
# input_sub_mean = op.Sub(input, mean)
# sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean)
# running_var = op.Squeeze(op.ReduceMean(sqr_input_sub_mean, axes))

out = op.BatchNormalization(
input,
weight,
bias,
running_mean,
running_var,
epsilon=eps,
momentum=1 - momentum,
training_mode=training,
)
if not training:
return out
else:
res, _, _ = out
return res


def aten_binary_cross_entropy(
self: TensorType,
target: TensorType,
Expand Down
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,7 @@ def _where_input_wrangler(
"new_zeros",
core_ops.aten_new_zeros,
),
TorchLibOpInfo("nn.functional.batch_norm", nn_ops.aten_batch_norm),
TorchLibOpInfo("nn.functional.celu", nn_ops.aten_celu),
TorchLibOpInfo("nn.functional.celu_type_promoted", nn_ops.aten_celu_type_promoted),
TorchLibOpInfo(
Expand Down
Loading