diff --git a/tests/ttnn/unit_tests/operations/test_rms_norm.py b/tests/ttnn/unit_tests/operations/test_rms_norm.py index 16d35f1ab7c..fe7327a7106 100644 --- a/tests/ttnn/unit_tests/operations/test_rms_norm.py +++ b/tests/ttnn/unit_tests/operations/test_rms_norm.py @@ -22,7 +22,7 @@ def test_rms_norm(device, batch_size, h, w): torch_output_tensor = ttnn.rms_norm.golden_function(torch_input_tensor, torch_weight) input_tensor = ttnn.from_torch(torch_input_tensor, device=device, layout=ttnn.TILE_LAYOUT) - weight = ttnn.from_torch(torch_weight, device=device) + weight = ttnn.from_torch(torch_weight, device=device, layout=ttnn.TILE_LAYOUT) output_tensor = ttnn.rms_norm(input_tensor, weight) output_tensor = ttnn.from_device(output_tensor) output_tensor = ttnn.to_torch(output_tensor) diff --git a/ttnn/cpp/pybind11/operations/normalization.hpp b/ttnn/cpp/pybind11/operations/normalization.hpp index 6d1828229cb..4b774d1a36e 100644 --- a/ttnn/cpp/pybind11/operations/normalization.hpp +++ b/ttnn/cpp/pybind11/operations/normalization.hpp @@ -32,6 +32,14 @@ void py_module(py::module& module) { Compute layer_norm over :attr:`input_tensor`. )doc"); + module.def("rms_norm", &rms_norm, + py::arg("input_tensor"), + py::arg("weight"), + py::kw_only(), + py::arg("epsilon") = 1e-12, + R"doc( +Compute rms_norm over :attr:`input_tensor`. + )doc"); } } // namespace normalization diff --git a/ttnn/cpp/ttnn/operations/normalization.hpp b/ttnn/cpp/ttnn/operations/normalization.hpp index 0bd652dbf13..7764d02d207 100644 --- a/ttnn/cpp/ttnn/operations/normalization.hpp +++ b/ttnn/cpp/ttnn/operations/normalization.hpp @@ -18,7 +18,6 @@ inline ttnn::Tensor layer_norm( std::optional& bias, std::optional& residual_input_tensor, const MemoryConfig& memory_config, - //std::optional& program_config std::optional& program_config ) { @@ -32,6 +31,15 @@ inline ttnn::Tensor layer_norm( } } +inline ttnn::Tensor rms_norm( + const ttnn::Tensor& input_tensor, + const ttnn::Tensor& weight, + float epsilon = 1e-6 +) { + const MemoryConfig & dram_memory_config = tt::tt_metal::MemoryConfig{.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED,.buffer_type=tt::tt_metal::BufferType::DRAM}; + return tt::operations::primary::rmsnorm(input_tensor, epsilon, std::optional(weight), std::nullopt, dram_memory_config); +} + } // namespace normalization } // namespace operations } // namespace ttnn diff --git a/ttnn/ttnn/operations/normalization.py b/ttnn/ttnn/operations/normalization.py index 689137b9413..e649f9c3ad1 100644 --- a/ttnn/ttnn/operations/normalization.py +++ b/ttnn/ttnn/operations/normalization.py @@ -127,12 +127,12 @@ def _golden_function(input_tensor: ttnn.Tensor, weight=None, *, epsilon=1e-12, * return weight * input_tensor -@ttnn.register_operation( - name="ttnn.rms_norm", - validate_input_tensors=_rms_norm_validate_input_tensors, - golden_function=_golden_function, +rms_norm = ttnn.register_operation(name="ttnn.rms_norm", is_cpp_function=True, golden_function=_golden_function)( + ttnn._ttnn.operations.normalization.rms_norm ) -def rms_norm(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, epsilon: float = 1e-6) -> ttnn.Tensor: + + +def _rms_norm(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, epsilon: float = 1e-6) -> ttnn.Tensor: r""" rms_norm(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, epsilon: float = 1e-6) -> ttnn.Tensor