Skip to content

Commit

Permalink
fixed outdated assert
Browse files Browse the repository at this point in the history
  • Loading branch information
yingchen21 committed Aug 7, 2024
1 parent e590af8 commit 4acab6c
Showing 1 changed file with 3 additions and 19 deletions.
22 changes: 3 additions & 19 deletions src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -447,16 +447,10 @@ __host__ void
case OP_INC_MULTIHEAD_SELF_ATTENTION: {
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_outputs[op] == 1);
assert(fused->op_num_weights[op] == 0);
IncMultiHeadSelfAttentionMeta *m =
(IncMultiHeadSelfAttentionMeta *)metas->meta[op];
// TODO: why is op_num_weight still non-zero?
assert(fused->op_num_weights[op] ==
(1 + (int)(*m->qkv_bias || *m->final_bias)));
GenericTensorAccessorR biases;
if (*m->qkv_bias || *m->final_bias) {
assert(fused->op_num_weights[op] == 2);
biases = my_weight_accessor[1];
}
IncMultiHeadSelfAttention::inference_kernel_wrapper(
m,
bc,
Expand All @@ -469,17 +463,12 @@ __host__ void
case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: {
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_outputs[op] == 1);
assert(fused->op_num_weights[op] == 0);
TreeIncMultiHeadSelfAttentionMeta *m =
(TreeIncMultiHeadSelfAttentionMeta *)metas->meta[op];
TreeVerifyBatchConfig const &tree_bc =
Future(task->futures[0]).get_result<TreeVerifyBatchConfig>();
assert(fused->op_num_weights[op] ==
(1 + (int)(*m->qkv_bias || *m->final_bias)));
GenericTensorAccessorR biases;
if (*m->qkv_bias || *m->final_bias) {
assert(fused->op_num_weights[op] == 2);
biases = my_weight_accessor[1];
}
TreeIncMultiHeadSelfAttention::inference_kernel_wrapper(
m,
&tree_bc,
Expand All @@ -491,19 +480,14 @@ __host__ void
case OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION: {
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_outputs[op] == 1);
assert(fused->op_num_weights[op] == 0);
SpecIncMultiHeadSelfAttentionMeta const *m =
(SpecIncMultiHeadSelfAttentionMeta *)metas->meta[op];
// BeamSearchBatchConfig const *beam_bc =
// (BeamSearchBatchConfig *)task->args;
BeamSearchBatchConfig const &beam_bc =
Future(task->futures[0]).get_result<BeamSearchBatchConfig>();
assert(fused->op_num_weights[op] ==
(1 + (int)(*m->qkv_bias || *m->final_bias)));
GenericTensorAccessorR biases;
if (*m->qkv_bias || *m->final_bias) {
assert(fused->op_num_weights[op] == 2);
biases = my_weight_accessor[1];
}
SpecIncMultiHeadSelfAttention::inference_kernel_wrapper(
m,
&beam_bc,
Expand Down

0 comments on commit 4acab6c

Please sign in to comment.