diff --git a/common/common.cpp b/common/common.cpp index 3b45d066..85baa5e2 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -808,6 +808,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.flash_attn = true; return true; } + if (arg == "-bkq" || arg == "--binary-kq") { + params.binary_kq = true; + return true; + } if (arg == "-co" || arg == "--color") { params.use_color = true; return true; @@ -1442,6 +1446,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep }); options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-bkq, --binary-kq", "enable binary KQ mask (default: %s)", params.binary_kq ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" "(default: '%s')", params.prompt.c_str() }); @@ -2265,6 +2270,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; + cparams.binary_kq = params.binary_kq; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -3261,6 +3267,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); + fprintf(stream, "binary_kq: %s # default: false\n", params.binary_kq ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index 50035897..28b56471 100644 --- a/common/common.h +++ b/common/common.h @@ -173,6 +173,7 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention + bool binary_kq = false; // use binary KQ mask (if allowed in the given context) bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 813d7bae..0736e393 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -231,6 +231,7 @@ struct cmd_params { std::vector main_gpu; std::vector no_kv_offload; std::vector flash_attn; + std::vector binary_kq; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; @@ -258,6 +259,7 @@ static const cmd_params cmd_params_defaults = { /* main_gpu */ {0}, /* no_kv_offload */ {false}, /* flash_attn */ {false}, + /* binary_kq */ {false}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -289,6 +291,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); + printf(" -bkq, --binary-kq <0|1> (default: %s)\n", join(cmd_params_defaults.binary_kq, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -503,6 +506,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end()); + } else if (arg == "-bkq" || arg == "--binary-kq") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.binary_kq.insert(params.binary_kq.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -591,6 +601,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } + if (params.binary_kq.empty()) { params.binary_kq = cmd_params_defaults.binary_kq; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -614,6 +625,7 @@ struct cmd_params_instance { int main_gpu; bool no_kv_offload; bool flash_attn; + bool binary_kq; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -653,6 +665,7 @@ struct cmd_params_instance { cparams.type_v = type_v; cparams.offload_kqv = !no_kv_offload; cparams.flash_attn = flash_attn; + cparams.binary_kq = binary_kq; cparams.embeddings = embeddings; return cparams; @@ -677,6 +690,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & tv : params.type_v) for (const auto & nkvo : params.no_kv_offload) for (const auto & fa : params.flash_attn) + for (const auto & bkq : params.binary_kq) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -697,6 +711,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .binary_kq = */ bkq, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -723,6 +738,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .binary_kq = */ bkq, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -749,6 +765,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .binary_kq = */ bkq, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -787,6 +804,7 @@ struct test { int main_gpu; bool no_kv_offload; bool flash_attn; + bool binary_kq; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -813,6 +831,7 @@ struct test { main_gpu = inst.main_gpu; no_kv_offload = inst.no_kv_offload; flash_attn = inst.flash_attn; + binary_kq = inst.binary_kq; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -884,7 +903,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", + "main_gpu", "no_kv_offload", "flash_attn", "binary-kq", "tensor_split", "use_mmap", "embeddings", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -906,7 +925,7 @@ struct test { } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "use_mmap" || field == "embeddings") { + field == "flash_attn" || field == "binary-kq" || field == "use_mmap" || field == "embeddings") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -940,7 +959,7 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(binary_kq), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -1103,6 +1122,9 @@ struct markdown_printer : public printer { if (field == "flash_attn") { return 2; } + if (field == "binary-kq") { + return 3; + } if (field == "use_mmap") { return 4; } @@ -1134,6 +1156,9 @@ struct markdown_printer : public printer { if (field == "flash_attn") { return "fa"; } + if (field == "binary-kq") { + return "bkq"; + } if (field == "use_mmap") { return "mmap"; } @@ -1183,6 +1208,9 @@ struct markdown_printer : public printer { if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) { fields.emplace_back("flash_attn"); } + if (params.binary_kq.size() > 1 || params.binary_kq != cmd_params_defaults.binary_kq) { + fields.emplace_back("binary-kq"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index 6f3056e6..e4a31fa2 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -2,13 +2,18 @@ #include "softmax.cuh" template -static __device__ __forceinline__ float t2f32(T val) { - return (float) val; +static __device__ __forceinline__ float mask_value(float slope, const T * mask, int iy) { + return mask ? slope * (float)mask[iy] : 0.0f; } template <> -__device__ float __forceinline__ t2f32(half val) { - return __half2float(val); +__device__ __forceinline__ float mask_value(float slope, const half * mask, int iy) { + return mask ? slope * __half2float(mask[iy]) : 0.0f; +} + +template <> +__device__ __forceinline__ float mask_value(float, const uint32_t * mask, int iy) { + return mask[iy >> 5] & (1u << (iy & 31)) ? 0.0f : -INFINITY; } template @@ -44,8 +49,8 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst const int64_t ix = (int64_t)rowx*ncols + col; const int64_t iy = (int64_t)rowy*ncols + col; - const float val = do_softcap ? scale*cap_params1*tanhf(cap_params0*x[ix]) + (mask ? slope*t2f32(mask[iy]) : 0.0f) : - scale*x[ix] + (mask ? slope*t2f32(mask[iy]) : 0.0f); + const float val = do_softcap ? scale*cap_params1*tanhf(cap_params0*x[ix]) + mask_value(slope, mask, iy) : + scale*x[ix] + mask_value(slope, mask, iy); vals[col] = val; max_val = max(max_val, val); @@ -181,7 +186,7 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32); // src1 contains mask and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -194,14 +199,17 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const bool use_i32 = (src1 && src1->type == GGML_TYPE_I32); - if (use_f16) { + if (use_i32) { + const uint32_t * mask = (const uint32_t *)src1_d; + soft_max_f32_cuda(src0_d, mask, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); + } + else if (use_f16) { const half * src1_dd = (const half *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); } else { const float * src1_dd = (const float *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); } } @@ -219,7 +227,7 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32); // src1 contains mask and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -229,15 +237,17 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds memcpy(params, dst->op_params, sizeof(params)); const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - //printf("%s: %g, %g, %g, %g, %p, %d\n", __func__, params[0], params[1], params[2], params[3], (const void *)src1, use_f16); + const bool use_i32 = (src1 && src1->type == GGML_TYPE_I32); - if (use_f16) { + if (use_i32) { + const uint32_t * mask = (const uint32_t *)src1_d; + soft_max_f32_cuda(src0_d, mask, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); + } + else if (use_f16) { const half * src1_dd = (const half *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); } else { const float * src1_dd = (const float *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); } } diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 83bd76f9..51b223c7 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -67,10 +67,14 @@ GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32_4, GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16, GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4, GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32, GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4, + GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32, + GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, @@ -576,10 +580,14 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32, soft_max_u32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32_4, soft_max_u32_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16, soft_cap_max_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4, soft_cap_max_f16_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32, soft_cap_max_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4, soft_cap_max_f32_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32, soft_cap_max_u32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4, soft_cap_max_u32_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); @@ -1629,19 +1637,22 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32); int nth = 32; // SIMD width id pipeline = nil; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const bool use_u32 = (src1 && src1->type == GGML_TYPE_I32); if (ne00%4 == 0) { while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { nth *= 2; } - if (use_f16) { + if (use_u32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32_4].pipeline; + } else if (use_f16) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; @@ -1650,7 +1661,9 @@ static enum ggml_status ggml_metal_graph_compute( while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { nth *= 2; } - if (use_f16) { + if (use_u32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32].pipeline; + } else if (use_f16) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; @@ -1694,19 +1707,22 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_SOFT_CAP_MAX: { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32); int nth = 32; // SIMD width id pipeline = nil; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const bool use_u32 = (src1 && src1->type == GGML_TYPE_I32); if (ne00%4 == 0) { while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { nth *= 2; } - if (use_f16) { + if (use_u32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4].pipeline; + } else if (use_f16) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline; @@ -1715,7 +1731,9 @@ static enum ggml_status ggml_metal_graph_compute( while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { nth *= 2; } - if (use_f16) { + if (use_u32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32].pipeline; + } else if (use_f16) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index f9c88a37..8bd4e5c2 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -453,6 +453,198 @@ kernel void kernel_sum_rows( dst_row[0] = row_sum; } +kernel void kernel_soft_max_u32( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + // parallel max + float lmax = -INFINITY; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + pdst[i00] = pmask[i00 >> 5] & (1u << (i00 & 31)) ? psrc0[i00]*scale : -INFINITY; + lmax = MAX(lmax, pdst[i00]); + } + + // find the max value in the block + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + const float exp_psrc0 = exp(pdst[i00] - max_val); + lsum += exp_psrc0; + pdst[i00] = exp_psrc0; + } + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + pdst[i00] *= inv_sum; + } +} + +kernel void kernel_soft_max_u32_4( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + + // parallel max + float4 lmax4 = -INFINITY; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + int idx = 4*i00; + uint8_t m4 = pmask[idx >> 5] >> (idx & 31); + float4 val = psrc4[i00]*scale; + val[0] = m4 & 1 ? val[0] : -INFINITY; + val[1] = m4 & 2 ? val[1] : -INFINITY; + val[2] = m4 & 4 ? val[2] : -INFINITY; + val[3] = m4 & 8 ? val[3] : -INFINITY; + lmax4 = fmax(lmax4, val); + pdst4[i00] = val; + } + + const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp(pdst4[i00] - max_val); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + + const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + pdst4[i00] *= inv_sum; + } +} + template kernel void kernel_soft_max( device const char * src0, @@ -661,6 +853,101 @@ kernel void kernel_soft_max_4( } } +kernel void kernel_soft_cap_max_u32( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant float & s_before, + constant float & s_after, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float * psrc0 = (device const float * ) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32; + device float * pdst = (device float * ) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + // parallel max + float lmax = -INFINITY; + + const float tot_scale = scale * s_after; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + float val = pmask[i00 >> 5] & (1u << (i00 & 31)) ? precise::tanh(s_before*psrc0[i00])*tot_scale : -INFINITY; + lmax = MAX(lmax, val); + pdst[i00] = val; + } + + // find the max value in the block + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + const float exp_psrc0 = exp(pdst[i00] - max_val); + lsum += exp_psrc0; + pdst[i00] = exp_psrc0; + } + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + pdst[i00] *= inv_sum; + } +} + template kernel void kernel_soft_cap_max( device const char * src0, @@ -767,6 +1054,116 @@ kernel void kernel_soft_cap_max( } } +kernel void kernel_soft_cap_max_u32_4( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant float & s_before, + constant float & s_after, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + + const float tot_scale = scale * s_after; + + // parallel max + float4 lmax4 = -INFINITY; + float4 vinf = lmax4; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + float4 val = precise::tanh(s_before*psrc4[i00])*tot_scale; + int idx = 4*i00; + uint8_t m = pmask[idx >> 5] >> (idx & 31); + bool4 m4 = { m & 1 ? true : false, m & 2 ? true : false, m & 4 ? true : false, m & 8 ? true : false }; + //bool4 m4 = ((pmask[idx >> 5] >> (idx & 31)) & 0xf) * 0x01010101; + val = select(vinf, val, m4); + //uint32_t m = pmask[idx >> 5] >> (idx & 31); + //val[0] = m & 1 ? val[0] : -INFINITY; + //val[1] = m & 2 ? val[1] : -INFINITY; + //val[2] = m & 4 ? val[2] : -INFINITY; + //val[3] = m & 8 ? val[3] : -INFINITY; + lmax4 = fmax(lmax4, val); + pdst4[i00] = val; + } + + const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp(pdst4[i00] - max_val); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + + const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + pdst4[i00] *= inv_sum; + } +} + template kernel void kernel_soft_cap_max_4( device const char * src0, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index cebac584..939f5215 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2043,6 +2043,162 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } +#ifdef __AVX512F__ +static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { + __m512 vslope = _mm512_set1_ps(slope); + __m512 vmax = _mm512_set1_ps(-INFINITY); + for (int j = 0; j < n/16; ++j) { + __m512 v = _mm512_fmadd_ps(vslope, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x + j)), _mm512_loadu_ps(y + 16*j)); + _mm512_storeu_ps(y + 16*j, v); + vmax = _mm512_max_ps(vmax, v); + } + float max = _mm512_reduce_max_ps(vmax); + for (int i = 16*(n/16); i < n; ++i) { + y[i] += slope*GGML_FP16_TO_FP32(x[i]); + max = MAX(max, y[i]); + } + return max; +} +static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) { + __m512 vslope = _mm512_set1_ps(slope); + __m512 vmax = _mm512_set1_ps(-INFINITY); + for (int j = 0; j < n/16; ++j) { + __m512 v = _mm512_fmadd_ps(vslope, _mm512_loadu_ps(x + 16*j), _mm512_loadu_ps(y + 16*j)); + _mm512_storeu_ps(y + 16*j, v); + vmax = _mm512_max_ps(vmax, v); + } + float max = _mm512_reduce_max_ps(vmax); + for (int i = 16*(n/16); i < n; ++i) { + y[i] += slope*x[i]; + max = MAX(max, y[i]); + } + return max; +} +#elif defined __AVX2__ +static inline float hmax_f32x8(__m256 v) { + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(v, 1), _mm256_castps256_ps128(v)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + return _mm_cvtss_f32( max4 ); +} +static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { + __m256 vmax = _mm256_set1_ps(-INFINITY); + if (fabsf(slope - 1.0f) < 1e-5f) { + for (int j = 0; j < n/8; ++j) { + __m256 vmask = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x + j)); + __m256 v = _mm256_add_ps(vmask, _mm256_loadu_ps(y + 8*j)); + _mm256_storeu_ps(y + 8*j, v); + vmax = _mm256_max_ps(vmax, v); + } + float max = hmax_f32x8(vmax); + for (int i = 8*(n/8); i < n; ++i) { + y[i] += slope*GGML_FP16_TO_FP32(x[i]); + max = MAX(max, y[i]); + } + return max; + } + __m256 vslope = _mm256_set1_ps(slope); + for (int j = 0; j < n/8; ++j) { +#ifdef __FMA__ + __m256 v = _mm256_fmadd_ps(vslope, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x + j)), _mm256_loadu_ps(y + 8*j)); +#else + __m256 vmask = _mm256_mul_ps(vslope, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x + j))); + __m256 v = _mm256_add_ps(vmask, _mm256_loadu_ps(y + 8*j)); +#endif + _mm256_storeu_ps(y + 8*j, v); + vmax = _mm256_max_ps(vmax, v); + } + float max = hmax_f32x8(vmax); + for (int i = 8*(n/8); i < n; ++i) { + y[i] += slope*GGML_FP16_TO_FP32(x[i]); + max = MAX(max, y[i]); + } + return max; +} +static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) { + __m256 vmax = _mm256_set1_ps(-INFINITY); + if (fabsf(slope - 1.0f) < 1e-5f) { + for (int j = 0; j < n/8; ++j) { + __m256 vmask = _mm256_loadu_ps(x + 8*j); + __m256 v = _mm256_add_ps(vmask, _mm256_loadu_ps(y + 8*j)); + _mm256_storeu_ps(y + 8*j, v); + vmax = _mm256_max_ps(vmax, v); + } + float max = hmax_f32x8(vmax); + for (int i = 8*(n/8); i < n; ++i) { + y[i] += slope*x[i]; + max = MAX(max, y[i]); + } + return max; + } + __m256 vslope = _mm256_set1_ps(slope); + for (int j = 0; j < n/8; ++j) { +#ifdef __FMA__ + __m256 v = _mm256_fmadd_ps(vslope, _mm256_loadu_ps(x + 8*j), _mm256_loadu_ps(y + 8*j)); +#else + __m256 vmask = _mm256_mul_ps(vslope, _mm256_loadu_ps(x + 8*j)); + __m256 v = _mm256_add_ps(vmask, _mm256_loadu_ps(y + 8*j)); +#endif + _mm256_storeu_ps(y + 8*j, v); + vmax = _mm256_max_ps(vmax, v); + } + float max = hmax_f32x8(vmax); + for (int i = 8*(n/8); i < n; ++i) { + y[i] += slope*x[i]; + max = MAX(max, y[i]); + } + return max; +} +#elif defined __ARM_NEON +static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { + float32x4_t vslope = vdupq_n_f32(slope); + float32x4_t vmax = vdupq_n_f32(-INFINITY); + for (int j = 0; j < n/4; ++j) { + float32x4_t val = vmlaq_f32(vld1q_f32(y + 4*j), vslope, vcvt_f32_f16(vld1_f16((const float16_t *)x + 4*j))); + vmax = vmaxq_f32(vmax, val); + vst1q_f32(y + 4*j, val); + } + float max = vmaxvq_f32(vmax); + for (int i = 4*(n/4); i < n; ++i) { + y[i] += slope*x[i]; + max = MAX(max, y[i]); + } + return max; +} +static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) { + float32x4_t vslope = vdupq_n_f32(slope); + float32x4_t vmax = vdupq_n_f32(-INFINITY); + for (int j = 0; j < n/4; ++j) { + float32x4_t val = vmlaq_f32(vld1q_f32(y + 4*j), vslope, vld1q_f32(x + 4*j)); + vmax = vmaxq_f32(vmax, val); + vst1q_f32(y + 4*j, val); + } + float max = vmaxvq_f32(vmax); + for (int i = 4*(n/4); i < n; ++i) { + y[i] += slope*x[i]; + max = MAX(max, y[i]); + } + return max; +} +#else +static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { + float max = -INFINITY; + for (int i = 0; i < n; ++i) { + y[i] += slope * GGML_FP16_TO_FP32(x[i]); + max = MAX(max, y[i]); + } + return max; +} +static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) { + float max = -INFINITY; + for (int i = 0; i < n; ++i) { + y[i] += slope * x[i]; + max = MAX(max, y[i]); + } + return max; +} +#endif + static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -2614,6 +2770,13 @@ inline static __m512 ggml_v_softcap(__m512 x, __m512 s_before, __m512 s_after) { return _mm512_mul_ps(th, s_after); } +inline static __m512 ggml_v_softcap_mask(__m512 x, __m512 s_before, __m512 s_after, __m512 src, __mmask16 mask) { + const __m512 one = _mm512_set1_ps(1.0f); + const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, s_before)); + const __m512 th = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); + return _mm512_mask_mul_ps(src, mask, th, s_after); +} + inline static __m512 ggml_v_gelu(__m512 x, __m512 c1, __m512 c2) { const __m512 one = _mm512_set1_ps(1.0f); __m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one); @@ -2851,6 +3014,108 @@ static void ggml_vec_cpy_softcap_f32(const int n, const float * x, float * y, fl } } +#ifdef __AVX512F__ +static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { + const __mmask16 * m16 = (const __mmask16 *)mask; + __m512 vinf = _mm512_set1_ps(-INFINITY); + __m512 vmax = vinf; + __m512 vs_before = _mm512_set1_ps(2.f*s_before); + __m512 vs_after = _mm512_set1_ps(s_after); + for (int i = 0; i < n/16; ++i) { + __m512 v = ggml_v_softcap_mask(_mm512_loadu_ps(x + 16*i), vs_before, vs_after, vinf, m16[i]); + _mm512_storeu_ps(y + 16*i, v); + vmax = _mm512_max_ps(vmax, v); + } + return _mm512_reduce_max_ps(vmax); +} +static float ggml_vec_cpy_soft_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float scale) { + const __mmask16 * m16 = (const __mmask16 *)mask; + __m512 vinf = _mm512_set1_ps(-INFINITY); + __m512 vmax = vinf; + __m512 vscale = _mm512_set1_ps(scale); + for (int i = 0; i < n/16; ++i) { + __m512 v = _mm512_mask_mul_ps(vinf, m16[i], vscale, _mm512_loadu_ps(x + 16*i)); + _mm512_storeu_ps(y + 16*i, v); + vmax = _mm512_max_ps(vmax, v); + } + return _mm512_reduce_max_ps(vmax); +} +//#elif __ARM_NEON +//static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { +// //const uint16_t * mask16 = (const uint16_t *)mask; +// const uint8_t * mask8 = (const uint8_t *)mask; +// float32x4_t vinf = vdupq_n_f32(-INFINITY); +// float32x4x4_t vmax = { vinf, vinf, vinf, vinf }; +// float32x4_t vs_before = vdupq_n_f32(s_before); +// float32x4_t vs_after = vdupq_n_f32(s_after ); +// const uint8x16_t vmask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); +// //const uint8x8_t vmask = vreinterpret_u8_u64(vdup_n_u64(0x8040201008040201)); +// //static const uint32_t k_shuffle[8] = { 0x00000000, 0x01010101, 0x02020202, 0x03030303, +// // 0x04040404, 0x05050505, 0x06060606, 0x07070707 }; +// //const uint8x8x4_t vtab = vld1_u8_x4((const uint8_t *)k_shuffle); +// //for (int i = 0; i < n/16; ++i) { +// // float32x4x4_t vx = vld1q_f32_x4(x + 16*i); +// // uint8x8_t m1 = vceq_u8(vand_u8(vdup_n_u8(mask8[2*i+0]), vmask), vmask); +// // uint8x8_t m2 = vceq_u8(vand_u8(vdup_n_u8(mask8[2*i+1]), vmask), vmask); +// // uint8x16x4_t mk = { vcombine_u8(vqtbl1_u8(m1, vtab.val[0]), vqtbl1_u8(m1, vtab.val[1])), +// // for (int k = 0; k < 4; ++k) { +// // vx.val[k] = ggml_v_softcap(vx.val[k], vs_before, vs_after); +// // //uint8x16_t mk = vcombine(vqtbl1_u8(m1, vtab.val[k]), +// // uint8x16_t v_on = vandq_u8(vreinterpretq_u8_f32(vx.val[k]), mk); +// // uint8x16_t v_off = vandq_u8(vreinterpretq_u8_f32(vinf), mk); +// // vx.val[k] = vreinterpretq_f32_u8(vorrq_u8(v_on, v_off)); +// // vmax.val[k] = vmaxq_f32(vmax.val[k], vx.val[k]); +// // vst1q_f32(y + 16*i + 4*k, vx.val[k]); +// // } +// //} +// static const uint32_t k_shuffle[16] = { 0x00000000, 0x01010101, 0x02020202, 0x03030303, +// 0x04040404, 0x05050505, 0x06060606, 0x07070707, +// 0x08080808, 0x09090909, 0x0a0a0a0a, 0x0b0b0b0b, +// 0x0c0c0c0c, 0x0d0d0d0d, 0x0e0e0e0e, 0x0f0f0f0f}; +// const uint8x16x4_t vtab = vld1q_u8_x4((const uint8_t *)k_shuffle); +// for (int i = 0; i < n/16; ++i) { +// float32x4x4_t vx = vld1q_f32_x4(x + 16*i); +// uint8x16_t m = vcombine_u8(vdup_n_u8(mask8[2*i+0]), vdup_n_u8(mask8[2*i+1])); +// m = vceqq_u8(vandq_u8(m, vmask), vmask); +// for (int k = 0; k < 4; ++k) { +// vx.val[k] = ggml_v_softcap(vx.val[k], vs_before, vs_after); +// uint8x16_t mk = vqtbl1q_u8(m, vtab.val[k]); +// uint8x16_t v_on = vandq_u8(vreinterpretq_u8_f32(vx.val[k]), mk); +// uint8x16_t v_off = vandq_u8(vreinterpretq_u8_f32(vinf), mk); +// vx.val[k] = vreinterpretq_f32_u8(vorrq_u8(v_on, v_off)); +// vmax.val[k] = vmaxq_f32(vmax.val[k], vx.val[k]); +// vst1q_f32(y + 16*i + 4*k, vx.val[k]); +// } +// } +// float max = vmaxvq_f32(vmax.val[0]); +// for (int k = 1; k < 4; ++k) { +// float maxk = vmaxvq_f32(vmax.val[k]); +// max = MAX(max, maxk); +// } +// return max; +//} +#else +static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { + GGML_UNUSED(n); + GGML_UNUSED(x); + GGML_UNUSED(y); + GGML_UNUSED(mask); + GGML_UNUSED(s_before); + GGML_UNUSED(s_after); + GGML_ASSERT(false); + return 0.f; +} +static float ggml_vec_cpy_soft_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float scale) { + GGML_UNUSED(n); + GGML_UNUSED(x); + GGML_UNUSED(y); + GGML_UNUSED(mask); + GGML_UNUSED(scale); + GGML_ASSERT(false); + return 0.f; +} +#endif + static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) { int i = 0; #if defined(__AVX512F__) && defined(__AVX512DQ__) @@ -6013,10 +6278,10 @@ static struct ggml_tensor * ggml_softcap_max_impl( GGML_ASSERT(ggml_is_padded_1d(a)); if (mask) { - GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32 || mask->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(mask->ne[0] == a->ne[0]); + //GGML_ASSERT(mask->ne[0] == a->ne[0]); GGML_ASSERT(mask->ne[1] >= a->ne[1]); } @@ -6784,10 +7049,14 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { - GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32 || mask->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(mask->ne[0] == a->ne[0]); + if (mask->type == GGML_TYPE_I32) { + GGML_ASSERT(mask->ne[0] == (a->ne[0] + 31)/32); + } else { + GGML_ASSERT(mask->ne[0] == a->ne[0]); + } GGML_ASSERT(mask->ne[1] >= a->ne[1]); } @@ -13718,7 +13987,7 @@ static void ggml_compute_forward_softcap( default: { GGML_ASSERT(false); - } break; + } } } @@ -13767,6 +14036,7 @@ static void ggml_compute_forward_softcap_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const bool use_i32 = (src1 && src1->type == GGML_TYPE_I32); for (int i1 = ir0; i1 < ir1; i1++) { // ALiBi @@ -13777,21 +14047,29 @@ static void ggml_compute_forward_softcap_max_f32( float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + const int mask_row = i1%ne01; - ggml_vec_cpy_softcap_f32(nc, sp, wp, values[2], values[0]*values[3]); + float max = -INFINITY; + if (use_i32) { + int n32 = (ne00 + 31)/32; + const uint32_t * mp_u32 = (const uint32_t *)src1->data + mask_row*n32; + max = ggml_vec_cpy_softcap_mask_f32(nc, sp, wp, mp_u32, values[2], values[0]*values[3]); + } else { - if (mp_f32) { - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); - } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*mp_f32[i]; + ggml_vec_cpy_softcap_f32(nc, sp, wp, values[2], values[0]*values[3]); + + if (src1) { + if (use_f16) { + ggml_fp16_t * mp_f16 = (ggml_fp16_t *)((char *) src1->data) + mask_row*ne00; + max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope); + } else { + float * mp_f32 = (float *)((char *) src1->data) + mask_row*ne00; + max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope); } } + else { + ggml_vec_max_f32(nc, &max, wp); + } } #ifndef NDEBUG @@ -13801,8 +14079,8 @@ static void ggml_compute_forward_softcap_max_f32( } #endif - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, wp); + //float max = -INFINITY; + //ggml_vec_max_f32(nc, &max, wp); ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); assert(sum > 0.0); @@ -13834,7 +14112,7 @@ static void ggml_compute_forward_softcap_max( default: { GGML_ASSERT(false); - } break; + } } } @@ -14570,6 +14848,7 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const bool use_u32 = (src1 && src1->type == GGML_TYPE_I32); for (int i1 = ir0; i1 < ir1; i1++) { // ALiBi @@ -14579,33 +14858,54 @@ static void ggml_compute_forward_soft_max_f32( float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); - // broadcast the mask across rows - ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - - ggml_vec_cpy_f32 (nc, wp, sp); - ggml_vec_scale_f32(nc, wp, scale); - if (mp_f32) { - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); - } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*mp_f32[i]; + float max = -INFINITY; + if (use_u32) { + int n32 = ne00/32; + const uint32_t * mp_u32 = (const uint32_t *)src1->data + (i1%ne01)*n32; + max = ggml_vec_cpy_soft_mask_f32(nc, sp, wp, mp_u32, scale); + } else { + + ggml_vec_cpy_f32 (nc, wp, sp); + ggml_vec_scale_f32(nc, wp, scale); + if (src1) { + // broadcast the mask across rows + if (use_f16) { + ggml_fp16_t * mp_f16 = (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00; + max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope); + } else { + float * mp_f32 = (float *)((char *) src1->data) + (i1%ne01)*ne00; + max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope); } } - } + else { + ggml_vec_max_f32(nc, &max, wp); + } + + //// broadcast the mask across rows + //ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + //float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + + //if (mp_f32) { + // if (use_f16) { + // for (int i = 0; i < nc; ++i) { + // wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); + // } + // } else { + // for (int i = 0; i < nc; ++i) { + // wp[i] += slope*mp_f32[i]; + // } + // } + //} #ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(wp[i])); - } + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(wp[i])); + } #endif - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, wp); + ggml_vec_max_f32(nc, &max, wp); + } ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); assert(sum > 0.0); diff --git a/include/llama.h b/include/llama.h index a9af4c48..dd13d657 100644 --- a/include/llama.h +++ b/include/llama.h @@ -340,6 +340,7 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] + bool binary_kq; // whether to use binary KQ mask [EXPERIMENTAL] // Abort callback // if it returns true, execution of llama_decode() will be aborted diff --git a/src/llama.cpp b/src/llama.cpp index 8a85144e..ad3febef 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2348,6 +2348,7 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool binary_kq; enum llama_pooling_type pooling_type; @@ -8446,6 +8447,7 @@ struct llm_build_context { const int32_t n_ctx_orig; const bool flash_attn; + const bool binary_kq; const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -8495,6 +8497,7 @@ struct llm_build_context { kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), + binary_kq (cparams.binary_kq), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -8687,25 +8690,31 @@ struct llm_build_context { } struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { - lctx.inp_KQ_mask = causal - ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) - : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + auto nx = causal ? n_kv : n_tokens; + // Note: we only use a binary mask when nx%32 == 0 because otherwise the CUDA implementation becomes way more messy + //bool can_be_binary = binary_kq && !lctx.is_encoding && !flash_attn && !hparams.use_alibi && nx%32 == 0; + //auto type = can_be_binary ? GGML_TYPE_I32 : flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + auto type = !lctx.is_encoding ? !binary_kq || flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32 : GGML_TYPE_F32; + if (type == GGML_TYPE_I32) nx /= 32; + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); - - return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; + return lctx.inp_KQ_mask; } struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) { GGML_ASSERT(hparams.n_swa > 0); - - lctx.inp_KQ_mask_swa = causal - ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) - : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + auto nx = causal ? n_kv : n_tokens; + // Note: we only use a binary mask when nx%32 == 0 because otherwise the CUDA implementation becomes way more messy + //bool can_be_binary = binary_kq && !lctx.is_encoding && !flash_attn && !hparams.use_alibi && nx%32 == 0; + //auto type = can_be_binary ? GGML_TYPE_I32 : flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + auto type = !lctx.is_encoding ? !binary_kq || flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32 : GGML_TYPE_F32; + if (type == GGML_TYPE_I32) nx /= 32; + lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); ggml_set_input(lctx.inp_KQ_mask_swa); - return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa; + return lctx.inp_KQ_mask_swa; } struct ggml_tensor * build_inp_mean() { @@ -14259,71 +14268,161 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const int64_t n_kv = kv_self.n; const int64_t n_tokens = batch.n_tokens; - - float * data = nullptr; - float * data_swa = nullptr; + if (lctx.inp_KQ_mask && lctx.inp_KQ_mask_swa) { + GGML_ASSERT(lctx.inp_KQ_mask->type == lctx.inp_KQ_mask_swa->type); + } if (lctx.inp_KQ_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); - data = (float *) lctx.inp_KQ_mask->data; } - if (lctx.inp_KQ_mask_swa) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer)); - data_swa = (float *) lctx.inp_KQ_mask_swa->data; } - // For causal attention, use only the previous KV cells - // of the correct sequence for each token of the batch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - for (int h = 0; h < 1; ++h) { + auto mask_type = lctx.inp_KQ_mask ? lctx.inp_KQ_mask->type : lctx.inp_KQ_mask_swa->type; + GGML_ASSERT(mask_type == GGML_TYPE_I32 || mask_type == GGML_TYPE_F32 || mask_type == GGML_TYPE_F16); + + if (mask_type == GGML_TYPE_I32) { + // in order this to be true, we are not using alibi + GGML_ASSERT(!hparams.use_alibi); + uint32_t * h_data = lctx.inp_KQ_mask ? (uint32_t *)lctx.inp_KQ_mask->data : nullptr; + uint32_t * h_data_swa = lctx.inp_KQ_mask_swa ? (uint32_t *)lctx.inp_KQ_mask_swa->data : nullptr; for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; const llama_seq_id seq_id = batch.seq_id[j][0]; + uint32_t u = 0, u_swa = 0; + uint32_t m = 1; + for (int i = 0; i < n_kv; ++i) { - float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { - f = -INFINITY; - } else { - if (hparams.use_alibi) { - f = -std::abs(lctx.kv_self.cells[i].pos - pos); - } else { - f = 0.0f; - } + if (lctx.kv_self.cells[i].pos > pos || !lctx.kv_self.cells[i].has_seq_id(seq_id)) { + u |= m; u_swa |= m; } + if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) u_swa |= m; + m <<= 1; + if (!m) { + if (h_data) *h_data++ = ~u; + if (h_data_swa) *h_data_swa++ = ~u_swa; + u = u_swa = 0; m = 1; + } + } + if (m > 1) { + if (h_data) *h_data++ = ~u; + if (h_data_swa) *h_data_swa++ = ~u_swa; + } - if (data) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = f; + } + + auto n_pad = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); + if (n_pad > n_tokens) { + auto n_kv_32 = (n_kv + 31)/32; + if (h_data) std::memset(h_data, 0, (n_pad - n_tokens)*n_kv_32*sizeof(uint32_t)); + if (h_data_swa) std::memset(h_data_swa, 0, (n_pad - n_tokens)*n_kv_32*sizeof(uint32_t)); + } + } + + else if (mask_type == GGML_TYPE_F16) { + ggml_fp16_t * h_data = lctx.inp_KQ_mask ? (ggml_fp16_t *)lctx.inp_KQ_mask->data : nullptr; + ggml_fp16_t * h_data_swa = lctx.inp_KQ_mask_swa ? (ggml_fp16_t *)lctx.inp_KQ_mask_swa->data : nullptr; + ggml_fp16_t h_zero = ggml_fp32_to_fp16(0.0f); + ggml_fp16_t h_inf = ggml_fp32_to_fp16(-INFINITY); + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; + + for (int i = 0; i < n_kv; ++i) { + ggml_fp16_t f; + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) f = h_inf; + else f = hparams.use_alibi ? ggml_fp32_to_fp16(-std::abs(lctx.kv_self.cells[i].pos - pos)) : h_zero; + if (h_data) h_data[j*n_kv + i] = f; + if (h_data_swa) h_data_swa[j*n_kv + i] = pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa ? h_inf : f; + } + } + auto n_pad = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); + if (n_pad > n_tokens) { + if (h_data) { + for (int j = 0; j < n_kv; ++j) h_data[n_tokens*n_kv + j] = h_inf; + for (int i = n_tokens+1; i < n_pad; ++i) { + std::memcpy(h_data + i*n_kv, h_data + n_tokens*n_kv, n_kv*sizeof(ggml_fp16_t)); + } + } + if (h_data_swa) { + for (int j = 0; j < n_kv; ++j) h_data_swa[n_tokens*n_kv + j] = h_inf; + for (int i = n_tokens+1; i < n_pad; ++i) { + std::memcpy(h_data_swa + i*n_kv, h_data_swa + n_tokens*n_kv, n_kv*sizeof(ggml_fp16_t)); } + } + } + } + + else { + + float * data = nullptr; + float * data_swa = nullptr; + + if (lctx.inp_KQ_mask) { + data = (float *) lctx.inp_KQ_mask->data; + } + + if (lctx.inp_KQ_mask_swa) { + data_swa = (float *) lctx.inp_KQ_mask_swa->data; + } - // may need to cut off old tokens for sliding window - if (data_swa) { - if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + // For causal attention, use only the previous KV cells + // of the correct sequence for each token of the batch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; + + for (int i = 0; i < n_kv; ++i) { + float f; + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { f = -INFINITY; + } else { + if (hparams.use_alibi) { + f = -std::abs(lctx.kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } + } + + if (data) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = f; + } + + // may need to cut off old tokens for sliding window + if (data_swa) { + if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } + data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; } - data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; } } - } - if (data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } } } - } - if (data_swa) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + if (data_swa) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } } } } } } else { + // TODO + GGML_ASSERT(false); + // when using kv cache, the mask needs to match the kv cache size const int64_t n_tokens = batch.n_tokens; const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens; @@ -16634,6 +16733,7 @@ struct llama_context_params llama_context_default_params() { /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.flash_attn =*/ false, + /*.binary_kq =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -16808,6 +16908,10 @@ struct llama_context * llama_new_context_with_model( return nullptr; } + if (params.binary_kq && params.flash_attn) { + LLAMA_LOG_WARN("%s: binary-KQ mask is currently not used in flash_attn\n", __func__); + } + llama_context * ctx = new llama_context(*model); const auto & hparams = model->hparams; @@ -16824,6 +16928,7 @@ struct llama_context * llama_new_context_with_model( cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; + cparams.binary_kq = params.binary_kq; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -16890,6 +16995,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: binary_kq = %d\n", __func__, cparams.binary_kq); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);