Skip to content

Commit

Permalink
[luci-interpreter] Removed beta(bias) of RmsNorm (#14185)
Browse files Browse the repository at this point in the history
This commit removes beta(bias) of RmsNorm in kernel and loader.

ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
  • Loading branch information
seockho-kim authored Oct 10, 2024
1 parent 2dc21f4 commit 8c1c9fa
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 36 deletions.
20 changes: 5 additions & 15 deletions compiler/luci-interpreter/src/kernels/RmsNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ namespace luci_interpreter
namespace kernels
{

RmsNorm::RmsNorm(const Tensor *input, const Tensor *gamma, const Tensor *beta, Tensor *output,
RmsNorm::RmsNorm(const Tensor *input, const Tensor *gamma, Tensor *output,
const RmsNormParams &params)
: KernelWithParams<RmsNormParams>({input, gamma, beta}, {output}, params)
: KernelWithParams<RmsNormParams>({input, gamma}, {output}, params)
{
}

Expand All @@ -38,13 +38,9 @@ void RmsNorm::configure()
LUCI_INTERPRETER_CHECK(num_dims == 3 || num_dims == 4);
LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
LUCI_INTERPRETER_CHECK(gamma()->element_type() == input()->element_type());
LUCI_INTERPRETER_CHECK(beta()->element_type() == input()->element_type());
LUCI_INTERPRETER_CHECK(gamma()->shape().num_dims() == 1);
LUCI_INTERPRETER_CHECK(beta()->shape().num_dims() == 1);
LUCI_INTERPRETER_CHECK((gamma()->shape().dim(0) == input()->shape().dim(num_dims - 1)) ||
(gamma()->shape().dim(0) == 1));
LUCI_INTERPRETER_CHECK((beta()->shape().dim(0) == input()->shape().dim(num_dims - 1)) ||
(beta()->shape().dim(0) == 1));

output()->resize(input()->shape());
}
Expand All @@ -70,9 +66,6 @@ void RmsNorm::evalFloat() const
const float *gamma_data = getTensorData<float>(gamma());
auto gamma_shape = getTensorShape(gamma());
bool single_gamma = gamma_shape.DimensionsCount() == 1 && gamma_shape.Dims(0) == 1;
const float *beta_data = getTensorData<float>(beta());
auto beta_shape = getTensorShape(beta());
bool single_beta = beta_shape.DimensionsCount() == 1 && beta_shape.Dims(0) == 1;
float *output_data = getTensorData<float>(output());

if (input_shape.DimensionsCount() == 4)
Expand All @@ -99,11 +92,9 @@ void RmsNorm::evalFloat() const
for (int32_t channel = 0; channel < channels; channel++)
{
double gamma = single_gamma ? gamma_data[0] : gamma_data[channel];
double beta = single_beta ? beta_data[0] : beta_data[channel];
output_data[tflite::Offset(output_shape, batch, height, width, channel)] =
(gamma *
(input_data[tflite::Offset(input_shape, batch, height, width, channel)] / rms) +
beta);
gamma *
(input_data[tflite::Offset(input_shape, batch, height, width, channel)] / rms);
}
}
}
Expand Down Expand Up @@ -131,8 +122,7 @@ void RmsNorm::evalFloat() const
for (int32_t i = 0; i < size; i++)
{
double gamma = single_gamma ? gamma_data[0] : gamma_data[i];
double beta = single_beta ? beta_data[0] : beta_data[i];
output_data[offset + i] = (gamma * (input_data[offset + i] / rms) + beta);
output_data[offset + i] = gamma * (input_data[offset + i] / rms);
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions compiler/luci-interpreter/src/kernels/RmsNorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@ namespace kernels
class RmsNorm : public KernelWithParams<RmsNormParams>
{
public:
RmsNorm(const Tensor *input, const Tensor *gamma, const Tensor *beta, Tensor *output,
const RmsNormParams &params);
RmsNorm(const Tensor *input, const Tensor *gamma, Tensor *output, const RmsNormParams &params);

const Tensor *input() const { return _inputs[0]; }
const Tensor *gamma() const { return _inputs[1]; }
const Tensor *beta() const { return _inputs[2]; }
Tensor *output() const { return _outputs[0]; }

void configure() override;
Expand Down
19 changes: 7 additions & 12 deletions compiler/luci-interpreter/src/kernels/RmsNorm.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,12 @@ TEST_F(RmsNormTest, Simple)
Tensor input_tensor =
makeInputTensor<DataType::FLOAT32>({1, 2, 2, 1}, {0, 1, 2, 3}, _memory_manager.get());
Tensor gamma_tensor = makeInputTensor<DataType::FLOAT32>({1}, {1}, _memory_manager.get());
Tensor beta_tensor = makeInputTensor<DataType::FLOAT32>({1}, {0}, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);

RmsNormParams params{};
params.epsilon = 0.00001f;

RmsNorm kernel(&input_tensor, &gamma_tensor, &beta_tensor, &output_tensor, params);
RmsNorm kernel(&input_tensor, &gamma_tensor, &output_tensor, params);
kernel.configure();
_memory_manager->allocate_memory(output_tensor);
kernel.execute();
Expand All @@ -54,18 +53,17 @@ TEST_F(RmsNormTest, Simple)
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 2, 2, 1}));
}

TEST_F(RmsNormTest, Default_gamma_beta)
TEST_F(RmsNormTest, Default_gamma)
{
Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({1, 2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7},
_memory_manager.get());
Tensor gamma_tensor = makeInputTensor<DataType::FLOAT32>({1}, {1}, _memory_manager.get());
Tensor beta_tensor = makeInputTensor<DataType::FLOAT32>({1}, {0}, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);

RmsNormParams params{};
params.epsilon = 0.001f;

RmsNorm kernel(&input_tensor, &gamma_tensor, &beta_tensor, &output_tensor, params);
RmsNorm kernel(&input_tensor, &gamma_tensor, &output_tensor, params);
kernel.configure();
_memory_manager->allocate_memory(output_tensor);
kernel.execute();
Expand All @@ -81,13 +79,12 @@ TEST_F(RmsNormTest, Have_gamma)
Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({1, 2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7},
_memory_manager.get());
Tensor gamma_tensor = makeInputTensor<DataType::FLOAT32>({1}, {2}, _memory_manager.get());
Tensor beta_tensor = makeInputTensor<DataType::FLOAT32>({1}, {0}, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);

RmsNormParams params{};
params.epsilon = 0.001f;

RmsNorm kernel(&input_tensor, &gamma_tensor, &beta_tensor, &output_tensor, params);
RmsNorm kernel(&input_tensor, &gamma_tensor, &output_tensor, params);
kernel.configure();
_memory_manager->allocate_memory(output_tensor);
kernel.execute();
Expand All @@ -98,18 +95,17 @@ TEST_F(RmsNormTest, Have_gamma)
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 2, 2, 2}));
}

TEST_F(RmsNormTest, Wrong_gamma_beta_dim_NEG)
TEST_F(RmsNormTest, Wrong_gamma_dim_NEG)
{
Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({1, 2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7},
_memory_manager.get());
Tensor gamma_tensor = makeInputTensor<DataType::FLOAT32>({3}, {1, 1, 1}, _memory_manager.get());
Tensor beta_tensor = makeInputTensor<DataType::FLOAT32>({3}, {0, 0, 0}, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);

RmsNormParams params{};
params.epsilon = 0.001f;

RmsNorm kernel(&input_tensor, &gamma_tensor, &beta_tensor, &output_tensor, params);
RmsNorm kernel(&input_tensor, &gamma_tensor, &output_tensor, params);
EXPECT_ANY_THROW(kernel.configure());
}

Expand All @@ -118,13 +114,12 @@ TEST_F(RmsNormTest, Unsupported_dims_NEG)
Tensor input_tensor =
makeInputTensor<DataType::FLOAT32>({2, 2}, {0, 1, 2, 3}, _memory_manager.get());
Tensor gamma_tensor = makeInputTensor<DataType::FLOAT32>({1}, {1}, _memory_manager.get());
Tensor beta_tensor = makeInputTensor<DataType::FLOAT32>({1}, {0}, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);

RmsNormParams params{};
params.epsilon = 0.001f;

RmsNorm kernel(&input_tensor, &gamma_tensor, &beta_tensor, &output_tensor, params);
RmsNorm kernel(&input_tensor, &gamma_tensor, &output_tensor, params);
EXPECT_ANY_THROW(kernel.configure());
}

Expand Down
3 changes: 0 additions & 3 deletions compiler/luci-interpreter/src/loader/KernelBuilder.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1101,20 +1101,17 @@ TEST_F(KernelBuilderTest, RmsNorm)
{
auto *input = createInputNode();
auto *gamma = createInputNode();
auto *beta = createInputNode();

auto *op = createNode<luci::CircleRmsNorm>();
op->input(input);
op->gamma(gamma);
op->beta(beta);
op->epsilon(1e-06);

auto kernel = buildKernel<kernels::RmsNorm>(op);
ASSERT_THAT(kernel, NotNull());

checkTensor(kernel->input(), input);
checkTensor(kernel->gamma(), gamma);
checkTensor(kernel->beta(), beta);
checkTensor(kernel->output(), op);
EXPECT_THAT(kernel->params().epsilon, Eq(op->epsilon()));
}
Expand Down
5 changes: 2 additions & 3 deletions compiler/luci-interpreter/src/loader/nodes/RmsNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,17 @@ std::unique_ptr<Kernel> build_kernel_CircleRmsNorm(const luci::CircleNode *circl
KernelBuilderHelper &helper)
{
const auto *node = loco::must_cast<const luci::CircleRmsNorm *>(circle_node);
assert(node->arity() == 3);
// assert(node->arity() == 2);

const Tensor *input = helper.getInputTensor(node->input());
const Tensor *gamma = helper.getInputTensor(node->gamma());
const Tensor *beta = helper.getInputTensor(node->beta());

Tensor *output = helper.getOutputTensor(node);

RmsNormParams params{};
params.epsilon = node->epsilon();

return std::make_unique<kernels::RmsNorm>(input, gamma, beta, output, params);
return std::make_unique<kernels::RmsNorm>(input, gamma, output, params);
}

} // namespace luci_interpreter

0 comments on commit 8c1c9fa

Please sign in to comment.