diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 5cfcde00fdada..559f4b0b9a31b 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -817,18 +817,8 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): # layer_norm normalizes on the last D dimensions, # where D is the size of normalized_shape axis = -len(normalized_shape) - scalar_type = _type_utils.JitScalarType.from_value( - input, _type_utils.JitScalarType.FLOAT - ) - dtype = scalar_type.dtype() - if symbolic_helper._is_none(weight): - weight_value = torch.ones(normalized_shape, dtype=dtype) - weight = g.op("Constant", value_t=weight_value) - if symbolic_helper._is_none(bias): - bias_value = torch.zeros(normalized_shape, dtype=dtype) - bias = g.op("Constant", value_t=bias_value) - - out = g.op( + + res, new_running_mean, new_running_var = g.op( "LayerNormalization", input, weight, @@ -840,5 +830,4 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): overload_name_s="vec", ) - res, new_running_mean, new_running_var = out - return res, new_running_mean, new_running_var + return res