Skip to content

Commit

Permalink
Add use mask for loads
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Batra committed Sep 23, 2024
1 parent bbdc10b commit 207f73a
Showing 1 changed file with 38 additions and 7 deletions.
45 changes: 38 additions & 7 deletions python/perf-kernels/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,29 @@ def get_autotune_config():
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
@triton.jit
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, eps,
BLOCK_SIZE: tl.constexpr):
BLOCK_SIZE: tl.constexpr, use_mask: tl.constexpr):
row_start = tl.program_id(0)
row_idx = row_start

#Calculate squared mean by block
row_start_ptr = input_ptr + row_idx * input_row_stride
row_sum = 0.0
for b in tl.range(0, n_cols, BLOCK_SIZE):
col_offsets = b + tl.arange(0, BLOCK_SIZE)
loop_num = tl.cdiv(n_cols, BLOCK_SIZE)
if use_mask:
loop_num -= 1
#for b in tl.range(0, n_cols, BLOCK_SIZE):
loop_num_t = loop_num
for b in tl.range(0, loop_num_t):
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, cache_modifier=".cg")
row_block = row_block * row_block #square every value the block
row_sum += (tl.sum(row_block, axis=-1) / n_cols
) #tl.sum across row, divide by block_size and add it running sum

if use_mask:
col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
Expand All @@ -67,9 +81,26 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
row_norm = tl.rsqrt(row_norm)

#Blocked normalization
loop_num_t = loop_num
output_row_start_ptr = output_ptr + row_idx * output_row_stride
for b in tl.range(0, n_cols, BLOCK_SIZE):
col_offsets = b + tl.arange(0, BLOCK_SIZE)
#for b in tl.range(0, n_cols, BLOCK_SIZE):
for b in tl.range(0, loop_num_t):
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
#row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") #load block of input
#g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg") #load block of g
row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block of input
g = tl.load(g_ptr + col_offsets, cache_modifier=".cg") #load block of g
output = row_block * row_norm #element wise multiply with rms_norm
output = output * g #element wise multiplication with g

output_ptrs = output_row_start_ptr + col_offsets
#tl.store(output_ptrs, output, mask=mask)
tl.store(output_ptrs, output)

if use_mask:
col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") #load block of input
Expand All @@ -81,7 +112,7 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
tl.store(output_ptrs, output, mask=mask)


def rmsnorm(x, epsilon=1e-6):
def rmsnorm(x, epsilon=1e-6, use_mask=1):
n_rows, n_cols = x.shape
#Restricting BLOCK_SIZE to 64Kb is an important optimization. Otherwise,
#performance can drop significantly for larger n_cols.
Expand All @@ -93,7 +124,7 @@ def rmsnorm(x, epsilon=1e-6):

num_programs = n_rows
grid = lambda meta: (num_programs, )
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE)
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE, use_mask)

return y

Expand Down

0 comments on commit 207f73a

Please sign in to comment.