Skip to content

Commit

Permalink
layernorm fix
Browse files Browse the repository at this point in the history
  • Loading branch information
prathikr committed Feb 13, 2024
1 parent 4419d11 commit 21054b2
Showing 1 changed file with 3 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -817,18 +817,8 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
# layer_norm normalizes on the last D dimensions,
# where D is the size of normalized_shape
axis = -len(normalized_shape)
scalar_type = _type_utils.JitScalarType.from_value(
input, _type_utils.JitScalarType.FLOAT
)
dtype = scalar_type.dtype()
if symbolic_helper._is_none(weight):
weight_value = torch.ones(normalized_shape, dtype=dtype)
weight = g.op("Constant", value_t=weight_value)
if symbolic_helper._is_none(bias):
bias_value = torch.zeros(normalized_shape, dtype=dtype)
bias = g.op("Constant", value_t=bias_value)

out = g.op(

res, new_running_mean, new_running_var = g.op(
"LayerNormalization",
input,
weight,
Expand All @@ -840,5 +830,4 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
overload_name_s="vec",
)

res, new_running_mean, new_running_var = out
return res, new_running_mean, new_running_var
return res

0 comments on commit 21054b2

Please sign in to comment.