Skip to content

Commit

Permalink
remaining tests
Browse files Browse the repository at this point in the history
  • Loading branch information
richagadgil committed Nov 23, 2024
1 parent 6e87c99 commit 141bd27
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 0 deletions.
56 changes: 56 additions & 0 deletions test/onnx/parse/gemm_bf16_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>
#include <migraphx/apply_alpha_beta.hpp>

TEST_CASE(gemm_bf16_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::bf16_type, {8, 6}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::bf16_type, {8, 7}});
auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::bf16_type, {6, 1}});
auto alpha = 0.5f;
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bf16_type}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
std::vector<std::size_t> lens = {6, 7};
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2);
l2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2);
auto b_l = mm->add_literal(beta);
auto b_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
l2_b = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bf16_type}}), l2_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_b);

auto prog = optimize_onnx("gemm_bf16_test.onnx");
EXPECT(p == prog);
}
34 changes: 34 additions & 0 deletions test/onnx/parse/group_norm_3d_bf16_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>

TEST_CASE(group_norm_3d_bf16_test)
{
migraphx::program p = make_group_norm(
{1, 4, 2}, {2}, {2}, {1, 2, 2, 2}, {2, 3}, 1e-5f, migraphx::shape::bf16_type);
auto prog = optimize_onnx("group_norm_3d_bf16_test.onnx");
EXPECT(p == prog);
}
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,88 @@ TEST_CASE(negativeloglikelihoodloss_kd_mean_reduction_half_weighted_test)
optimize_onnx("negativeloglikelihoodloss_kd_mean_reduction_half_weighted_test.onnx");
EXPECT(p == prog);
}

TEST_CASE(negativeloglikelihoodloss_kd_mean_reduction_bf16_weighted_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
size_t batch_size = 4;
size_t class_size = 4;

auto scores = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::bf16_type, {batch_size, class_size, 2, 2}});
auto labels =
mm->add_parameter("1", migraphx::shape{migraphx::shape::int32_type, {class_size, 2, 2}});
auto weights =
mm->add_parameter("2", migraphx::shape{migraphx::shape::bf16_type, {class_size}});

auto weights_dflt = mm->add_literal(
migraphx::literal(migraphx::shape(migraphx::shape::bf16_type, {1}, {0}), {1}));
auto labels_idx = mm->add_literal(migraphx::literal(
migraphx::shape(migraphx::shape::int32_type, {class_size}, {1}), {0, 1, 2, 3}));

// Index variables used for gather on k dimensions that span their dimension
auto kd_1 = mm->add_literal(
migraphx::literal(migraphx::shape(migraphx::shape::int32_type, {2}, {1}), {0, 1}));
auto kd_2 = mm->add_literal(
migraphx::literal(migraphx::shape(migraphx::shape::int32_type, {2}, {1}), {0, 1}));

mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {class_size}}}),
weights_dflt);

weights = mm->add_instruction(migraphx::make_op("neg"), weights);
auto unsq_labels =
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {-1}}}), labels);

auto unsq_labels_idx =
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3}}}), labels_idx);
auto bc_unsq_labels_idx = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", unsq_labels->get_shape().lens()}}),
unsq_labels_idx);

auto unsq_labels_idx2 =
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 2, 3}}}), kd_1);
auto bc_unsq_labels_idx2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", unsq_labels->get_shape().lens()}}),
unsq_labels_idx2);

auto unsq_labels_idx3 =
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1, 3}}}), kd_2);
auto bc_unsq_labels_idx3 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", unsq_labels->get_shape().lens()}}),
unsq_labels_idx3);

auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", -1}}),
bc_unsq_labels_idx,
bc_unsq_labels_idx2,
bc_unsq_labels_idx3,
unsq_labels);

auto transpose = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), scores);

auto gathernd = mm->add_instruction(migraphx::make_op("gathernd"), transpose, concat);
auto unsq_mb_weights =
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 2, 3}}}), weights);
auto unsq_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", scores->get_shape().lens()}}),
unsq_mb_weights);
auto transpose2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), unsq_mb);
auto gathernd2 = mm->add_instruction(migraphx::make_op("gathernd"), transpose2, concat);

auto weighted_loss = mm->add_instruction(migraphx::make_op("mul"), gathernd, gathernd2);

auto loss_x =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1, 2}}}), weighted_loss);
auto loss_w =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1, 2}}}), gathernd2);

loss_w = mm->add_instruction(migraphx::make_op("neg"), loss_w);

mm->add_instruction(migraphx::make_op("div"), loss_x, loss_w);

auto prog =
optimize_onnx("negativeloglikelihoodloss_kd_mean_reduction_bf16_weighted_test.onnx");
EXPECT(p == prog);
}

Check warning on line 350 in test/onnx/parse/negativeloglikelihoodloss_kd_all_reduction_weighted_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

no newline at end of file [clang-diagnostic-newline-eof,-warnings-as-errors]

0 comments on commit 141bd27

Please sign in to comment.