Skip to content

Commit

Permalink
merge: Resolve merge conflicts pulling in from Microsoft upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
saforem2 committed Oct 8, 2024
2 parents cef3fc7 + 4448492 commit fd94b37
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 10 deletions.
6 changes: 6 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,5 +1627,11 @@ def _add_profiler_args(parser):
type=str,
default='2,3',
help="Which steps to profile. Format: <start step>,<end step>")

group.add_argument("--profile-ranks",
type=int,
nargs='+',
default=None,
help="Which ranks to profile. Format: 0 1 2 3")

return parser
11 changes: 8 additions & 3 deletions megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
args = get_args()
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias

Expand Down Expand Up @@ -367,9 +368,13 @@ def backward(ctx, grad_output):
# grad_weight = None
# else:
# grad_weight = grad_output.t().matmul(total_input)
from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore
WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction)
grad_weight = None
if args.enable_zbh1_pipeline:
from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore
WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)

grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.sequence_parallel:
Expand Down
16 changes: 12 additions & 4 deletions megatron/model/fused_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
from megatron import get_args

import torch
from torch.nn.parameter import Parameter
from torch.nn import init
import intel_extension_for_pytorch as ipex # noqa


# Taken from facebookresearch/llama
class RMSNorm(torch.nn.Module):
# def __init__(self, dim: int, eps: float = 1e-6):
def __init__(self, dim: int, eps: float = 1e-6, sequence_parallel: bool = False):
def __init__(
self, dim: int, eps: float = 1e-6, sequence_parallel: bool = False
):
super().__init__()
self.eps = eps
self.weight = Parameter(torch.ones(dim))
self.weight = Parameter(torch.ones(dim, dtype=get_args().params_dtype))
self.sequence_parallel = sequence_parallel
setattr(self.weight, "sequence_parallel", self.sequence_parallel)

def forward(self, x):
output = torch.xpu.IpexRmsNorm(x, self.weight.shape, self.weight, self.eps)
output = torch.xpu.IpexRmsNorm(
x, self.weight.shape, self.weight, self.eps
)
return output
7 changes: 5 additions & 2 deletions megatron/model/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

# Taken from facebookresearch/llama
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, sequence_parallel: bool = False):
def __init__(
self, dim: int, eps: float = 1e-6, sequence_parallel: bool = False
):
super().__init__()
self.eps = eps
init_device = None
Expand All @@ -20,7 +22,8 @@ def __init__(self, dim: int, eps: float = 1e-6, sequence_parallel: bool = False)
torch.empty(dim, device=init_device, dtype=get_args().params_dtype)
)
init.ones_(self.weight)
setattr(self.weight, "sequence_parallel", sequence_parallel)
self.sequence_parallel = sequence_parallel
setattr(self.weight, "sequence_parallel", self.sequence_parallel)

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
Expand Down
4 changes: 3 additions & 1 deletion megatron/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def is_end_step():
def is_capture_step():
return cur_step >= start_step and cur_step <= end_step

if args.profile.startswith('pt'):
if args.profile.startswith('pt') and (
args.profile_ranks is None or torch.distributed.get_rank() in args.profile_ranks
):
schedule = torch.profiler.schedule(wait=0, warmup=0, active=active_steps, repeat=1)
activities = [torch.profiler.ProfilerActivity.CPU]
activities.extend([torch.profiler.ProfilerActivity.HPU] if device.startswith("hpu") else [])
Expand Down

0 comments on commit fd94b37

Please sign in to comment.