Skip to content

Commit

Permalink
try fixing Mac CI
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Jul 17, 2024
1 parent e8ce6b9 commit a82b7a0
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select
}
}

DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool)
DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool)
: accuracy_level_{accuracy_level},
domain_{kMSDomain},
op_type_{"MatMulNBits"},
Expand All @@ -291,7 +291,7 @@ DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_
}

NodeAttributes
DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const RuntimeState& runtime_state) const {
DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) const {
NodeAttributes extra_attributes;

const auto* dq_node = runtime_state.selected_nodes.Input(0);
Expand All @@ -308,9 +308,9 @@ DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const RuntimeState& runtime_stat
return extra_attributes;
}

Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph,
const NodesToOptimize& selected_nodes,
Node& replacement_node) const {
Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph,
const NodesToOptimize& selected_nodes,
Node& replacement_node) const {
ORT_RETURN_IF_NOT(intra_op_thread_pool_, "Passed in thread pool should not be null");
const auto* dq_node = selected_nodes.Input(0);
const auto* weight_arg = dq_node->InputDefs()[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ struct MatMulReplaceWithQLinear : public Action {
};

// used together with DQMatMulNodeGroupSelector, which does the sanity check
struct DQMatMulReplaceWithMatMulNBits : public ReplaceWithNew {
DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool);
struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew {
DQMatMulToMatMulNBitsAction(int64_t accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool);

private:
std::string OpType(const RuntimeState&) const override { return op_type_; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,11 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi
// 2 nodes. DQ -> MatMul. DQ is the second input to MatMul.
// DQ's weight is int4/uint4. DQ's scale is float/float16.
// DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power.
const std::string action_name{"DQMatMul"};
const std::string action_name{"DQMatMulToMatMulNBits"};

std::unique_ptr<Action> action =
std::make_unique<QDQ::DQMatMulReplaceWithMatMulNBits>(qdq_matmulnbits_accuracy_level,
intra_op_thread_pool);
std::make_unique<QDQ::DQMatMulToMatMulNBitsAction>(qdq_matmulnbits_accuracy_level,
intra_op_thread_pool);

#if !defined(ORT_MINIMAL_BUILD)
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::DQMatMulToMatMulNBitsSelector>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,9 @@ RunDQMatMulConverted(const std::vector<int64_t>& input1_shape,
}

TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits) {
if constexpr (!SessionOptions::DEFAULT_USE_PER_SESSION_THREADS) {
GTEST_SKIP() << "Skipping the test";
}
// DQ contrib op schema is not updated to support blocked quantization
RunDQMatMulConverted<Int4x2, true>({12, 12}, {12, 37}, {37, 12}, 0, 16, 0);
RunDQMatMulConverted<Int4x2, false>({12, 12}, {12, 37}, {37, 12}, 0, 16, 0);
Expand Down

0 comments on commit a82b7a0

Please sign in to comment.