Skip to content

Commit

Permalink
Merge branch 'master' into HeyangQin/fastgen_moe_h100
Browse files Browse the repository at this point in the history
  • Loading branch information
HeyangQin authored May 29, 2024
2 parents 7b8ba2b + 2fc702e commit 8274a01
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

// DeepSpeed Team

#include "quantize.h"
#include "fp_quantize.h"

#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

#include <stdexcept>
#include "context.h"
#include "fp_quantize.h"
#include "memory_access_utils.h"
#include "quantize.h"
#include "reduction_utils.h"

#include <cuda.h>
Expand Down
File renamed without changes.
28 changes: 19 additions & 9 deletions deepspeed/checkpoint/deepspeed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import os
import re
from typing import Dict
import torch

Expand All @@ -21,6 +22,7 @@
ARGS_KEY = 'args'
CHECKPOINT_INFO_KEY = 'checkpoint_info'
ITERATION_KEY = 'iteration'
LAYER_FILE_PREFIX_PATTERN = r'layer_(\d+)-model_.*'

SEQUENTIAL_LAYERS = [
'input_layernorm.weight', 'input_layernorm.bias', 'self_attention.dense.bias', 'post_attention_layernorm.weight',
Expand All @@ -32,7 +34,13 @@

class DeepSpeedCheckpoint(object):

def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
def __init__(self,
dir,
tp_degree=None,
pp_degree=None,
dp_degree=None,
final_layer_norm_idx=FINAL_LAYER_NORM_INDEX):
self.final_layer_norm_idx = final_layer_norm_idx
self.dir = dir

pipeline_parallel = len(get_files_with_prefix(get_files(dir), LAYER_FILE_PREFIX)) > 0
Expand Down Expand Up @@ -73,7 +81,7 @@ def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
self.pp_to_transformer_map = self._build_pp_transformer_map()
self.transformer_file_map = self._build_transformer_file_map()
self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map(self.final_layer_norm_idx)
self._build_global_state()

def is_change_tp_degree(self):
Expand Down Expand Up @@ -125,7 +133,7 @@ def get_embedding_layer_id(self):
return self.layer_keys[EMBEDDING_LAYER_INDEX]

def get_final_norm_layer_id(self):
return self.layer_keys[FINAL_LAYER_NORM_INDEX]
return self.layer_keys[self.final_layer_norm_idx]

def get_iteration(self):
if not ITERATION_KEY in self.global_state:
Expand Down Expand Up @@ -214,7 +222,7 @@ def get_2d_parallel_files(self, tp_index: int, pp_index: int) -> list:
def _build_pp_transformer_map(self):
data_map = {}
if self.pp_degree > 0:
transformer_layers = self.layer_keys[1:-1]
transformer_layers = self.layer_keys[1:self.final_layer_norm_idx]
layers_per_pp = len(transformer_layers) // self.pp_degree
data_map = {
i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp]
Expand All @@ -229,7 +237,7 @@ def _dump_mapping(self, data_map, map_tag=None):
print(f'{k} = {v}')

def _build_transformer_file_map(self):
transformer_layer_keys = self.layer_keys[1:-1]
transformer_layer_keys = self.layer_keys[1:self.final_layer_norm_idx]
file_map = {}
# XXX: this is not guaranteed
layers_per_pp = 1
Expand All @@ -238,7 +246,7 @@ def _build_transformer_file_map(self):
#print(f"{transformer_layer_keys} {layers_per_pp}")
for key_index, layer_key in enumerate(transformer_layer_keys):
pp_index = key_index // layers_per_pp
layer_files = get_files_with_prefix(self.layer_files, layer_key)
layer_files = get_files_with_prefix(self.layer_files, layer_key + '-')
layer_file_partitions = partition_data(layer_files, self.tp_degree)
for tp_index in range(self.tp_degree):
map_key = (tp_index, pp_index)
Expand All @@ -263,11 +271,13 @@ def validate_files(self):

def _get_layer_keys(self):
key_set = set()
key_len = len(LAYER_FILE_PREFIX) + 2
for file_path in self.layer_files:
_, fname = os.path.split(file_path)
key_set.add(fname[:key_len])
return sorted(list(key_set))
layer_id = re.search(LAYER_FILE_PREFIX_PATTERN, fname).group(1)
key_set.add(layer_id)
sorted_ids = sorted(list(key_set), key=int)
layer_keys = [LAYER_FILE_PREFIX + str(layer_id) for layer_id in sorted_ids]
return layer_keys

def _merge_state_dicts(self, sd_list):
merged_sd = {}
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/mics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def wait(self) -> None:
"""
# let the current stream to op
try:
print("HANDLE", self.allgather_handle)
# print("HANDLE", self.allgather_handle)
instrument_w_nvtx(self.allgather_handle.wait)()
except (ValueError, RuntimeError) as e:
log_dist(
Expand Down
2 changes: 2 additions & 0 deletions docs/_sass/minimal-mistakes/_sidebar.scss
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
top: auto;
right: 0;
width: $right-sidebar-width-narrow;
margin-right: -1.5 * $right-sidebar-width-narrow;
padding-left: 1em;
z-index: 10;

Expand All @@ -93,6 +94,7 @@

@include breakpoint($x-large) {
width: $right-sidebar-width;
margin-right: -1.5 * $right-sidebar-width-narrow;
}
}

Expand Down
4 changes: 2 additions & 2 deletions op_builder/fp_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def filter_ccs(self, ccs):

def sources(self):
return [
"csrc/fp_quantizer/quantize.cu",
"csrc/fp_quantizer/quantize.cpp",
"csrc/fp_quantizer/fp_quantize.cu",
"csrc/fp_quantizer/fp_quantize.cpp",
]

def extra_ldflags(self):
Expand Down

0 comments on commit 8274a01

Please sign in to comment.