From 4acab6c560ae52e7f57e256b103bc49b4229644e Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 Aug 2024 23:36:32 +0000 Subject: [PATCH] fixed outdated assert --- src/ops/fused.cu | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 782bd9fbe6..509a423f52 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -447,16 +447,10 @@ __host__ void case OP_INC_MULTIHEAD_SELF_ATTENTION: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); + assert(fused->op_num_weights[op] == 0); IncMultiHeadSelfAttentionMeta *m = (IncMultiHeadSelfAttentionMeta *)metas->meta[op]; - // TODO: why is op_num_weight still non-zero? - assert(fused->op_num_weights[op] == - (1 + (int)(*m->qkv_bias || *m->final_bias))); GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - assert(fused->op_num_weights[op] == 2); - biases = my_weight_accessor[1]; - } IncMultiHeadSelfAttention::inference_kernel_wrapper( m, bc, @@ -469,17 +463,12 @@ __host__ void case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); + assert(fused->op_num_weights[op] == 0); TreeIncMultiHeadSelfAttentionMeta *m = (TreeIncMultiHeadSelfAttentionMeta *)metas->meta[op]; TreeVerifyBatchConfig const &tree_bc = Future(task->futures[0]).get_result(); - assert(fused->op_num_weights[op] == - (1 + (int)(*m->qkv_bias || *m->final_bias))); GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - assert(fused->op_num_weights[op] == 2); - biases = my_weight_accessor[1]; - } TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( m, &tree_bc, @@ -491,19 +480,14 @@ __host__ void case OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); + assert(fused->op_num_weights[op] == 0); SpecIncMultiHeadSelfAttentionMeta const *m = (SpecIncMultiHeadSelfAttentionMeta *)metas->meta[op]; // BeamSearchBatchConfig const *beam_bc = // (BeamSearchBatchConfig *)task->args; BeamSearchBatchConfig const &beam_bc = Future(task->futures[0]).get_result(); - assert(fused->op_num_weights[op] == - (1 + (int)(*m->qkv_bias || *m->final_bias))); GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - assert(fused->op_num_weights[op] == 2); - biases = my_weight_accessor[1]; - } SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( m, &beam_bc,