From 1d8fca72ae9154eec0e1c0a75cfaac3c50f08e4a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 27 May 2024 12:10:19 +0300 Subject: [PATCH] metal : add GGML_OP_REPEAT kernels (#7557) ggml-ci --- ggml-metal.m | 53 ++++++++++++++++++++++++++++++++++++++++++++---- ggml-metal.metal | 47 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 4 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 15fb68fc489af..ff9ae55aada74 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -35,6 +35,10 @@ GGML_METAL_KERNEL_TYPE_MUL_ROW, GGML_METAL_KERNEL_TYPE_DIV, GGML_METAL_KERNEL_TYPE_DIV_ROW, + GGML_METAL_KERNEL_TYPE_REPEAT_F32, + GGML_METAL_KERNEL_TYPE_REPEAT_F16, + GGML_METAL_KERNEL_TYPE_REPEAT_I32, + GGML_METAL_KERNEL_TYPE_REPEAT_I16, GGML_METAL_KERNEL_TYPE_SCALE, GGML_METAL_KERNEL_TYPE_SCALE_4, GGML_METAL_KERNEL_TYPE_CLAMP, @@ -485,6 +489,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); @@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_ACC: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_REPEAT: case GGML_OP_SCALE: case GGML_OP_CLAMP: case GGML_OP_SQR: @@ -979,8 +988,6 @@ static enum ggml_status ggml_metal_graph_compute( switch (dst->op) { case GGML_OP_CONCAT: { - const int64_t nb = ne00; - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; [encoder setComputePipelineState:pipeline]; @@ -1011,7 +1018,6 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; const int nth = MIN(1024, ne0); @@ -1021,11 +1027,14 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_OP_MUL: case GGML_OP_DIV: { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + const size_t offs = 0; bool bcast_row = false; - int64_t nb = ne00; + int64_t nb = ne00; // used by the "row" kernels id pipeline = nil; @@ -1094,6 +1103,42 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } } break; + case GGML_OP_REPEAT: + { + id pipeline; + + switch (src0t) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; + case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; + default: GGML_ASSERT(false); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_ACC: { GGML_ASSERT(src0t == GGML_TYPE_F32); diff --git a/ggml-metal.metal b/ggml-metal.metal index ce51c74d5158d..174086b5b6293 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -168,6 +168,53 @@ kernel void kernel_div( } } +template +kernel void kernel_repeat( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3 % ne03; + const int64_t i02 = i2 % ne02; + const int64_t i01 = i1 % ne01; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i00 = i0 % ne00; + *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00)); + } +} + +typedef decltype(kernel_repeat) kernel_repeat_t; + +template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; + // assumption: src1 is a row // broadcast src1 into src0 kernel void kernel_add_row(