Skip to content

Commit

Permalink
refactor ut
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Jul 16, 2024
1 parent 9426e06 commit 64a7492
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
17 changes: 17 additions & 0 deletions onnxruntime/test/common/random_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "core/common/common.h"
#include "core/common/optional.h"
#include "core/common/type_utils.h"
#include "core/framework/int4.h"
#include "test/util/include/test_random_seed.h"

namespace onnxruntime {
Expand Down Expand Up @@ -108,6 +109,22 @@ class RandomValueGenerator {
return val;
}

template <typename TInt4>
typename std::enable_if<
std::is_same_v<TInt4, Int4x2> || std::is_same_v<TInt4, UInt4x2>,
std::vector<TInt4>>::type
Uniform(gsl::span<const int64_t> dims, TInt4 min, TInt4 max) {
using UnpackedType = typename TInt4::UnpackedType;
std::vector<UnpackedType> data_int8 = Uniform<UnpackedType>(dims, min.GetElem(0), max.GetElem(0));
std::vector<TInt4> data(TInt4::CalcNumInt4Pairs(data_int8.size()));
for (size_t i = 0; i < data_int8.size(); i++) {
size_t r = i >> 1;
size_t c = i & 0x1;
data[r].SetElem(c, data_int8[i]);
}
return data;
}

// Gaussian distribution for float
template <typename TFloat>
typename std::enable_if<
Expand Down
16 changes: 0 additions & 16 deletions onnxruntime/test/optimizer/graph_transform_test_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,22 +116,6 @@ class ModelTestBuilder {
return MakeInput<bool>(shape, data);
}

template <typename TInt4>
typename std::enable_if<
std::is_same_v<TInt4, Int4x2> || std::is_same_v<TInt4, UInt4x2>,
NodeArg*>::type
MakeInputInt4(const std::vector<int64_t>& shape, typename TInt4::UnpackedType min, typename TInt4::UnpackedType max) {
using UnpackedType = typename TInt4::UnpackedType;
std::vector<UnpackedType> data_int8 = rand_gen_.Uniform<UnpackedType>(shape, min, max);
std::vector<TInt4> data(TInt4::CalcNumInt4Pairs(data_int8.size()));
for (size_t i = 0; i < data_int8.size(); i++) {
size_t r = i >> 1;
size_t c = i & 0x1;
data[r].SetElem(c, data_int8[i]);
}
return MakeInput<TInt4>(shape, data);
}

template <typename T>
NodeArg* MakeInput(const std::optional<std::vector<int64_t>>& shape,
std::optional<std::string> input_name = std::nullopt) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/optimizer/qdq_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ GetQDQTestCaseFn BuildQDQSplitTestCase(const std::vector<int64_t>& input_shape,
NodeArg* input_arg = nullptr;

if constexpr (std::is_same_v<InputType, Int4x2> || std::is_same_v<InputType, UInt4x2>) {
input_arg = builder.MakeInputInt4<InputType>(input_shape, InputType::min_val, InputType::max_val);
input_arg = builder.MakeInput(input_shape, InputType(InputType::min_val, 0), InputType(InputType::max_val, 0));
dq_zp = InputType(static_cast<std::byte>(InputType::max_val / 2));
q_zp = OutputType(static_cast<std::byte>(OutputType::max_val / 2));
} else {
Expand Down

0 comments on commit 64a7492

Please sign in to comment.