diff --git a/python/perf-kernels/rmsnorm.py b/python/perf-kernels/rmsnorm.py index 0df4e2a62517..854e6509ba86 100644 --- a/python/perf-kernels/rmsnorm.py +++ b/python/perf-kernels/rmsnorm.py @@ -47,15 +47,29 @@ def get_autotune_config(): @triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) @triton.jit def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, eps, - BLOCK_SIZE: tl.constexpr): + BLOCK_SIZE: tl.constexpr, use_mask: tl.constexpr): row_start = tl.program_id(0) row_idx = row_start #Calculate squared mean by block row_start_ptr = input_ptr + row_idx * input_row_stride row_sum = 0.0 - for b in tl.range(0, n_cols, BLOCK_SIZE): - col_offsets = b + tl.arange(0, BLOCK_SIZE) + loop_num = tl.cdiv(n_cols, BLOCK_SIZE) + if use_mask: + loop_num -= 1 + #for b in tl.range(0, n_cols, BLOCK_SIZE): + loop_num_t = loop_num + for b in tl.range(0, loop_num_t): + col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + mask = col_offsets < n_cols + row_block = tl.load(input_ptrs, cache_modifier=".cg") + row_block = row_block * row_block #square every value the block + row_sum += (tl.sum(row_block, axis=-1) / n_cols + ) #tl.sum across row, divide by block_size and add it running sum + + if use_mask: + col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets mask = col_offsets < n_cols row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") @@ -67,9 +81,26 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride row_norm = tl.rsqrt(row_norm) #Blocked normalization + loop_num_t = loop_num output_row_start_ptr = output_ptr + row_idx * output_row_stride - for b in tl.range(0, n_cols, BLOCK_SIZE): - col_offsets = b + tl.arange(0, BLOCK_SIZE) + #for b in tl.range(0, n_cols, BLOCK_SIZE): + for b in tl.range(0, loop_num_t): + col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + mask = col_offsets < n_cols + #row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") #load block of input + #g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg") #load block of g + row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block of input + g = tl.load(g_ptr + col_offsets, cache_modifier=".cg") #load block of g + output = row_block * row_norm #element wise multiply with rms_norm + output = output * g #element wise multiplication with g + + output_ptrs = output_row_start_ptr + col_offsets + #tl.store(output_ptrs, output, mask=mask) + tl.store(output_ptrs, output) + + if use_mask: + col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets mask = col_offsets < n_cols row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") #load block of input @@ -81,7 +112,7 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride tl.store(output_ptrs, output, mask=mask) -def rmsnorm(x, epsilon=1e-6): +def rmsnorm(x, epsilon=1e-6, use_mask=1): n_rows, n_cols = x.shape #Restricting BLOCK_SIZE to 64Kb is an important optimization. Otherwise, #performance can drop significantly for larger n_cols. @@ -93,7 +124,7 @@ def rmsnorm(x, epsilon=1e-6): num_programs = n_rows grid = lambda meta: (num_programs, ) - rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE) + rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE, use_mask) return y