Skip to content

Commit

Permalink
Implemented for SpecInfer
Browse files Browse the repository at this point in the history
  • Loading branch information
yingchen21 committed Aug 4, 2024
1 parent 99f2879 commit c9d0fb1
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 386 deletions.
4 changes: 1 addition & 3 deletions include/flexflow/ops/spec_inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ class SpecIncMultiHeadSelfAttention : public Op {
BeamSearchBatchConfig const *bc,
int shard_id,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &weight,
GenericTensorAccessorW const &output,
GenericTensorAccessorR const &bias);
GenericTensorAccessorW const &output);
Params get_params() const;

public:
Expand Down
5 changes: 1 addition & 4 deletions include/flexflow/ops/tree_inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,7 @@ class TreeIncMultiHeadSelfAttention : public Op {
TreeVerifyBatchConfig const *bc,
int shard_id,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &weight,
GenericTensorAccessorW const &output,
GenericTensorAccessorR const &bias);

GenericTensorAccessorW const &output);
Params get_params() const;

public:
Expand Down
11 changes: 3 additions & 8 deletions src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ __host__ void
assert(fused->op_num_outputs[op] == 1);
IncMultiHeadSelfAttentionMeta *m =
(IncMultiHeadSelfAttentionMeta *)metas->meta[op];
// TODO: why is op_num_weight still non-zero?
assert(fused->op_num_weights[op] ==
(1 + (int)(*m->qkv_bias || *m->final_bias)));
GenericTensorAccessorR biases;
Expand All @@ -461,9 +462,7 @@ __host__ void
bc,
task->index_point.point_data[0],
my_input_accessor[0],
// my_weight_accessor[0],
my_output_accessor[0]
// biases
);
break;
}
Expand All @@ -486,9 +485,7 @@ __host__ void
&tree_bc,
task->index_point.point_data[0],
my_input_accessor[0],
my_weight_accessor[0],
my_output_accessor[0],
biases);
my_output_accessor[0]);
break;
}
case OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION: {
Expand All @@ -512,9 +509,7 @@ __host__ void
&beam_bc,
task->index_point.point_data[0],
my_input_accessor[0],
my_weight_accessor[0],
my_output_accessor[0],
biases);
my_output_accessor[0]);
break;
}
case OP_LAYERNORM: {
Expand Down
225 changes: 84 additions & 141 deletions src/ops/spec_inc_multihead_self_attention.cc

Large diffs are not rendered by default.

37 changes: 14 additions & 23 deletions src/ops/spec_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -715,14 +715,14 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
stream);
// phase 1: Implement kernel to compute KQV for input tokens
// TODO WARNING: this is commented out only because we are fixing the inc_attn first
// compute_qkv_kernel(m,
// bc,
// shard_id,
// // input_ptr,
// weight_ptr,
// static_cast<DT *>(m->devQKVProjArray),
// bias_ptr,
// stream);
compute_qkv_kernel(m,
bc,
shard_id,
// input_ptr,
// weight_ptr,
static_cast<DT *>(m->devQKVProjArray),
// bias_ptr,
stream);
// phase 2: Update key/val cache
update_kv_cache_kernel<DT>(m, bc, stream);
if (bc->num_generation_tokens > 0) {
Expand Down Expand Up @@ -756,9 +756,7 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper(
BeamSearchBatchConfig const *bc,
int shard_id,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &weight,
GenericTensorAccessorW const &output,
GenericTensorAccessorR const &bias) {
GenericTensorAccessorW const &output) {
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
bool use_bias = *m->qkv_bias || *m->final_bias;
Expand All @@ -770,35 +768,28 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper(
cudaEventRecord(t_start, stream);
}
assert(input.data_type == weight.data_type);
assert(input.data_type == output.data_type);
if (use_bias) {
assert(input.data_type == bias.data_type);
}
if (input.data_type == DT_HALF) {
half const *bias_ptr =
use_bias ? bias.get_half_ptr() : static_cast<half const *>(nullptr);
half const *bias_ptr = static_cast<half const *>(nullptr);
Kernels::SpecIncMultiHeadSelfAttention::inference_kernel(
m,
bc,
shard_id,
input.get_half_ptr(),
weight.get_half_ptr(),
static_cast<half const *>(nullptr),
output.get_half_ptr(),
bias_ptr,
static_cast<half const *>(nullptr),
stream);
} else if (input.data_type == DT_FLOAT) {
float const *bias_ptr =
use_bias ? bias.get_float_ptr() : static_cast<float const *>(nullptr);
Kernels::SpecIncMultiHeadSelfAttention::inference_kernel(
m,
bc,
shard_id,
input.get_float_ptr(),
weight.get_float_ptr(),
static_cast<float const *>(nullptr),
output.get_float_ptr(),
bias_ptr,
static_cast<float const *>(nullptr),
stream);
} else {
assert(false && "Unspported data type");
Expand Down
Loading

0 comments on commit c9d0fb1

Please sign in to comment.