Skip to content

Commit

Permalink
Everything works
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 11, 2023
1 parent 8b492e8 commit 9e1ce2b
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 24 deletions.
6 changes: 0 additions & 6 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,6 @@ class AttentionCPUBase : public AttentionBase {
});
}

std::cout << "Probs before softmax.";
for (size_t i = 0; i < total_sequence_length * sequence_length * batch_size * num_heads_; ++i) {
std::cout << attention_probs[i] << " ";
}
std::cout << std::endl;

// attention_probs(B, N, S, T) = Softmax(attention_probs)
{
const int N = batch_size * num_heads_ * sequence_length;
Expand Down
18 changes: 0 additions & 18 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,30 +139,12 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat
});
}

std::cout << "After bias add.";
std::cout << std::endl;
auto tensor = qkv_with_bias.GetMutable<Tensor>();
auto data = tensor->MutableData<float>();
for (size_t i = 0; i < batch_size * sequence_length * hidden_size; ++i) {
std::cout << data[i] << " ";
}
std::cout << std::endl;

// Reshape Q from BxSxD to BxSxNxH
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable<Tensor>(), batch_size, sequence_length, num_heads, head_size));

// Transpose Q from BxSxNxH to BxNxSxH
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable<Tensor>(), qkv_with_bias_transposed));

std::cout << "After transpose.";
std::cout << std::endl;
tensor = qkv_with_bias_transposed.GetMutable<Tensor>();
data = tensor->MutableData<float>();
for (size_t i = 0; i < batch_size * sequence_length * hidden_size; ++i) {
std::cout << data[i] << " ";
}
std::cout << std::endl;

return Status::OK();
}

Expand Down

0 comments on commit 9e1ce2b

Please sign in to comment.