Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binary KQ mask #28

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
7 changes: 7 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.flash_attn = true;
return true;
}
if (arg == "-bkq" || arg == "--binary-kq") {
params.binary_kq = true;
return true;
}
if (arg == "-co" || arg == "--color") {
params.use_color = true;
return true;
Expand Down Expand Up @@ -1442,6 +1446,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep });
options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
options.push_back({ "*", "-bkq, --binary-kq", "enable binary KQ mask (default: %s)", params.binary_kq ? "enabled" : "disabled" });
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
"in conversation mode, this will be used as system prompt\n"
"(default: '%s')", params.prompt.c_str() });
Expand Down Expand Up @@ -2265,6 +2270,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.binary_kq = params.binary_kq;

cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
Expand Down Expand Up @@ -3261,6 +3267,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
fprintf(stream, "binary_kq: %s # default: false\n", params.binary_kq ? "true" : "false");
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);

const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ struct gpt_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
bool binary_kq = false; // use binary KQ mask (if allowed in the given context)

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool ignore_eos = false; // ignore generated EOS tokens
Expand Down
34 changes: 31 additions & 3 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ struct cmd_params {
std::vector<int> main_gpu;
std::vector<bool> no_kv_offload;
std::vector<bool> flash_attn;
std::vector<bool> binary_kq;
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
Expand Down Expand Up @@ -258,6 +259,7 @@ static const cmd_params cmd_params_defaults = {
/* main_gpu */ {0},
/* no_kv_offload */ {false},
/* flash_attn */ {false},
/* binary_kq */ {false},
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
Expand Down Expand Up @@ -289,6 +291,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
printf(" -bkq, --binary-kq <0|1> (default: %s)\n", join(cmd_params_defaults.binary_kq, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
Expand Down Expand Up @@ -503,6 +506,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = string_split<bool>(argv[i], split_delim);
params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
} else if (arg == "-bkq" || arg == "--binary-kq") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = string_split<bool>(argv[i], split_delim);
params.binary_kq.insert(params.binary_kq.end(), p.begin(), p.end());
} else if (arg == "-mmp" || arg == "--mmap") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -591,6 +601,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
if (params.binary_kq.empty()) { params.binary_kq = cmd_params_defaults.binary_kq; }
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
Expand All @@ -614,6 +625,7 @@ struct cmd_params_instance {
int main_gpu;
bool no_kv_offload;
bool flash_attn;
bool binary_kq;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
Expand Down Expand Up @@ -653,6 +665,7 @@ struct cmd_params_instance {
cparams.type_v = type_v;
cparams.offload_kqv = !no_kv_offload;
cparams.flash_attn = flash_attn;
cparams.binary_kq = binary_kq;
cparams.embeddings = embeddings;

return cparams;
Expand All @@ -677,6 +690,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & tv : params.type_v)
for (const auto & nkvo : params.no_kv_offload)
for (const auto & fa : params.flash_attn)
for (const auto & bkq : params.binary_kq)
for (const auto & nt : params.n_threads) {
for (const auto & n_prompt : params.n_prompt) {
if (n_prompt == 0) {
Expand All @@ -697,6 +711,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .binary_kq = */ bkq,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand All @@ -723,6 +738,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .binary_kq = */ bkq,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand All @@ -749,6 +765,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .binary_kq = */ bkq,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand Down Expand Up @@ -787,6 +804,7 @@ struct test {
int main_gpu;
bool no_kv_offload;
bool flash_attn;
bool binary_kq;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
Expand All @@ -813,6 +831,7 @@ struct test {
main_gpu = inst.main_gpu;
no_kv_offload = inst.no_kv_offload;
flash_attn = inst.flash_attn;
binary_kq = inst.binary_kq;
tensor_split = inst.tensor_split;
use_mmap = inst.use_mmap;
embeddings = inst.embeddings;
Expand Down Expand Up @@ -884,7 +903,7 @@ struct test {
"n_batch", "n_ubatch",
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
"main_gpu", "no_kv_offload", "flash_attn",
"main_gpu", "no_kv_offload", "flash_attn", "binary-kq",
"tensor_split", "use_mmap", "embeddings",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
Expand All @@ -906,7 +925,7 @@ struct test {
}
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
field == "flash_attn" || field == "use_mmap" || field == "embeddings") {
field == "flash_attn" || field == "binary-kq" || field == "use_mmap" || field == "embeddings") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
Expand Down Expand Up @@ -940,7 +959,7 @@ struct test {
std::to_string(n_batch), std::to_string(n_ubatch),
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(binary_kq),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
Expand Down Expand Up @@ -1103,6 +1122,9 @@ struct markdown_printer : public printer {
if (field == "flash_attn") {
return 2;
}
if (field == "binary-kq") {
return 3;
}
if (field == "use_mmap") {
return 4;
}
Expand Down Expand Up @@ -1134,6 +1156,9 @@ struct markdown_printer : public printer {
if (field == "flash_attn") {
return "fa";
}
if (field == "binary-kq") {
return "bkq";
}
if (field == "use_mmap") {
return "mmap";
}
Expand Down Expand Up @@ -1183,6 +1208,9 @@ struct markdown_printer : public printer {
if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
fields.emplace_back("flash_attn");
}
if (params.binary_kq.size() > 1 || params.binary_kq != cmd_params_defaults.binary_kq) {
fields.emplace_back("binary-kq");
}
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
fields.emplace_back("tensor_split");
}
Expand Down
40 changes: 25 additions & 15 deletions ggml/src/ggml-cuda/softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
#include "softmax.cuh"

template <typename T>
static __device__ __forceinline__ float t2f32(T val) {
return (float) val;
static __device__ __forceinline__ float mask_value(float slope, const T * mask, int iy) {
return mask ? slope * (float)mask[iy] : 0.0f;
}

template <>
__device__ float __forceinline__ t2f32<half>(half val) {
return __half2float(val);
__device__ __forceinline__ float mask_value(float slope, const half * mask, int iy) {
return mask ? slope * __half2float(mask[iy]) : 0.0f;
}

template <>
__device__ __forceinline__ float mask_value(float, const uint32_t * mask, int iy) {
return mask[iy >> 5] & (1u << (iy & 31)) ? 0.0f : -INFINITY;
}

template <bool vals_smem, int ncols_template, int block_size_template, typename T>
Expand Down Expand Up @@ -44,8 +49,8 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
const int64_t ix = (int64_t)rowx*ncols + col;
const int64_t iy = (int64_t)rowy*ncols + col;

const float val = do_softcap ? scale*cap_params1*tanhf(cap_params0*x[ix]) + (mask ? slope*t2f32(mask[iy]) : 0.0f) :
scale*x[ix] + (mask ? slope*t2f32(mask[iy]) : 0.0f);
const float val = do_softcap ? scale*cap_params1*tanhf(cap_params0*x[ix]) + mask_value(slope, mask, iy) :
scale*x[ix] + mask_value(slope, mask, iy);

vals[col] = val;
max_val = max(max_val, val);
Expand Down Expand Up @@ -181,7 +186,7 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32); // src1 contains mask and it is optional

const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
Expand All @@ -194,14 +199,17 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));

const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
const bool use_i32 = (src1 && src1->type == GGML_TYPE_I32);

if (use_f16) {
if (use_i32) {
const uint32_t * mask = (const uint32_t *)src1_d;
soft_max_f32_cuda(src0_d, mask, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
}
else if (use_f16) {
const half * src1_dd = (const half *)src1_d;

soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
} else {
const float * src1_dd = (const float *)src1_d;

soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
}
}
Expand All @@ -219,7 +227,7 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32); // src1 contains mask and it is optional

const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
Expand All @@ -229,15 +237,17 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds
memcpy(params, dst->op_params, sizeof(params));

const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
//printf("%s: %g, %g, %g, %g, %p, %d\n", __func__, params[0], params[1], params[2], params[3], (const void *)src1, use_f16);
const bool use_i32 = (src1 && src1->type == GGML_TYPE_I32);

if (use_f16) {
if (use_i32) {
const uint32_t * mask = (const uint32_t *)src1_d;
soft_max_f32_cuda(src0_d, mask, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
}
else if (use_f16) {
const half * src1_dd = (const half *)src1_d;

soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
} else {
const float * src1_dd = (const float *)src1_d;

soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
}
}
Loading