Skip to content

Commit

Permalink
override default layer_norm export num outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
Prathik Rao committed Feb 13, 2024
1 parent f74221a commit 4419d11
Showing 1 changed file with 34 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -808,3 +808,37 @@ def upsample_nearest2d(g, input, output_size, scale_factors):
@register_symbolic("upsample_nearest3d")
def upsample_nearest3d(g, input, output_size, scale_factors):
return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d")

@register_symbolic("layer_norm")
@parse_args("v", "is", "v", "v", "f", "none")
def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
# normalized_shape: input shape from an expected input of size
# axis: The first normalization dimension.
# layer_norm normalizes on the last D dimensions,
# where D is the size of normalized_shape
axis = -len(normalized_shape)
scalar_type = _type_utils.JitScalarType.from_value(
input, _type_utils.JitScalarType.FLOAT
)
dtype = scalar_type.dtype()
if symbolic_helper._is_none(weight):
weight_value = torch.ones(normalized_shape, dtype=dtype)
weight = g.op("Constant", value_t=weight_value)
if symbolic_helper._is_none(bias):
bias_value = torch.zeros(normalized_shape, dtype=dtype)
bias = g.op("Constant", value_t=bias_value)

out = g.op(
"LayerNormalization",
input,
weight,
bias,
epsilon_f=eps,
axis_i=axis,
outputs=3, # force all 3 outputs to be exported in training mode
operator_s="layer_norm",
overload_name_s="vec",
)

res, new_running_mean, new_running_var = out
return res, new_running_mean, new_running_var

0 comments on commit 4419d11

Please sign in to comment.