Skip to content

Commit

Permalink
register fused rmsnorm as pytorch custom op
Browse files Browse the repository at this point in the history
ghstack-source-id: 401d968feaa2e58eedb573c07739694358a8d4a6
Pull Request resolved: #296
  • Loading branch information
tianyu-l committed May 2, 2024
1 parent e34d2ac commit faa78a4
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 72 deletions.
1 change: 1 addition & 0 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

CUDA_LAUNCH_BLOCKING=1 \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
283 changes: 211 additions & 72 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
# LICENSE file in the root directory of this source tree.

import math
from typing import Tuple

import torch
import torch.nn as nn
from torch import Tensor

import triton
import triton.language as tl
Expand Down Expand Up @@ -213,47 +215,95 @@ def _rms_norm_bwd_kernel_sm(
tl.store(DW + row_block_id * N + cols, dw, mask=mask)


class TritonFusedRMSNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, eps):
x_shape_start = x.shape
def fused_rmsnorm_forward(x, weight, eps):
x_shape_start = x.shape

# Flatten input
x = x.view(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if weight.stride(-1) != 1:
weight = weight.contiguous()
# Flatten input
x = x.view(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if weight.stride(-1) != 1:
weight = weight.contiguous()

M, N = x.shape
y = torch.empty_like(x)
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
M, N = x.shape
y = torch.empty_like(x)
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)

max_size = 65536 // x.element_size()
block_N = min(max_size, triton.next_power_of_2(N))
max_size = 65536 // x.element_size()
block_N = min(max_size, triton.next_power_of_2(N))

if N > block_N:
raise ValueError(f"N {N} must be <= {block_N=}")
if N > block_N:
raise ValueError(f"N {N} must be <= {block_N=}")

grid = lambda meta: (M,)
_rms_norm_fwd_kernel[grid](
x,
x.stride(0),
y,
y.stride(0),
weight,
rstd,
eps,
M,
N,
block_N,
)
grid = lambda meta: (M,)
_rms_norm_fwd_kernel[grid](
x,
x.stride(0),
y,
y.stride(0),
weight,
rstd,
eps,
M,
N,
block_N,
)

y = y.reshape(x_shape_start)
return y, rstd


def fused_rmsnorm_backward(dy, x, weight, eps, rstd, x_shape_start):
# Flatten input and output gradients
dy = dy.view(-1, dy.shape[-1])
if dy.stride(-1) != 1:
dy = dy.contiguous()

M, N = dy.shape
dx = torch.empty_like(x)
dw = torch.empty_like(weight)

sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)

max_size = 65536 // x.element_size()
block_N = min(max_size, triton.next_power_of_2(N))
rows_per_sm = math.ceil(M / sm_count)

if N > block_N:
raise ValueError(f"N {N} must be <= {block_N=}")

grid = lambda meta: (sm_count,)
_rms_norm_bwd_kernel_sm[grid](
x,
x.stride(0),
weight,
dy,
dy.stride(0),
dx,
dx.stride(0),
rstd,
_dw,
eps,
M,
N,
rows_per_sm,
block_N,
)
dw = _dw.sum(0).to(weight.dtype)
dx = dx.view(x_shape_start)
return dx, dw, None


class TritonFusedRMSNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, eps):
y, rstd = fused_rmsnorm_forward(x, weight, eps)

ctx.eps = eps
ctx.x_shape_start = x.shape
ctx.save_for_backward(x, weight, rstd)
ctx.x_shape_start = x_shape_start

y = y.reshape(x_shape_start)
return y

@staticmethod
Expand All @@ -262,55 +312,144 @@ def backward(ctx, dy):
eps = ctx.eps
x_shape_start = ctx.x_shape_start

# Flatten input and output gradients
dy = dy.view(-1, dy.shape[-1])
if dy.stride(-1) != 1:
dy = dy.contiguous()

M, N = dy.shape
dx = torch.empty_like(x)
dw = torch.empty_like(weight)

sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)

max_size = 65536 // x.element_size()
block_N = min(max_size, triton.next_power_of_2(N))
rows_per_sm = math.ceil(M / sm_count)

if N > block_N:
raise ValueError(f"N {N} must be <= {block_N=}")

grid = lambda meta: (sm_count,)
_rms_norm_bwd_kernel_sm[grid](
x,
x.stride(0),
weight,
dy,
dy.stride(0),
dx,
dx.stride(0),
rstd,
_dw,
eps,
M,
N,
rows_per_sm,
block_N,
)
dw = _dw.sum(0).to(weight.dtype)
dx = dx.view(x_shape_start)
return dx, dw, None
return fused_rmsnorm_backward(dy, x, weight, eps, rstd, x_shape_start)


# expose fusedRMSNorm as a function
def fused_rms_norm_fn(
x,
weight,
eps=1e-6,
eps,
):
# option 1: register forward and backward separately

return TritonFusedRMSNorm.apply(
x,
weight,
eps,
)

# option 2: register forward only, and register backward using torch.library.register_autograd

# args = (x, weight, eps,)
# torch.library.opcheck(fused_rmsnorm_forward, args)
# return fused_rmsnorm_forward(x, weight, eps)[0]


# @torch.library.custom_op("torchtitan::fused_rmsnorm", mutates_args=())
# def fused_rmsnorm_forward(x: Tensor, weight: Tensor, eps: float) -> Tuple[Tensor, Tensor]:
# x_shape_start = x.shape

# # Flatten input
# x = x.view(-1, x.shape[-1])
# if x.stride(-1) != 1:
# x = x.contiguous()
# if weight.stride(-1) != 1:
# weight = weight.contiguous()

# M, N = x.shape
# y = torch.empty_like(x)
# rstd = torch.empty((M,), dtype=torch.float32, device=x.device)

# max_size = 65536 // x.element_size()
# block_N = min(max_size, triton.next_power_of_2(N))

# if N > block_N:
# raise ValueError(f"N {N} must be <= {block_N=}")

# grid = lambda meta: (M,)
# _rms_norm_fwd_kernel[grid](
# x,
# x.stride(0),
# y,
# y.stride(0),
# weight,
# rstd,
# eps,
# M,
# N,
# block_N,
# )

# y = y.reshape(x_shape_start)
# # return y
# return y, rstd


# def setup_context(ctx, inputs, output) -> Tensor:
# x, weight, eps = inputs
# y, rstd = output

# x_shape_start = x.shape

# ctx.eps = eps
# ctx.save_for_backward(x, weight, rstd)
# ctx.x_shape_start = x_shape_start


# def fused_rmsnorm_backward(ctx, dy, drstd):
# x, weight, rstd = ctx.saved_tensors
# eps = ctx.eps
# x_shape_start = ctx.x_shape_start

# # Flatten input and output gradients
# dy = dy.view(-1, dy.shape[-1])
# if dy.stride(-1) != 1:
# dy = dy.contiguous()

# M, N = dy.shape
# dx = torch.empty_like(x)
# dw = torch.empty_like(weight)

# sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
# _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)

# max_size = 65536 // x.element_size()
# block_N = min(max_size, triton.next_power_of_2(N))
# rows_per_sm = math.ceil(M / sm_count)

# if N > block_N:
# raise ValueError(f"N {N} must be <= {block_N=}")

# grid = lambda meta: (sm_count,)
# _rms_norm_bwd_kernel_sm[grid](
# x,
# x.stride(0),
# weight,
# dy,
# dy.stride(0),
# dx,
# dx.stride(0),
# rstd,
# _dw,
# eps,
# M,
# N,
# rows_per_sm,
# block_N,
# )
# dw = _dw.sum(0).to(weight.dtype)
# dx = dx.view(x_shape_start)
# return dx, dw, None


# torch.library.register_autograd("torchtitan::fused_rmsnorm", fused_rmsnorm_backward, setup_context=setup_context)


# @torch.library.register_fake("torchtitan::fused_rmsnorm")
# def _(x, weight, eps):
# x_shape_start = x.shape

# # Flatten input
# x = x.view(-1, x.shape[-1])
# if x.stride(-1) != 1:
# x = x.contiguous()
# # if weight.stride(-1) != 1:
# # weight = weight.contiguous()

# M, N = x.shape
# y = torch.empty_like(x)
# rstd = torch.empty((M,), dtype=torch.float32, device=x.device)

# y = y.reshape(x_shape_start)
# return y, rstd

0 comments on commit faa78a4

Please sign in to comment.