Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RMSNorm Blocked Implementation #638

Open
wants to merge 1 commit into
base: main_perf
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 63 additions & 31 deletions python/perf-kernels/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,35 +46,67 @@ def get_autotune_config():

@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
@triton.jit
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon,
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, eps,
BLOCK_SIZE: tl.constexpr):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
row_idx = tl.program_id(0)

#Calculate squared mean by block
row_start_ptr = input_ptr + row_idx * input_row_stride
row_sum = 0.0
n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1
#tl.device_print("n_cols_blks",n_cols_blks)
for b in tl.range(0, n_cols_blks):
col_offsets = b*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
row_block = tl.load(input_ptrs, cache_modifier=".cg")
row_block = row_block * row_block #square every value the block
row_sum += (tl.sum(row_block, axis=-1) / n_cols) #tl.sum across row

col_offsets = n_cols_blks*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
for row_idx in tl.range(row_start, n_rows, row_step):
row_start_ptr = input_ptr + row_idx * input_row_stride
row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
row_block = row_block * row_block #square every value the block
row_sum += (tl.sum(row_block, axis=-1) / n_cols) #tl.sum across row


row_norm = row_sum + eps
row_norm = tl.rsqrt(row_norm)

#Blocked normalization
output_row_start_ptr = output_ptr + row_idx * output_row_stride
#for b in tl.range(0, n_cols, BLOCK_SIZE):
for b in tl.range(0, n_cols_blks):
col_offsets = b*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0)
row_norm = row * row #square each value
row_norm = tl.sum(row_norm, axis=-1) #sum across columns(axis=-1)
row_norm = row_norm / n_cols #divide by n_cols
row_norm = row_norm + epsilon #add epsilon
row_norm = tl.rsqrt(row_norm) #take rsqrt, this is normalization value
rms_norm = row * row_norm #multiply each x by normalization value
rms_norm = rms_norm * g #element wise multiplication with g

output_row_start_ptr = output_ptr + row_idx * output_row_stride
row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block of input
g = tl.load(g_ptr + col_offsets, cache_modifier=".cg") #load block of g
output = row_block * row_norm #element wise multiply with rms_norm
output = output * g #element wise multiplication with g

output_ptrs = output_row_start_ptr + col_offsets
output_ptrs = tl.multiple_of(output_ptrs, (16, ))
tl.store(output_ptrs, rms_norm, mask=mask)
tl.store(output_ptrs, output)

col_offsets = n_cols_blks*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") #load block of input
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg") #load block of g
output = row_block * row_norm #element wise multiply with rms_norm
output = output * g #element wise multiplication with g

#tl.device_print("output",output)
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, output, mask=mask)



def triton_rmsnorm(x, g, epsilon=1e-6):
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)
#Restricting BLOCK_SIZE to 64Kb is an important optimization. Otherwise,
#performance can drop significantly for larger n_cols.
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))

y = torch.empty_like(x, device='cuda')

Expand All @@ -84,7 +116,6 @@ def triton_rmsnorm(x, g, epsilon=1e-6):

return y


def torch_rmsnorm(x, g):
M, N = x.shape
if hasattr(torch.nn, 'RMSNorm'):
Expand All @@ -95,15 +126,17 @@ def torch_rmsnorm(x, g):
rms_norm = torch.div(x, rms.unsqueeze(1).repeat(1, N)) * g
return rms_norm


# yapf: disable
@pytest.mark.parametrize('M, N', [
(1, 4),
(2, 10),
(8192, 4096),
(4096, 8192),
(1, 8192),
(873, 1245),
])
(1, 4),
(2, 10),
(8192, 4096),
(4096, 8192),
(1, 8192),
(873, 1245),
(1, 98304)
])
# yapf: enable
def test_rmsnorm(M, N):
torch.manual_seed(0)
x = torch.randn(M, N, device='cuda')
Expand All @@ -114,7 +147,6 @@ def test_rmsnorm(M, N):

assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)


#Benchmark
arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32}

Expand Down
Loading