Skip to content

Commit

Permalink
#8111: Port ttnn.rms_norm to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
xanderchin committed May 7, 2024
1 parent cc1ab0d commit 7e60195
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 7 deletions.
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_rms_norm(device, batch_size, h, w):
torch_output_tensor = ttnn.rms_norm.golden_function(torch_input_tensor, torch_weight)

input_tensor = ttnn.from_torch(torch_input_tensor, device=device, layout=ttnn.TILE_LAYOUT)
weight = ttnn.from_torch(torch_weight, device=device)
weight = ttnn.from_torch(torch_weight, device=device, layout=ttnn.TILE_LAYOUT)
output_tensor = ttnn.rms_norm(input_tensor, weight)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)
Expand Down
8 changes: 8 additions & 0 deletions ttnn/cpp/pybind11/operations/normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ void py_module(py::module& module) {
Compute layer_norm over :attr:`input_tensor`.
)doc");

module.def("rms_norm", &rms_norm,
py::arg("input_tensor"),
py::arg("weight"),
py::kw_only(),
py::arg("epsilon") = 1e-12,
R"doc(
Compute rms_norm over :attr:`input_tensor`.
)doc");
}

} // namespace normalization
Expand Down
10 changes: 9 additions & 1 deletion ttnn/cpp/ttnn/operations/normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ inline ttnn::Tensor layer_norm(
std::optional<const ttnn::Tensor>& bias,
std::optional<const ttnn::Tensor>& residual_input_tensor,
const MemoryConfig& memory_config,
//std::optional<const LayerNormShardedMultiCoreProgramConfig>& program_config
std::optional<const LayerNormProgramConfig>& program_config
) {

Expand All @@ -32,6 +31,15 @@ inline ttnn::Tensor layer_norm(
}
}

inline ttnn::Tensor rms_norm(
const ttnn::Tensor& input_tensor,
const ttnn::Tensor& weight,
float epsilon = 1e-6
) {
const MemoryConfig & dram_memory_config = tt::tt_metal::MemoryConfig{.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED,.buffer_type=tt::tt_metal::BufferType::DRAM};
return tt::operations::primary::rmsnorm(input_tensor, epsilon, std::optional<const ttnn::Tensor>(weight), std::nullopt, dram_memory_config);
}

} // namespace normalization
} // namespace operations
} // namespace ttnn
10 changes: 5 additions & 5 deletions ttnn/ttnn/operations/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def _golden_function(input_tensor: ttnn.Tensor, weight=None, *, epsilon=1e-12, *
return weight * input_tensor


@ttnn.register_operation(
name="ttnn.rms_norm",
validate_input_tensors=_rms_norm_validate_input_tensors,
golden_function=_golden_function,
rms_norm = ttnn.register_operation(name="ttnn.rms_norm", is_cpp_function=True, golden_function=_golden_function)(
ttnn._ttnn.operations.normalization.rms_norm
)
def rms_norm(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, epsilon: float = 1e-6) -> ttnn.Tensor:


def _rms_norm(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, epsilon: float = 1e-6) -> ttnn.Tensor:
r"""
rms_norm(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, epsilon: float = 1e-6) -> ttnn.Tensor
Expand Down

0 comments on commit 7e60195

Please sign in to comment.