Skip to content

Commit

Permalink
fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Nov 5, 2023
1 parent 87a294e commit a9e75e5
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 144 deletions.
125 changes: 99 additions & 26 deletions include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,27 @@ struct half4 {
half w;
};

struct half8 {
half x;
half y;
half z;
half w;
half a;
half b;
half c;
half d;
};
struct float8 {
float x;
float y;
float z;
float w;
float a;
float b;
float c;
float d;
};

////////////////data type///////////////
template <typename DT, int VECPSIZE>
struct VEC_K {};
Expand Down Expand Up @@ -69,6 +90,10 @@ template <>
struct Vec_fp32_<half4> {
using Type = float4;
};
template <>
struct Vec_fp32_<half8> {
using Type = float8;
};

template <typename DT>
struct VEC_V {};
Expand All @@ -78,7 +103,7 @@ struct VEC_V<float> {
};
template <>
struct VEC_V<half> {
using Type = half4;
using Type = half8;
};

////////////////data structures half///////////////
Expand Down Expand Up @@ -181,8 +206,17 @@ inline __device__ float4 fma(float a, float4 b, float4 c) {
return d;
}

inline __device__ float4 fma(half a, float4 b, float4 c) {
assert(false);
inline __device__ float8 fma(float a, float8 f1, float8 f2) {
float8 res;
res.x = fma(a, f1.x, f2.x);
res.y = fma(a, f1.y, f2.y);
res.z = fma(a, f1.z, f2.z);
res.w = fma(a, f1.w, f2.w);
res.a = fma(a, f1.a, f2.a);
res.b = fma(a, f1.b, f2.b);
res.c = fma(a, f1.c, f2.c);
res.d = fma(a, f1.d, f2.d);
return res;
}

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -211,6 +245,19 @@ inline __device__ float4 add(float4 a, float4 b) {
return c;
}

inline __device__ float8 add(float8 f1, float8 f2) {
float8 res;
res.x = add(f1.x, f2.x);
res.y = add(f1.y, f2.y);
res.z = add(f1.z, f2.z);
res.w = add(f1.w, f2.w);
res.a = add(f1.a, f2.a);
res.b = add(f1.b, f2.b);
res.c = add(f1.c, f2.c);
res.d = add(f1.d, f2.d);
return res;
}

inline __device__ float sum(float v) {
return v;
}
Expand Down Expand Up @@ -271,6 +318,18 @@ inline __device__ float4 cast_to_float(half4 u) {
tmp.w = __half2float(u.w);
return tmp;
}
inline __device__ float8 cast_to_float(half8 u) {
float8 tmp;
tmp.x = __half2float(u.x);
tmp.y = __half2float(u.y);
tmp.z = __half2float(u.z);
tmp.w = __half2float(u.w);
tmp.a = __half2float(u.a);
tmp.b = __half2float(u.b);
tmp.c = __half2float(u.c);
tmp.d = __half2float(u.d);
return tmp;
}

inline __device__ void convert_from_float(float4 &dst, float4 src) {
dst = src;
Expand All @@ -281,13 +340,27 @@ inline __device__ void convert_from_float(float &dst, float src) {
inline __device__ void convert_from_float(float2 &dst, float2 src) {
dst = src;
}
inline __device__ void convert_from_float(float8 &dst, float8 src) {
dst = src;
}

inline __device__ void convert_from_float(half4 &dst, float4 src) {
dst.x = __float2half(src.x);
dst.y = __float2half(src.y);
dst.z = __float2half(src.z);
dst.w = __float2half(src.w);
}

inline __device__ void convert_from_float(half8 &dst, float8 src) {
dst.x = __float2half(src.x);
dst.y = __float2half(src.y);
dst.z = __float2half(src.z);
dst.w = __float2half(src.w);
dst.a = __float2half(src.a);
dst.b = __float2half(src.b);
dst.c = __float2half(src.c);
dst.d = __float2half(src.d);
}
inline __device__ void convert_from_float(half2 &dst, float2 src) {
dst.x = __float2half(src.x);
dst.y = __float2half(src.y);
Expand Down Expand Up @@ -317,7 +390,8 @@ inline __device__ float qk_dot_(K_vec const (&q)[N], K_vec const (&k)[N]) {
// use float32 to get better accuracy
using Vec_sum = typename Vec_fp32_<K_vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
Vec_sum qk_vec = mul<Vec_sum, Vec_sum, Vec_sum>(cast_to_float(q[0]), cast_to_float(k[0]));
Vec_sum qk_vec =
mul<Vec_sum, Vec_sum, Vec_sum>(cast_to_float(q[0]), cast_to_float(k[0]));
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = FlexFlow::fma(cast_to_float(q[ii]), cast_to_float(k[ii]), qk_vec);
Expand Down Expand Up @@ -375,7 +449,6 @@ inline __device__ float block_sum(float *red_smem, float sum) {
return __shfl_sync(uint32_t(-1), sum, 0);
}

// utils
template <typename DT>
inline size_t smem_size_in_bytes(int hidden_size_per_head,
int max_sequence_length,
Expand All @@ -384,69 +457,69 @@ inline size_t smem_size_in_bytes(int hidden_size_per_head,
// The amount of shared memory needed to store the Q*K^T values in float.

size_t qk_sz = div_up(max_sequence_length + 1, 4) * 16;
// The extra memory needed if we are not using floats for the final logits.

// store the extra memory if half percision
size_t logits_sz = qk_sz;
// if (sizeof(DT) != 4) {
// logits_sz = div_up(max_sequence_length + 1, 4) * 4 * sizeof(DT);
// }

// The total size needed during softmax.
size_t softmax_sz = qk_sz + logits_sz;
size_t q_size = hidden_size_per_head * sizeof(DT);

// The number of partial rows to reduce in the final reduction.
int rows_per_red = threads_per_block / threads_per_value;
// The amount of storage needed to finalize the outputs.
size_t red_sz = rows_per_red * hidden_size_per_head * sizeof(DT) / 2;

size_t red_sz = rows_per_red * hidden_size_per_head * sizeof(float) / 2;
// The max.
return max(softmax_sz, red_sz);
return max(softmax_sz, red_sz) + q_size;
}

template <typename DT>
inline void smem_size_in_bytes_tree(int hidden_size_per_head,
int max_sequence_length,
int threads_per_value,
int threads_per_block,
TreeVerifyBatchConfig const *bc,
int shared_mem[]) {

int max_query_length = 0;
int max_total_length = 0;
// int max_total_length = 0;
for (int i = 0; i < bc->max_requests_per_batch(); i++) {
if (bc->request_completed[i]) {
continue;
}
max_query_length =
max(max_query_length, bc->requestsInfo[i].num_tokens_in_batch);
max_total_length = max(max_total_length,
bc->requestsInfo[i].first_token_depth_in_request +
bc->requestsInfo[i].num_tokens_in_batch);
// max_total_length = max(max_total_length,
// bc->requestsInfo[i].first_token_depth_in_request +
// bc->requestsInfo[i].num_tokens_in_batch);
}

int max_qk_length = max_query_length * max_total_length;
// todo fix this
int max_qk_length = 1200;

// The amount of shared memory needed to store the Q*K^T values in float.
size_t qk_sz = div_up(max_qk_length + 1, 4) * 16;
// The extra memory needed if we are not using floats for the final logits.

// store the extra memory if half percision
size_t logits_sz = qk_sz;
// if (sizeof(DT) != 4) {
// logits_sz = div_up(max_qk_length + 1, 4) * 4 * sizeof(DT);
// }

// The total size needed during softmax.
size_t softmax_sz = qk_sz + logits_sz;

size_t q_size = hidden_size_per_head * sizeof(DT);

// The number of partial rows to reduce in the final reduction.
int rows_per_red = threads_per_block / threads_per_value;
// The amount of storage needed to finalize the outputs.
size_t red_sz = rows_per_red * hidden_size_per_head * sizeof(DT) / 2;
// use 4
size_t red_sz = rows_per_red * hidden_size_per_head * sizeof(float) / 2;

// The max.
shared_mem[0] = qk_sz;
shared_mem[1] = max(softmax_sz, red_sz);
shared_mem[1] = max(softmax_sz, red_sz) + q_size;
}

template <typename T, int Dh>
struct threads_per_value_t {
static int const value = Dh * sizeof(T) / 16;
};

} // namespace FlexFlow
#endif // _FLEXFLOW_OPS_KERNELS_INC_MULTIHEAD_SELF_UTILS_H
Loading

0 comments on commit a9e75e5

Please sign in to comment.