diff --git a/compiler/luci-interpreter/src/kernels/RmsNorm.cpp b/compiler/luci-interpreter/src/kernels/RmsNorm.cpp index de0d0d858e5..18c791b475e 100644 --- a/compiler/luci-interpreter/src/kernels/RmsNorm.cpp +++ b/compiler/luci-interpreter/src/kernels/RmsNorm.cpp @@ -26,9 +26,9 @@ namespace luci_interpreter namespace kernels { -RmsNorm::RmsNorm(const Tensor *input, const Tensor *gamma, const Tensor *beta, Tensor *output, +RmsNorm::RmsNorm(const Tensor *input, const Tensor *gamma, Tensor *output, const RmsNormParams ¶ms) - : KernelWithParams({input, gamma, beta}, {output}, params) + : KernelWithParams({input, gamma}, {output}, params) { } @@ -38,13 +38,9 @@ void RmsNorm::configure() LUCI_INTERPRETER_CHECK(num_dims == 3 || num_dims == 4); LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type()); LUCI_INTERPRETER_CHECK(gamma()->element_type() == input()->element_type()); - LUCI_INTERPRETER_CHECK(beta()->element_type() == input()->element_type()); LUCI_INTERPRETER_CHECK(gamma()->shape().num_dims() == 1); - LUCI_INTERPRETER_CHECK(beta()->shape().num_dims() == 1); LUCI_INTERPRETER_CHECK((gamma()->shape().dim(0) == input()->shape().dim(num_dims - 1)) || (gamma()->shape().dim(0) == 1)); - LUCI_INTERPRETER_CHECK((beta()->shape().dim(0) == input()->shape().dim(num_dims - 1)) || - (beta()->shape().dim(0) == 1)); output()->resize(input()->shape()); } @@ -70,9 +66,6 @@ void RmsNorm::evalFloat() const const float *gamma_data = getTensorData(gamma()); auto gamma_shape = getTensorShape(gamma()); bool single_gamma = gamma_shape.DimensionsCount() == 1 && gamma_shape.Dims(0) == 1; - const float *beta_data = getTensorData(beta()); - auto beta_shape = getTensorShape(beta()); - bool single_beta = beta_shape.DimensionsCount() == 1 && beta_shape.Dims(0) == 1; float *output_data = getTensorData(output()); if (input_shape.DimensionsCount() == 4) @@ -99,11 +92,9 @@ void RmsNorm::evalFloat() const for (int32_t channel = 0; channel < channels; channel++) { double gamma = single_gamma ? gamma_data[0] : gamma_data[channel]; - double beta = single_beta ? beta_data[0] : beta_data[channel]; output_data[tflite::Offset(output_shape, batch, height, width, channel)] = - (gamma * - (input_data[tflite::Offset(input_shape, batch, height, width, channel)] / rms) + - beta); + gamma * + (input_data[tflite::Offset(input_shape, batch, height, width, channel)] / rms); } } } @@ -131,8 +122,7 @@ void RmsNorm::evalFloat() const for (int32_t i = 0; i < size; i++) { double gamma = single_gamma ? gamma_data[0] : gamma_data[i]; - double beta = single_beta ? beta_data[0] : beta_data[i]; - output_data[offset + i] = (gamma * (input_data[offset + i] / rms) + beta); + output_data[offset + i] = gamma * (input_data[offset + i] / rms); } } } diff --git a/compiler/luci-interpreter/src/kernels/RmsNorm.h b/compiler/luci-interpreter/src/kernels/RmsNorm.h index 66a58347a2a..c5c12cc0cee 100644 --- a/compiler/luci-interpreter/src/kernels/RmsNorm.h +++ b/compiler/luci-interpreter/src/kernels/RmsNorm.h @@ -28,12 +28,10 @@ namespace kernels class RmsNorm : public KernelWithParams { public: - RmsNorm(const Tensor *input, const Tensor *gamma, const Tensor *beta, Tensor *output, - const RmsNormParams ¶ms); + RmsNorm(const Tensor *input, const Tensor *gamma, Tensor *output, const RmsNormParams ¶ms); const Tensor *input() const { return _inputs[0]; } const Tensor *gamma() const { return _inputs[1]; } - const Tensor *beta() const { return _inputs[2]; } Tensor *output() const { return _outputs[0]; } void configure() override; diff --git a/compiler/luci-interpreter/src/kernels/RmsNorm.test.cpp b/compiler/luci-interpreter/src/kernels/RmsNorm.test.cpp index cd39b593f95..8f8429f8ac4 100644 --- a/compiler/luci-interpreter/src/kernels/RmsNorm.test.cpp +++ b/compiler/luci-interpreter/src/kernels/RmsNorm.test.cpp @@ -39,13 +39,12 @@ TEST_F(RmsNormTest, Simple) Tensor input_tensor = makeInputTensor({1, 2, 2, 1}, {0, 1, 2, 3}, _memory_manager.get()); Tensor gamma_tensor = makeInputTensor({1}, {1}, _memory_manager.get()); - Tensor beta_tensor = makeInputTensor({1}, {0}, _memory_manager.get()); Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); RmsNormParams params{}; params.epsilon = 0.00001f; - RmsNorm kernel(&input_tensor, &gamma_tensor, &beta_tensor, &output_tensor, params); + RmsNorm kernel(&input_tensor, &gamma_tensor, &output_tensor, params); kernel.configure(); _memory_manager->allocate_memory(output_tensor); kernel.execute(); @@ -54,18 +53,17 @@ TEST_F(RmsNormTest, Simple) EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 2, 2, 1})); } -TEST_F(RmsNormTest, Default_gamma_beta) +TEST_F(RmsNormTest, Default_gamma) { Tensor input_tensor = makeInputTensor({1, 2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}, _memory_manager.get()); Tensor gamma_tensor = makeInputTensor({1}, {1}, _memory_manager.get()); - Tensor beta_tensor = makeInputTensor({1}, {0}, _memory_manager.get()); Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); RmsNormParams params{}; params.epsilon = 0.001f; - RmsNorm kernel(&input_tensor, &gamma_tensor, &beta_tensor, &output_tensor, params); + RmsNorm kernel(&input_tensor, &gamma_tensor, &output_tensor, params); kernel.configure(); _memory_manager->allocate_memory(output_tensor); kernel.execute(); @@ -81,13 +79,12 @@ TEST_F(RmsNormTest, Have_gamma) Tensor input_tensor = makeInputTensor({1, 2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}, _memory_manager.get()); Tensor gamma_tensor = makeInputTensor({1}, {2}, _memory_manager.get()); - Tensor beta_tensor = makeInputTensor({1}, {0}, _memory_manager.get()); Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); RmsNormParams params{}; params.epsilon = 0.001f; - RmsNorm kernel(&input_tensor, &gamma_tensor, &beta_tensor, &output_tensor, params); + RmsNorm kernel(&input_tensor, &gamma_tensor, &output_tensor, params); kernel.configure(); _memory_manager->allocate_memory(output_tensor); kernel.execute(); @@ -98,18 +95,17 @@ TEST_F(RmsNormTest, Have_gamma) EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 2, 2, 2})); } -TEST_F(RmsNormTest, Wrong_gamma_beta_dim_NEG) +TEST_F(RmsNormTest, Wrong_gamma_dim_NEG) { Tensor input_tensor = makeInputTensor({1, 2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}, _memory_manager.get()); Tensor gamma_tensor = makeInputTensor({3}, {1, 1, 1}, _memory_manager.get()); - Tensor beta_tensor = makeInputTensor({3}, {0, 0, 0}, _memory_manager.get()); Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); RmsNormParams params{}; params.epsilon = 0.001f; - RmsNorm kernel(&input_tensor, &gamma_tensor, &beta_tensor, &output_tensor, params); + RmsNorm kernel(&input_tensor, &gamma_tensor, &output_tensor, params); EXPECT_ANY_THROW(kernel.configure()); } @@ -118,13 +114,12 @@ TEST_F(RmsNormTest, Unsupported_dims_NEG) Tensor input_tensor = makeInputTensor({2, 2}, {0, 1, 2, 3}, _memory_manager.get()); Tensor gamma_tensor = makeInputTensor({1}, {1}, _memory_manager.get()); - Tensor beta_tensor = makeInputTensor({1}, {0}, _memory_manager.get()); Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); RmsNormParams params{}; params.epsilon = 0.001f; - RmsNorm kernel(&input_tensor, &gamma_tensor, &beta_tensor, &output_tensor, params); + RmsNorm kernel(&input_tensor, &gamma_tensor, &output_tensor, params); EXPECT_ANY_THROW(kernel.configure()); } diff --git a/compiler/luci-interpreter/src/loader/KernelBuilder.test.cpp b/compiler/luci-interpreter/src/loader/KernelBuilder.test.cpp index 0a3c30fff1f..f6b9d0e701f 100644 --- a/compiler/luci-interpreter/src/loader/KernelBuilder.test.cpp +++ b/compiler/luci-interpreter/src/loader/KernelBuilder.test.cpp @@ -1101,12 +1101,10 @@ TEST_F(KernelBuilderTest, RmsNorm) { auto *input = createInputNode(); auto *gamma = createInputNode(); - auto *beta = createInputNode(); auto *op = createNode(); op->input(input); op->gamma(gamma); - op->beta(beta); op->epsilon(1e-06); auto kernel = buildKernel(op); @@ -1114,7 +1112,6 @@ TEST_F(KernelBuilderTest, RmsNorm) checkTensor(kernel->input(), input); checkTensor(kernel->gamma(), gamma); - checkTensor(kernel->beta(), beta); checkTensor(kernel->output(), op); EXPECT_THAT(kernel->params().epsilon, Eq(op->epsilon())); } diff --git a/compiler/luci-interpreter/src/loader/nodes/RmsNorm.cpp b/compiler/luci-interpreter/src/loader/nodes/RmsNorm.cpp index 4ce54751757..d817a250cab 100644 --- a/compiler/luci-interpreter/src/loader/nodes/RmsNorm.cpp +++ b/compiler/luci-interpreter/src/loader/nodes/RmsNorm.cpp @@ -25,18 +25,17 @@ std::unique_ptr build_kernel_CircleRmsNorm(const luci::CircleNode *circl KernelBuilderHelper &helper) { const auto *node = loco::must_cast(circle_node); - assert(node->arity() == 3); + // assert(node->arity() == 2); const Tensor *input = helper.getInputTensor(node->input()); const Tensor *gamma = helper.getInputTensor(node->gamma()); - const Tensor *beta = helper.getInputTensor(node->beta()); Tensor *output = helper.getOutputTensor(node); RmsNormParams params{}; params.epsilon = node->epsilon(); - return std::make_unique(input, gamma, beta, output, params); + return std::make_unique(input, gamma, output, params); } } // namespace luci_interpreter