diff --git a/torch_patches/models/dlrm_mlperf_v1.0_training.diff b/torch_patches/models/dlrm_mlperf_v1.0_training.diff new file mode 100644 index 000000000..21a0484f2 --- /dev/null +++ b/torch_patches/models/dlrm_mlperf_v1.0_training.diff @@ -0,0 +1,1904 @@ +diff --git a/data_loader_terabyte.py b/data_loader_terabyte.py +index b91c9fb..01c1b7a 100644 +--- a/data_loader_terabyte.py ++++ b/data_loader_terabyte.py +@@ -14,6 +14,7 @@ import time + import math + from tqdm import tqdm + import argparse ++import extend_distributed as ext_dist + + + class DataLoader: +@@ -191,6 +192,82 @@ def _test(): + ) + ) + ++class HybridParallelCriteoBinDataset(Dataset): ++ """Hybrid parallel binary version of criteo dataset.""" ++ def __init__(self, data_file, counts_file, ++ batch_size=1, max_ind_range=-1, bytes_per_feature=4, sparse_dense_boundary=2048): ++ # dataset ++ self.tar_fea = 1 # single target ++ self.den_fea = 13 # 13 dense features ++ with np.load(counts_file) as data: ++ self.counts = data["counts"] ++ self.global_sparse_embs = [] ++ self.global_dense_embs = [] ++ num = 0 ++ for count in self.counts: ++ if count >= sparse_dense_boundary: ++ self.global_sparse_embs.append(num) ++ else: ++ self.global_dense_embs.append(num) ++ num +=1 ++ data_file_size = os.path.getsize(data_file) ++ self.bytes_per_feature = bytes_per_feature ++ self.sparse_index_bytes_per_rank = batch_size * bytes_per_feature ++ self.ddp_tot_fea = self.tar_fea + self.den_fea + len(self.global_dense_embs) ++ ++ self.max_ind_range = max_ind_range ++ self.ddp_bytes_per_sample = self.bytes_per_feature * (self.tar_fea + self.den_fea + len(self.global_dense_embs)) ++ self.ddp_bytes_per_rank = batch_size // ext_dist.my_size * self.ddp_bytes_per_sample ++ self.num_samples = data_file_size // self.ddp_bytes_per_sample ++ self.ddp_bytes_per_batch = self.ddp_bytes_per_sample * batch_size ++ self.num_batches = math.ceil(data_file_size / self.ddp_bytes_per_batch) ++ if ext_dist.my_size > 1 and self.num_batches * self.ddp_bytes_per_batch > data_file_size: ++ self.last_batch = (data_file_size % self.ddp_bytes_per_batch) // self.ddp_bytes_per_sample ++ self.ddp_bytes_last_batch = self.last_batch // ext_dist.my_size * self.ddp_bytes_per_sample ++ ++ ++ n_emb_sparse = len(self.global_sparse_embs) ++ self.num_grps = ext_dist.my_size // len(self.global_sparse_embs) ++ self.n_local_emb_sparse, self.n_sparse_emb_per_rank = ext_dist.get_split_lengths(n_emb_sparse, split=True) ++ self.local_ln_emb_sparse_slice = ext_dist.get_my_slice(n_emb_sparse) ++ self.local_ln_emb_sparse = self.global_sparse_embs[self.local_ln_emb_sparse_slice] ++ ++ self.file = open(data_file, 'rb') ++ self.emb_num_2_fd = dict() ++ ++ for num in self.local_ln_emb_sparse: ++ emb_index_file = data_file.split('_data_parallel')[0] + '_sparse_embedding_index_{}.bin'.format(num) ++ self.emb_num_2_fd[num] = open(emb_index_file, 'rb') ++ index_num_batches = math.ceil(os.path.getsize(emb_index_file) / (batch_size * self.bytes_per_feature)) ++ assert(index_num_batches == self.num_batches) ++ # hardcoded for now ++ self.m_den = 13 ++ ++ def __len__(self): ++ return self.num_batches ++ ++ def __getitem__(self, idx): ++ with torch.autograd.profiler.record_function("HybridParallelCriteoBinDataset:__getitem__"): ++ my_rank = ext_dist.dist.get_rank() if ext_dist.my_size > 1 else 0 ++ rank_ddp_size = self.ddp_bytes_last_batch if idx == (self.num_batches - 1) else self.ddp_bytes_per_rank ++ rank_sparse_index_size = (self.last_batch // ext_dist.my_size) * ext_dist.my_size * self.bytes_per_feature if idx == (self.num_batches - 1) else self.sparse_index_bytes_per_rank ++ self.file.seek(idx * self.ddp_bytes_per_batch + rank_ddp_size * my_rank, 0) ++ raw_ddp_data = self.file.read(rank_ddp_size) ++ ++ array_ddp = np.frombuffer(raw_ddp_data, dtype=np.int32) ++ tensor = torch.from_numpy(array_ddp).view((-1, self.ddp_tot_fea)) ++ for num in self.local_ln_emb_sparse: ++ self.emb_num_2_fd[num].seek(idx * self.sparse_index_bytes_per_rank, 0) ++ global_emb_index = self.emb_num_2_fd[num].read(rank_sparse_index_size) ++ index_numpy = np.frombuffer(global_emb_index, dtype = np.int32) ++ index_tensor = torch.from_numpy(index_numpy).reshape(ext_dist.my_size, -1).t() ++ tensor = torch.cat((tensor, index_tensor), dim=1) ++ ++ return _transform_features(x_int_batch=tensor[:, 1:14], ++ x_cat_batch=tensor[:, 14:], ++ y_batch=tensor[:, 0], ++ max_ind_range=self.max_ind_range, ++ flag_input_torch_tensor=True) + + class CriteoBinDataset(Dataset): + """Binary version of criteo dataset.""" +@@ -214,7 +291,21 @@ class CriteoBinDataset(Dataset): + bytes_per_sample = bytes_per_feature * self.tot_fea + self.num_samples = data_file_size // bytes_per_sample + +- print('data file:', data_file, 'number of batches:', self.num_batches) ++ if ext_dist.my_size > 1: ++ self.bytes_per_rank = self.bytes_per_batch // ext_dist.my_size ++ else: ++ self.bytes_per_rank = self.bytes_per_batch ++ ++ if ext_dist.my_size > 1 and self.num_batches * self.bytes_per_batch > data_file_size: ++ last_batch = (data_file_size % self.bytes_per_batch) // bytes_per_sample ++ self.bytes_last_batch = last_batch // ext_dist.my_size * bytes_per_sample ++ else: ++ self.bytes_last_batch = self.bytes_per_rank ++ ++ if self.bytes_last_batch == 0: ++ self.num_batches = self.num_batches - 1 ++ self.bytes_last_batch = self.bytes_per_rank ++ + self.file = open(data_file, 'rb') + + with np.load(counts_file) as data: +@@ -227,16 +318,19 @@ class CriteoBinDataset(Dataset): + return self.num_batches + + def __getitem__(self, idx): +- self.file.seek(idx * self.bytes_per_batch, 0) +- raw_data = self.file.read(self.bytes_per_batch) +- array = np.frombuffer(raw_data, dtype=np.int32) +- tensor = torch.from_numpy(array).view((-1, self.tot_fea)) +- +- return _transform_features(x_int_batch=tensor[:, 1:14], +- x_cat_batch=tensor[:, 14:], +- y_batch=tensor[:, 0], +- max_ind_range=self.max_ind_range, +- flag_input_torch_tensor=True) ++ with torch.autograd.profiler.record_function("CriteoBinDataset:__getitem__"): ++ my_rank = ext_dist.dist.get_rank() if ext_dist.my_size > 1 else 0 ++ rank_size = self.bytes_last_batch if idx == (self.num_batches - 1) else self.bytes_per_rank ++ self.file.seek(idx * self.bytes_per_batch + rank_size * my_rank, 0) ++ raw_data = self.file.read(rank_size) ++ array = np.frombuffer(raw_data, dtype=np.int32) ++ tensor = torch.from_numpy(array).view((-1, self.tot_fea)) ++ ++ return _transform_features(x_int_batch=tensor[:, 1:14], ++ x_cat_batch=tensor[:, 14:], ++ y_batch=tensor[:, 0], ++ max_ind_range=self.max_ind_range, ++ flag_input_torch_tensor=True) + + + def numpy_to_binary(input_files, output_file_path, split='train'): +diff --git a/dlrm_data_pytorch.py b/dlrm_data_pytorch.py +index ee63897..c0535ad 100644 +--- a/dlrm_data_pytorch.py ++++ b/dlrm_data_pytorch.py +@@ -33,6 +33,8 @@ from numpy import random as ra + # pytorch + import torch + from torch.utils.data import Dataset, RandomSampler ++import extend_distributed as ext_dist ++ + + import data_loader_terabyte + import mlperf_logger +@@ -382,26 +384,53 @@ def make_criteo_data_and_loaders(args): + if args.mlperf_logging and args.memory_map and args.data_set == "terabyte": + # more efficient for larger batches + data_directory = path.dirname(args.raw_data_file) +- + if args.mlperf_bin_loader: + lstr = args.processed_data_file.split("/") + d_path = "/".join(lstr[0:-1]) + "/" + lstr[-1].split(".")[0] + train_file = d_path + "_train.bin" + test_file = d_path + "_test.bin" ++ if ext_dist.my_size > 1 and args.use_hybridparallel_dataset: ++ train_file = d_path + "_train_data_parallel.bin" ++ test_file = d_path + "_test_data_parallel.bin" + # val_file = d_path + "_val.bin" + counts_file = args.raw_data_file + '_fea_count.npz' +- + if any(not path.exists(p) for p in [train_file, + test_file, + counts_file]): + ensure_dataset_preprocessed(args, d_path) + +- train_data = data_loader_terabyte.CriteoBinDataset( +- data_file=train_file, +- counts_file=counts_file, +- batch_size=args.mini_batch_size, +- max_ind_range=args.max_ind_range +- ) ++ train_data = None ++ test_data = None ++ if ext_dist.my_size > 1 and args.use_hybridparallel_dataset: ++ train_data = data_loader_terabyte.HybridParallelCriteoBinDataset( ++ data_file=train_file, ++ counts_file=counts_file, ++ batch_size=args.mini_batch_size, ++ max_ind_range=args.max_ind_range, ++ sparse_dense_boundary=args.sparse_dense_boundary ++ ) ++ test_data = data_loader_terabyte.HybridParallelCriteoBinDataset( ++ data_file=test_file, ++ counts_file=counts_file, ++ batch_size=args.test_mini_batch_size, ++ max_ind_range=args.max_ind_range, ++ sparse_dense_boundary=args.sparse_dense_boundary ++ ++ ) ++ ++ else: ++ train_data = data_loader_terabyte.CriteoBinDataset( ++ data_file=train_file, ++ counts_file=counts_file, ++ batch_size=args.mini_batch_size, ++ max_ind_range=args.max_ind_range ++ ) ++ test_data = data_loader_terabyte.CriteoBinDataset( ++ data_file=test_file, ++ counts_file=counts_file, ++ batch_size=args.test_mini_batch_size, ++ max_ind_range=args.max_ind_range ++ ) + + mlperf_logger.log_event(key=mlperf_logger.constants.TRAIN_SAMPLES, + value=train_data.num_samples) +@@ -418,12 +447,7 @@ def make_criteo_data_and_loaders(args): + sampler=RandomSampler(train_data) if args.mlperf_bin_shuffle else None + ) + +- test_data = data_loader_terabyte.CriteoBinDataset( +- data_file=test_file, +- counts_file=counts_file, +- batch_size=args.test_mini_batch_size, +- max_ind_range=args.max_ind_range +- ) ++ + + mlperf_logger.log_event(key=mlperf_logger.constants.EVAL_SAMPLES, + value=test_data.num_samples) +diff --git a/dlrm_s_pytorch.py b/dlrm_s_pytorch.py +index b19f670..264f54e 100644 +--- a/dlrm_s_pytorch.py ++++ b/dlrm_s_pytorch.py +@@ -80,6 +80,17 @@ import torch.nn as nn + from torch.nn.parallel.parallel_apply import parallel_apply + from torch.nn.parallel.replicate import replicate + from torch.nn.parallel.scatter_gather import gather, scatter ++ ++# For distributed run ++import extend_distributed as ext_dist ++ ++try: ++ import intel_pytorch_extension as ipex ++ from intel_pytorch_extension import core ++except: ++ pass ++from lamb import Lamb, log_lamb_rs ++ + # quotient-remainder trick + from tricks.qr_embedding_bag import QREmbeddingBag + # mixed-dimension trick +@@ -106,7 +117,11 @@ class LRPolicyScheduler(_LRScheduler): + if self.decay_start_step < self.num_warmup_steps: + sys.exit("Learning rate warmup must finish before the decay starts") + +- super(LRPolicyScheduler, self).__init__(optimizer) ++ if isinstance(optimizer, tuple): ++ for opt in optimizer: ++ super(LRPolicyScheduler, self).__init__(opt) ++ else: ++ super(LRPolicyScheduler, self).__init__(optimizer) + + def get_lr(self): + step_count = self._step_count +@@ -133,32 +148,23 @@ class LRPolicyScheduler(_LRScheduler): + return lr + + +-def coalesce_sparse_grads(model, large_grad_threshold=256 * 1000 * 1000): +- """ +- For every sparse gradient of the model, either convert it to dense +- or coalesce, depending on its size and the large_grad_threshold parameter. +- Using coalesced or dense updates has better numerical properties than +- using sparse uncoalesced weight update. +- +- :param model: model the sparse gradients of which need to be coalesced +- :param large_grad_threshold: gradients of shape greater than this will be +- coalesced instead of being converted to dense, this will be slower +- but will also save memory +- :return: None +- """ +- for p in model.parameters(): +- if not p.grad.is_sparse: +- continue +- +- numel = p.shape[0] * p.shape[1] +- if numel < large_grad_threshold: +- # faster but larger memory footprint +- p.grad = p.grad.to_dense() +- else: +- # slower but saves memory for the large parameters +- p.grad = p.grad.coalesce() +- +- ++class Cast(nn.Module): ++ __constants__ = ['to_dtype'] ++ ++ def __init__(self, to_dtype): ++ super(Cast, self).__init__() ++ self.to_dtype = to_dtype ++ ++ def forward(self, input): ++ if input.is_mkldnn: ++ return input.to_dense(self.to_dtype) ++ else: ++ return input.to(self.to_dtype) ++ ++ def extra_repr(self): ++ return 'to(%s)' % self.to_dtype ++ ++ + ### define dlrm in PyTorch ### + class DLRM_Net(nn.Module): + def create_mlp(self, ln, sigmoid_layer): +@@ -169,7 +175,10 @@ class DLRM_Net(nn.Module): + m = ln[i + 1] + + # construct fully connected operator +- LL = nn.Linear(int(n), int(m), bias=True) ++ if self.use_ipex and self.bf16: ++ LL = ipex.IpexMLPLinear(int(n), int(m), bias=True, output_stays_blocked=(i < ln.size - 2), default_blocking=32) ++ else: ++ LL = nn.Linear(int(n), int(m), bias=True) + + # initialize the weights + # with torch.no_grad(): +@@ -188,22 +197,43 @@ class DLRM_Net(nn.Module): + # approach 3 + # LL.weight = Parameter(torch.tensor(W),requires_grad=True) + # LL.bias = Parameter(torch.tensor(bt),requires_grad=True) ++ ++ if self.bf16 and ipex.is_available(): ++ LL.to(torch.bfloat16) ++ # prepack weight for IPEX Linear ++ if hasattr(LL, 'reset_weight_shape'): ++ LL.reset_weight_shape(block_for_dtype=torch.bfloat16) ++ + layers.append(LL) + + # construct sigmoid or relu operator + if i == sigmoid_layer: ++ if self.bf16: ++ layers.append(Cast(torch.float32)) + layers.append(nn.Sigmoid()) + else: +- layers.append(nn.ReLU()) ++ if self.use_ipex and self.bf16: ++ LL.set_activation_type('relu') ++ else: ++ layers.append(nn.ReLU()) + + # approach 1: use ModuleList + # return layers + # approach 2: use Sequential container to wrap all layers + return torch.nn.Sequential(*layers) + +- def create_emb(self, m, ln): ++ def create_emb(self, m, ln, local_ln_emb_sparse=None, ln_emb_dense=None): + emb_l = nn.ModuleList() +- for i in range(0, ln.size): ++ # save the numpy random state ++ np_rand_state = np.random.get_state() ++ emb_dense = nn.ModuleList() ++ emb_sparse = nn.ModuleList() ++ embs = range(len(ln)) ++ if local_ln_emb_sparse or ln_emb_dense: ++ embs = local_ln_emb_sparse + ln_emb_dense ++ for i in embs: ++ # Use per table random seed for Embedding initialization ++ np.random.seed(self.l_emb_seeds[i]) + n = ln[i] + # construct embedding operator + if self.qr_flag and n > self.qr_threshold: +@@ -220,23 +250,46 @@ class DLRM_Net(nn.Module): + EE.embs.weight.data = torch.tensor(W, requires_grad=True) + + else: +- EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True) +- + # initialize embeddings + # nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n)) +- W = np.random.uniform( +- low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m) +- ).astype(np.float32) ++ + # approach 1 +- EE.weight.data = torch.tensor(W, requires_grad=True) ++ if n >= self.sparse_dense_boundary: ++ # For sparse embs, split the table across ranks along sparse dimension ++ if (ext_dist.my_size > 1) and (ext_dist.my_size > len(self.ln_emb_sparse)): ++ new_m = m // (ext_dist.my_size // len(self.ln_emb_sparse)) ++ else: ++ new_m = m ++ ++ W = np.random.uniform( ++ low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, new_m) ++ ).astype(np.float32) ++ ++ EE = nn.EmbeddingBag(n, new_m, mode="sum", sparse=True, _weight=torch.tensor(W, requires_grad=True)) ++ else: ++ W = np.random.uniform( ++ low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m) ++ ).astype(np.float32) ++ ++ EE = nn.EmbeddingBag(n, m, mode="sum", sparse=False, _weight=torch.tensor(W, requires_grad=True)) + # approach 2 + # EE.weight.data.copy_(torch.tensor(W)) + # approach 3 + # EE.weight = Parameter(torch.tensor(W),requires_grad=True) ++ if self.bf16 and ipex.is_available(): ++ EE.to(torch.bfloat16) ++ ++ if ext_dist.my_size > 1: ++ if n >= self.sparse_dense_boundary: ++ emb_sparse.append(EE) ++ else: ++ emb_dense.append(EE) + + emb_l.append(EE) + +- return emb_l ++ # Restore the numpy random state ++ np.random.set_state(np_rand_state) ++ return emb_l, emb_dense, emb_sparse + + def __init__( + self, +@@ -257,6 +310,9 @@ class DLRM_Net(nn.Module): + qr_threshold=200, + md_flag=False, + md_threshold=200, ++ bf16=False, ++ use_ipex=False, ++ sparse_dense_boundary = 2048 + ): + super(DLRM_Net, self).__init__() + +@@ -277,6 +333,9 @@ class DLRM_Net(nn.Module): + self.arch_interaction_itself = arch_interaction_itself + self.sync_dense_params = sync_dense_params + self.loss_threshold = loss_threshold ++ self.bf16 = bf16 ++ self.use_ipex = use_ipex ++ self.sparse_dense_boundary = sparse_dense_boundary + # create variables for QR embedding if applicable + self.qr_flag = qr_flag + if self.qr_flag: +@@ -287,9 +346,28 @@ class DLRM_Net(nn.Module): + self.md_flag = md_flag + if self.md_flag: + self.md_threshold = md_threshold ++ ++ # generate np seeds for Emb table initialization ++ self.l_emb_seeds = np.random.randint(low=0, high=100000, size=len(ln_emb)) ++ ++ #If running distributed, get local slice of embedding tables ++ if ext_dist.my_size > 1: ++ n_emb = len(ln_emb) ++ self.n_global_emb = n_emb ++ self.rank = ext_dist.dist.get_rank() ++ self.ln_emb_dense = [i for i in range(n_emb) if ln_emb[i] < self.sparse_dense_boundary] ++ self.ln_emb_sparse = [i for i in range(n_emb) if ln_emb[i] >= self.sparse_dense_boundary] ++ n_emb_sparse = len(self.ln_emb_sparse) ++ self.n_local_emb_sparse, self.n_sparse_emb_per_rank = ext_dist.get_split_lengths(n_emb_sparse, split=True) ++ self.local_ln_emb_sparse_slice = ext_dist.get_my_slice(n_emb_sparse) ++ self.local_ln_emb_sparse = self.ln_emb_sparse[self.local_ln_emb_sparse_slice] + # create operators + if ndevices <= 1: +- self.emb_l = self.create_emb(m_spa, ln_emb) ++ if ext_dist.my_size > 1: ++ _, self.emb_dense, self.emb_sparse = self.create_emb(m_spa, ln_emb, self.local_ln_emb_sparse, self.ln_emb_dense) ++ else: ++ self.emb_l, _, _ = self.create_emb(m_spa, ln_emb) ++ + self.bot_l = self.create_mlp(ln_bot, sigmoid_bot) + self.top_l = self.create_mlp(ln_top, sigmoid_top) + +@@ -299,7 +377,13 @@ class DLRM_Net(nn.Module): + # x = layer(x) + # return x + # approach 2: use Sequential container to wrap all layers +- return layers(x) ++ need_padding = self.use_ipex and self.bf16 and x.size(0) % 2 == 1 ++ if need_padding: ++ x = torch.nn.functional.pad(input=x, pad=(0,0,0,1), mode='constant', value=0) ++ ret = layers(x) ++ return(ret[:-1,:]) ++ else: ++ return layers(x) + + def apply_emb(self, lS_o, lS_i, emb_l): + # WARNING: notice that we are processing the batch at once. We implicitly +@@ -326,27 +410,32 @@ class DLRM_Net(nn.Module): + return ly + + def interact_features(self, x, ly): ++ x = x.to(ly[0].dtype) + if self.arch_interaction_op == "dot": +- # concatenate dense and sparse features +- (batch_size, d) = x.shape +- T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d)) +- # perform a dot product +- Z = torch.bmm(T, torch.transpose(T, 1, 2)) +- # append dense feature with the interactions (into a row vector) +- # approach 1: all +- # Zflat = Z.view((batch_size, -1)) +- # approach 2: unique +- _, ni, nj = Z.shape +- # approach 1: tril_indices +- # offset = 0 if self.arch_interaction_itself else -1 +- # li, lj = torch.tril_indices(ni, nj, offset=offset) +- # approach 2: custom +- offset = 1 if self.arch_interaction_itself else 0 +- li = torch.tensor([i for i in range(ni) for j in range(i + offset)]) +- lj = torch.tensor([j for i in range(nj) for j in range(i + offset)]) +- Zflat = Z[:, li, lj] +- # concatenate dense features and interactions +- R = torch.cat([x] + [Zflat], dim=1) ++ if self.bf16: ++ T = [x] + ly ++ R = ipex.interaction(*T) ++ else: ++ # concatenate dense and sparse features ++ (batch_size, d) = x.shape ++ T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d)) ++ # perform a dot product ++ Z = torch.bmm(T, torch.transpose(T, 1, 2)) ++ # append dense feature with the interactions (into a row vector) ++ # approach 1: all ++ # Zflat = Z.view((batch_size, -1)) ++ # approach 2: unique ++ _, ni, nj = Z.shape ++ # approach 1: tril_indices ++ # offset = 0 if self.arch_interaction_itself else -1 ++ # li, lj = torch.tril_indices(ni, nj, offset=offset) ++ # approach 2: custom ++ offset = 1 if self.arch_interaction_itself else 0 ++ li = torch.tensor([i for i in range(ni) for j in range(i + offset)]) ++ lj = torch.tensor([j for i in range(nj) for j in range(i + offset)]) ++ Zflat = Z[:, li, lj] ++ # concatenate dense features and interactions ++ R = torch.cat([x] + [Zflat], dim=1) + elif self.arch_interaction_op == "cat": + # concatenation features (into a row vector) + R = torch.cat([x] + ly, dim=1) +@@ -360,7 +449,11 @@ class DLRM_Net(nn.Module): + return R + + def forward(self, dense_x, lS_o, lS_i): +- if self.ndevices <= 1: ++ if self.bf16: ++ dense_x = dense_x.bfloat16() ++ if ext_dist.my_size > 1: ++ return self.distributed_forward(dense_x, lS_o, lS_i) ++ elif self.ndevices <= 1: + return self.sequential_forward(dense_x, lS_o, lS_i) + else: + return self.parallel_forward(dense_x, lS_o, lS_i) +@@ -392,6 +485,83 @@ class DLRM_Net(nn.Module): + + return z + ++ def distributed_forward(self, dense_x, lS_o, lS_i): ++ batch_size = dense_x.size()[0] ++ # WARNING: # of ranks must be <= batch size in distributed_forward call ++ #if batch_size < ext_dist.my_size: ++ # sys.exit("ERROR: batch_size (%d) must be larger than number of ranks (%d)" % (batch_size, ext_dist.my_size)) ++ ++ lS_o_dense = [lS_o[i] for i in self.ln_emb_dense] ++ lS_o_sparse = [lS_o[i] for i in self.ln_emb_sparse] # partition sparse table in one group ++ lS_i_dense = [] ++ g_i_sparse = [] ++ if args.use_hybridparallel_dataset: ++ dense_embs_num = len(self.ln_emb_dense) ++ lS_i_dense = [lS_i[i] for i in range(dense_embs_num)] ++ for i in range(len(self.local_ln_emb_sparse)): ++ global_sparse_index = lS_i[dense_embs_num + i * ext_dist.my_size : dense_embs_num + (i + 1) * ext_dist.my_size] ++ global_sparse_index = global_sparse_index.reshape(-1) ++ g_i_sparse.append(global_sparse_index) ++ offset = torch.arange(batch_size * ext_dist.my_size).to(device) ++ g_o_sparse = [offset for i in range(self.n_local_emb_sparse)] ++ else: ++ lS_i_dense = [lS_i[i] for i in self.ln_emb_dense] ++ lS_i_sparse = [lS_i[i] for i in self.ln_emb_sparse] ++ # Replicate the indices for sparse embs for the split grps ++ if ext_dist.my_size > len(self.ln_emb_sparse): ++ num_split_grps = ext_dist.my_size // len(self.ln_emb_sparse) ++ for j in range(num_split_grps-1): ++ for i in range(len(self.ln_emb_sparse)): ++ lS_i_sparse.append(lS_i_sparse[i]) ++ ++ #lS_i_sparse = ext_dist.shuffle_data(lS_i_sparse) ++ #g_i_sparse = [lS_i_sparse[:, i * batch_size:(i + 1) * batch_size].reshape(-1) for i in range(len(self.local_ln_emb_sparse))] ++ input = torch.cat(lS_i_sparse) ++ output = input.new_empty(input.size()) ++ req = ext_dist.dist.all_to_all_single(output, input, async_op=True) ++ offset = torch.arange(batch_size * ext_dist.my_size).to(device) ++ g_o_sparse = [offset for i in range(self.n_local_emb_sparse)] ++ req.wait() ++ lS_i_sparse = output.reshape(ext_dist.my_size, -1) ++ g_i_sparse = [lS_i_sparse[:, i * batch_size:(i + 1) * batch_size].reshape(-1) for i in range(len(self.local_ln_emb_sparse))] ++ ++ ++ if (len(self.local_ln_emb_sparse) != len(g_o_sparse)) or (len(self.local_ln_emb_sparse) != len(g_i_sparse)): ++ sys.exit("ERROR 0 : corrupted model input detected in distributed_forward call") ++ # sparse embeddings ++ ly_sparse = self.apply_emb(g_o_sparse, g_i_sparse, self.emb_sparse) ++ a2a_req = ext_dist.alltoall(ly_sparse, self.n_sparse_emb_per_rank) ++ # dense embeddings ++ ly_dense = self.apply_emb(lS_o_dense, lS_i_dense, self.emb_dense) ++ ++ # bottom mlp ++ x = self.apply_mlp(dense_x, self.bot_l) ++ ly_sparse = a2a_req.wait() ++ ++ # concat emb data for split sparse embs ++ ly_sparse_full = [] ++ if ext_dist.my_size > len(self.ln_emb_sparse): ++ for i in range(len(self.ln_emb_sparse)): ++ ly_sparse_split = torch.cat([ly_sparse[j] for j in range(i, ext_dist.my_size, 16)], 1) ++ ly_sparse_full.append(ly_sparse_split) ++ else: ++ ly_sparse_full = list(ly_sparse) ++ ++ ly = ly_dense + ly_sparse_full ++ # interactions ++ z = self.interact_features(x, ly) ++ # top mlp ++ p = self.apply_mlp(z, self.top_l) ++ # clamp output if needed ++ if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: ++ z = torch.clamp( ++ p, min=self.loss_threshold, max=(1.0 - self.loss_threshold) ++ ) ++ else: ++ z = p ++ ++ return z ++ + def parallel_forward(self, dense_x, lS_o, lS_i): + ### prepare model (overwrite) ### + # WARNING: # of devices must be >= batch size in parallel_forward call +@@ -510,6 +680,7 @@ if __name__ == "__main__": + + ### import packages ### + import sys ++ import os + import argparse + + ### parse arguments ### +@@ -570,6 +741,8 @@ if __name__ == "__main__": + parser.add_argument("--save-onnx", action="store_true", default=False) + # gpu + parser.add_argument("--use-gpu", action="store_true", default=False) ++ # distributed run ++ parser.add_argument("--dist-backend", type=str, default="") + # debugging and profiling + parser.add_argument("--print-freq", type=int, default=1) + parser.add_argument("--test-freq", type=int, default=-1) +@@ -579,7 +752,10 @@ if __name__ == "__main__": + parser.add_argument("--debug-mode", action="store_true", default=False) + parser.add_argument("--enable-profiling", action="store_true", default=False) + parser.add_argument("--plot-compute-graph", action="store_true", default=False) ++ parser.add_argument("--profiling-start-iter", type=int, default=50) ++ parser.add_argument("--profiling-num-iters", type=int, default=100) + # store/load model ++ parser.add_argument("--out-dir", type=str, default=".") + parser.add_argument("--save-model", type=str, default="") + parser.add_argument("--load-model", type=str, default="") + # mlperf logging (disables other output and stops early) +@@ -590,16 +766,25 @@ if __name__ == "__main__": + parser.add_argument("--mlperf-auc-threshold", type=float, default=0.0) + parser.add_argument("--mlperf-bin-loader", action='store_true', default=False) + parser.add_argument("--mlperf-bin-shuffle", action='store_true', default=False) +- parser.add_argument("--mlperf-coalesce-sparse-grads", action='store_true', default=False) +- # mlperf gradient accumulation iterations +- parser.add_argument("--mlperf-grad-accum-iter", type=int, default=1) + # LR policy + parser.add_argument("--lr-num-warmup-steps", type=int, default=0) + parser.add_argument("--lr-decay-start-step", type=int, default=0) + parser.add_argument("--lr-num-decay-steps", type=int, default=0) ++ # embedding table is sparse table only if sparse_dense_boundary >= 2048 ++ parser.add_argument("--sparse-dense-boundary", type=int, default=2048) ++ # bf16 option ++ parser.add_argument("--bf16", action='store_true', default=False) ++ # ipex option ++ parser.add_argument("--use-ipex", action="store_true", default=False) ++ # lamb ++ parser.add_argument("--optimizer", type=int, default=0, help='optimizer:[0:sgd, 1:lamb/sgd, 2:adagrad, 3:sparseadam]') ++ parser.add_argument("--lamblr", type=float, default=0.01, help='lr for lamb') ++ parser.add_argument("--use-hybridparallel-dataset", action="store_true", default=False) + args = parser.parse_args() + +- if args.mlperf_logging: ++ ext_dist.init_distributed(backend=args.dist_backend) ++ ++ if args.mlperf_logging and ext_dist.my_size >1 and ext_dist.dist.get_rank() == 0: + print('command line args: ', json.dumps(vars(args))) + + ### some basic setup ### +@@ -614,14 +799,29 @@ if __name__ == "__main__": + if (args.test_num_workers < 0): + # if the parameter is not set, use the same parameter for training + args.test_num_workers = args.num_workers ++ if (args.mini_batch_size % ext_dist.my_size !=0 or args.test_mini_batch_size % ext_dist.my_size != 0): ++ print("Either test minibatch (%d) or train minibatch (%d) does not split across %d ranks" % (args.test_mini_batch_size, args.mini_batch_size, ext_dist.my_size)) ++ sys.exit(1) + + use_gpu = args.use_gpu and torch.cuda.is_available() ++ use_ipex = args.use_ipex + if use_gpu: + torch.cuda.manual_seed_all(args.numpy_rand_seed) + torch.backends.cudnn.deterministic = True +- device = torch.device("cuda", 0) +- ngpus = torch.cuda.device_count() # 1 ++ if ext_dist.my_size > 1: ++ ngpus = torch.cuda.device_count() # 1 ++ if ext_dist.my_local_size > torch.cuda.device_count(): ++ print("Not sufficient GPUs available... local_size = %d, ngpus = %d" % (ext_dist.my_local_size, ngpus)) ++ sys.exit(1) ++ ngpus = 1 ++ device = torch.device("cuda", ext_dist.my_local_rank) ++ else: ++ device = torch.device("cuda", 0) ++ ngpus = torch.cuda.device_count() # 1 + print("Using {} GPU(s)...".format(ngpus)) ++ elif use_ipex: ++ device = torch.device("dpcpp") ++ print("Using IPEX...") + else: + device = torch.device("cpu") + print("Using CPU...") +@@ -630,12 +830,7 @@ if __name__ == "__main__": + ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-") + # input data + +- mlperf_logger.barrier() +- mlperf_logger.log_end(key=mlperf_logger.constants.INIT_STOP) +- mlperf_logger.barrier() +- mlperf_logger.log_start(key=mlperf_logger.constants.RUN_START) +- mlperf_logger.barrier() +- ++ + if (args.data_generation == "dataset"): + train_data, train_ld, test_data, test_ld = \ + dp.make_criteo_data_and_loaders(args) +@@ -808,7 +1003,11 @@ if __name__ == "__main__": + qr_threshold=args.qr_threshold, + md_flag=args.md_flag, + md_threshold=args.md_threshold, ++ sparse_dense_boundary=args.sparse_dense_boundary, ++ bf16 = args.bf16, ++ use_ipex = args.use_ipex + ) ++ + print('Model created!') + # test prints + if args.debug_mode: +@@ -817,6 +1016,9 @@ if __name__ == "__main__": + print(param.detach().cpu().numpy()) + # print(dlrm) + ++ if args.use_ipex: ++ dlrm = dlrm.to(device) ++ + if use_gpu: + # Custom Model-Data Parallel + # the mlps are replicated and use data parallelism, while +@@ -825,6 +1027,17 @@ if __name__ == "__main__": + if dlrm.ndevices > 1: + dlrm.emb_l = dlrm.create_emb(m_spa, ln_emb) + ++ if ext_dist.my_size > 1: ++ if use_gpu: ++ device_ids = [ext_dist.my_local_rank] ++ dlrm.bot_l = ext_dist.DDP(dlrm.bot_l, device_ids=device_ids) ++ dlrm.top_l = ext_dist.DDP(dlrm.top_l, device_ids=device_ids) ++ else: ++ dlrm.bot_l = ext_dist.DDP(dlrm.bot_l) ++ dlrm.top_l = ext_dist.DDP(dlrm.top_l) ++ for i in range(len(dlrm.emb_dense)): ++ dlrm.emb_dense[i] = ext_dist.DDP(dlrm.emb_dense[i]) ++ + # specify the loss function + if args.loss_function == "mse": + loss_fn = torch.nn.MSELoss(reduction="mean") +@@ -838,9 +1051,49 @@ if __name__ == "__main__": + + if not args.inference_only: + # specify the optimizer algorithm +- optimizer = torch.optim.SGD(dlrm.parameters(), lr=args.learning_rate) ++ optimizer_list = ([[torch.optim.SGD], ([ipex.Lamb, False], torch.optim.SGD), ++ torch.optim.Adagrad, ([torch.optim.Adam, None], torch.optim.SparseAdam)], ++ [[ipex.SplitSGD], ([ipex.Lamb, True], ipex.SplitSGD)]) ++ optimizers = optimizer_list[args.bf16 and ipex.is_available()][args.optimizer] ++ #print('Chosen optimizer(s): %s' % str(optimizers)) ++ ++ if ext_dist.my_size == 1: ++ if len(optimizers) == 1: ++ optimizer = optimizers[0](dlrm.parameters(), lr=args.learning_rate) ++ else: ++ optimizer_dense = optimizers[0][0]([ ++ {"params": dlrm.bot_l.parameters(), "lr": args.learning_rate}, ++ {"params": dlrm.top_l.parameters(), "lr": args.learning_rate} ++ ], lr=args.learning_rate) ++ if optimizers[0][1] is not None: ++ optimizer_dense.set_bf16(optimizers[0][1]) ++ optimizer_sparse = optimizers[1]([ ++ {"params": [p for emb in dlrm.emb_l for p in emb.parameters()], "lr": args.learning_rate}, ++ ], lr=args.learning_rate) ++ optimizer = (optimizer_dense, optimizer_sparse) ++ else: ++ if len(optimizers) == 1: ++ optimizer = optimizers[0]([ ++ {"params": [p for emb in dlrm.emb_sparse for p in emb.parameters()], ++ "lr": args.learning_rate / ext_dist.my_size}, ++ {"params": [p for emb in dlrm.emb_dense for p in emb.parameters()], "lr": args.learning_rate}, ++ {"params": dlrm.bot_l.parameters(), "lr": args.learning_rate}, ++ {"params": dlrm.top_l.parameters(), "lr": args.learning_rate} ++ ], lr=args.learning_rate) ++ else: ++ optimizer_dense = optimizers[0][0]([ ++ {"params": [p for emb in dlrm.emb_dense for p in emb.parameters()], "lr": args.learning_rate}, ++ {"params": dlrm.bot_l.parameters(), "lr": args.learning_rate}, ++ {"params": dlrm.top_l.parameters(), "lr": args.learning_rate} ++ ], lr=args.lamblr, bf16=args.bf16) ++ optimizer_sparse = optimizers[1]([ ++ {"params": [p for emb in dlrm.emb_sparse for p in emb.parameters()], ++ "lr": args.learning_rate / ext_dist.my_size}, ++ ], lr=args.learning_rate) ++ optimizer = (optimizer_dense, optimizer_sparse) ++ + lr_scheduler = LRPolicyScheduler(optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step, +- args.lr_num_decay_steps) ++ args.lr_num_decay_steps) + + ### main loop ### + def time_wrap(use_gpu): +@@ -848,8 +1101,8 @@ if __name__ == "__main__": + torch.cuda.synchronize() + return time.time() + +- def dlrm_wrap(X, lS_o, lS_i, use_gpu, device): +- if use_gpu: # .cuda() ++ def dlrm_wrap(X, lS_o, lS_i, use_gpu, use_ipex, device): ++ if use_gpu or use_ipex: # .cuda() + # lS_i can be either a list of tensors or a stacked tensor. + # Handle each case below: + lS_i = [S_i.to(device) for S_i in lS_i] if isinstance(lS_i, list) \ +@@ -864,9 +1117,9 @@ if __name__ == "__main__": + else: + return dlrm(X, lS_o, lS_i) + +- def loss_fn_wrap(Z, T, use_gpu, device): ++ def loss_fn_wrap(Z, T, use_gpu, use_ipex, device): + if args.loss_function == "mse" or args.loss_function == "bce": +- if use_gpu: ++ if use_gpu or use_ipex: + return loss_fn(Z, T.to(device)) + else: + return loss_fn(Z, T) +@@ -889,6 +1142,7 @@ if __name__ == "__main__": + skip_upto_epoch = 0 + skip_upto_batch = 0 + total_time = 0 ++ total_data_time = 0 + total_loss = 0 + total_accu = 0 + total_iter = 0 +@@ -957,8 +1211,14 @@ if __name__ == "__main__": + ld_gL_test, ld_gA_test * 100 + ) + ) ++ ext_dist.barrier() ++ mlperf_logger.barrier() ++ mlperf_logger.log_end(key=mlperf_logger.constants.INIT_STOP) ++ mlperf_logger.barrier() ++ mlperf_logger.log_start(key=mlperf_logger.constants.RUN_START) ++ mlperf_logger.barrier() + +- print("time/loss/accuracy (if enabled):") ++ #print("time/loss/accuracy (if enabled):") + + # LR is logged twice for now because of a compliance checker bug + mlperf_logger.log_event(key=mlperf_logger.constants.OPT_BASE_LR, value=args.learning_rate) +@@ -971,6 +1231,18 @@ if __name__ == "__main__": + mlperf_logger.log_event(key='sgd_opt_learning_rate_decay_steps', value=args.lr_num_decay_steps) + mlperf_logger.log_event(key='sgd_opt_learning_rate_decay_poly_power', value=2) + ++ # record_shapes=True ++ # if hasattr(torch.autograd.profiler.profile, "resume"): ++ # prof_support_suspend_resume = True ++ # prof_arg_dict = {"start_suspended": True} ++ # else: ++ # prof_support_suspend_resume = False ++ # prof_arg_dict = { } ++ ++ # prof_start_iter = args.profiling_start_iter ++ # prof_end_iter = prof_start_iter + args.profiling_num_iters ++ ++ # with torch.autograd.profiler.profile(args.enable_profiling, use_gpu, record_shapes=record_shapes, **prof_arg_dict) as prof: + with torch.autograd.profiler.profile(args.enable_profiling, use_gpu) as prof: + while k < args.nepochs: + mlperf_logger.barrier() +@@ -989,21 +1261,28 @@ if __name__ == "__main__": + if args.mlperf_logging: + previous_iteration_time = None + ++ end_time = time_wrap(use_gpu) + for j, (X, lS_o, lS_i, T) in enumerate(train_ld): + if j == 0 and args.save_onnx: + (X_onnx, lS_o_onnx, lS_i_onnx) = (X, lS_o, lS_i) +- ++ if args.enable_profiling and j >= 1000: ++ break + if j < skip_upto_batch: + continue + + if args.mlperf_logging: + current_time = time_wrap(use_gpu) ++ data_time = current_time - end_time + if previous_iteration_time: + iteration_time = current_time - previous_iteration_time + else: + iteration_time = 0 + previous_iteration_time = current_time ++ # if prof and prof_support_suspend_resume and j == prof_start_iter: prof.resume() ++ # if prof and prof_support_suspend_resume and j == prof_end_iter: prof.suspend() + else: ++ # ext_dist.barrier() ++ # if prof and prof_support_suspend_resume and j >= prof_start_iter and j < prof_end_iter: prof.resume() + t1 = time_wrap(use_gpu) + + # early exit if nbatches was set by the user and has been exceeded +@@ -1020,10 +1299,10 @@ if __name__ == "__main__": + ''' + + # forward pass +- Z = dlrm_wrap(X, lS_o, lS_i, use_gpu, device) ++ Z = dlrm_wrap(X, lS_o, lS_i, use_gpu, use_ipex, device) + + # loss +- E = loss_fn_wrap(Z, T, use_gpu, device) ++ E = loss_fn_wrap(Z, T, use_gpu, use_ipex, device) + ''' + # debug prints + print("output and loss") +@@ -1039,7 +1318,11 @@ if __name__ == "__main__": + + if not args.inference_only: + # scaled error gradient propagation +- if (args.mlperf_logging and (j + 1) % args.mlperf_grad_accum_iter == 0) or not args.mlperf_logging: ++ # (where we do not accumulate gradients across mini-batches) ++ if args.optimizer == 1 or args.optimizer == 3: ++ optimizer_dense.zero_grad() ++ optimizer_sparse.zero_grad() ++ else: + optimizer.zero_grad() + # backward pass + E.backward() +@@ -1048,15 +1331,19 @@ if __name__ == "__main__": + # if hasattr(l, 'weight'): + # print(l.weight.grad.norm().item()) + +- if args.mlperf_coalesce_sparse_grads: +- coalesce_sparse_grads(dlrm) +- + # optimizer +- if (args.mlperf_logging and (j + 1) % args.mlperf_grad_accum_iter == 0) or not args.mlperf_logging: ++ if args.optimizer == 1 or args.optimizer == 3: ++ with torch.autograd.profiler.record_function("optimizer_dense:step"): ++ optimizer_dense.step() ++ with torch.autograd.profiler.record_function("optimizer_sparse:step"): ++ optimizer_sparse.step() ++ else: + optimizer.step() +- lr_scheduler.step() ++ lr_scheduler.step() ++ + if args.mlperf_logging: + total_time += iteration_time ++ total_data_time += data_time + else: + t2 = time_wrap(use_gpu) + total_time += t2 - t1 +@@ -1077,6 +1364,9 @@ if __name__ == "__main__": + gT = 1000.0 * total_time / total_iter if args.print_time else -1 + total_time = 0 + ++ gDT = 1000.0 * total_data_time / total_iter if args.print_time else -1 ++ total_data_time = 0 ++ + gA = total_accu / total_samp + total_accu = 0 + +@@ -1084,12 +1374,14 @@ if __name__ == "__main__": + total_loss = 0 + + str_run_type = "inference" if args.inference_only else "training" +- print( +- "Finished {} it {}/{} of epoch {}, {:.2f} ms/it, ".format( +- str_run_type, j + 1, nbatches, k, gT ++ if ext_dist.my_size > 1 and ext_dist.dist.get_rank() == 0: ++ print( ++ "Finished {} it {}/{} of epoch {}, {:.2f} ms/it, ".format( ++ str_run_type, j + 1, nbatches, k, gT ++ ) ++ + "loss {:.6f}, accuracy {:3.3f} %".format(gL, gA * 100) ++ + " data: {:.2f} ms/it".format(gDT) + ) +- + "loss {:.6f}, accuracy {:3.3f} %".format(gL, gA * 100) +- ) + # Uncomment the line below to print out the total time with overhead + # print("Accumulated time so far: {}" \ + # .format(time_wrap(use_gpu) - accum_time_begin)) +@@ -1125,16 +1417,19 @@ if __name__ == "__main__": + + # forward pass + Z_test = dlrm_wrap( +- X_test, lS_o_test, lS_i_test, use_gpu, device ++ X_test, lS_o_test, lS_i_test, use_gpu, use_ipex, device + ) + if args.mlperf_logging: ++ if ext_dist.my_size > 1: ++ Z_test = ext_dist.all_gather(Z_test, None) ++ T_test = ext_dist.all_gather(T_test, None) + S_test = Z_test.detach().cpu().numpy() # numpy array + T_test = T_test.detach().cpu().numpy() # numpy array + scores.append(S_test) + targets.append(T_test) + else: + # loss +- E_test = loss_fn_wrap(Z_test, T_test, use_gpu, device) ++ E_test = loss_fn_wrap(Z_test, T_test, use_gpu, use_ipex, device) + + # compute loss and accuracy + L_test = E_test.detach().cpu().numpy() # numpy array +@@ -1152,51 +1447,54 @@ if __name__ == "__main__": + scores = np.concatenate(scores, axis=0) + targets = np.concatenate(targets, axis=0) + +- metrics = { +- 'loss' : sklearn.metrics.log_loss, +- 'recall' : lambda y_true, y_score: +- sklearn.metrics.recall_score( +- y_true=y_true, +- y_pred=np.round(y_score) +- ), +- 'precision' : lambda y_true, y_score: +- sklearn.metrics.precision_score( +- y_true=y_true, +- y_pred=np.round(y_score) +- ), +- 'f1' : lambda y_true, y_score: +- sklearn.metrics.f1_score( +- y_true=y_true, +- y_pred=np.round(y_score) +- ), +- 'ap' : sklearn.metrics.average_precision_score, +- 'roc_auc' : sklearn.metrics.roc_auc_score, +- 'accuracy' : lambda y_true, y_score: +- sklearn.metrics.accuracy_score( +- y_true=y_true, +- y_pred=np.round(y_score) +- ), +- # 'pre_curve' : sklearn.metrics.precision_recall_curve, +- # 'roc_curve' : sklearn.metrics.roc_curve, +- } +- +- # print("Compute time for validation metric : ", end="") +- # first_it = True + validation_results = {} +- for metric_name, metric_function in metrics.items(): +- # if first_it: +- # first_it = False +- # else: +- # print(", ", end="") +- # metric_compute_start = time_wrap(False) +- validation_results[metric_name] = metric_function( +- targets, +- scores +- ) +- # metric_compute_end = time_wrap(False) +- # met_time = metric_compute_end - metric_compute_start +- # print("{} {:.4f}".format(metric_name, 1000 * (met_time)), +- # end="") ++ if args.use_ipex: ++ validation_results['roc_auc'], validation_results['loss'], validation_results['accuracy'] = \ ++ core.roc_auc_score(torch.from_numpy(targets).reshape(-1), torch.from_numpy(scores).reshape(-1)) ++ else: ++ metrics = { ++ 'loss' : sklearn.metrics.log_loss, ++ 'recall' : lambda y_true, y_score: ++ sklearn.metrics.recall_score( ++ y_true=y_true, ++ y_pred=np.round(y_score) ++ ), ++ 'precision' : lambda y_true, y_score: ++ sklearn.metrics.precision_score( ++ y_true=y_true, ++ y_pred=np.round(y_score) ++ ), ++ 'f1' : lambda y_true, y_score: ++ sklearn.metrics.f1_score( ++ y_true=y_true, ++ y_pred=np.round(y_score) ++ ), ++ 'ap' : sklearn.metrics.average_precision_score, ++ 'roc_auc' : sklearn.metrics.roc_auc_score, ++ 'accuracy' : lambda y_true, y_score: ++ sklearn.metrics.accuracy_score( ++ y_true=y_true, ++ y_pred=np.round(y_score) ++ ), ++ } ++ ++ # print("Compute time for validation metric : ", end="") ++ # first_it = True ++ for metric_name, metric_function in metrics.items(): ++ # if first_it: ++ # first_it = False ++ # else: ++ # print(", ", end="") ++ # metric_compute_start = time_wrap(False) ++ validation_results[metric_name] = metric_function( ++ targets, ++ scores ++ ) ++ # metric_compute_end = time_wrap(False) ++ # met_time = metric_compute_end - metric_compute_start ++ # print("{} {:.4f}".format(metric_name, 1000 * (met_time)), ++ # end="") ++ + # print(" ms") + gA_test = validation_results['accuracy'] + gL_test = validation_results['loss'] +@@ -1236,26 +1534,21 @@ if __name__ == "__main__": + mlperf_logger.log_event(key=mlperf_logger.constants.EVAL_ACCURACY, + value=float(validation_results['roc_auc']), + metadata={mlperf_logger.constants.EPOCH_NUM: epoch_num_float}) +- print( +- "Testing at - {}/{} of epoch {},".format(j + 1, nbatches, k) +- + " loss {:.6f}, recall {:.4f}, precision {:.4f},".format( +- validation_results['loss'], +- validation_results['recall'], +- validation_results['precision'] +- ) +- + " f1 {:.4f}, ap {:.4f},".format( +- validation_results['f1'], +- validation_results['ap'], +- ) +- + " auc {:.4f}, best auc {:.4f},".format( +- validation_results['roc_auc'], +- best_auc_test ++ if ext_dist.my_size > 1 and ext_dist.dist.get_rank() == 0: ++ print( ++ "Testing at - {}/{} of epoch {},".format(j + 1, nbatches, k) ++ + " loss {:.6f},".format( ++ validation_results['loss'] ++ ) ++ + " auc {:.6f}, best auc {:.6f},".format( ++ validation_results['roc_auc'], ++ best_auc_test ++ ) ++ + " accuracy {:3.3f} %, best accuracy {:3.3f} %".format( ++ validation_results['accuracy'] * 100, ++ best_gA_test * 100 ++ ) + ) +- + " accuracy {:3.3f} %, best accuracy {:3.3f} %".format( +- validation_results['accuracy'] * 100, +- best_gA_test * 100 +- ) +- ) + else: + print( + "Testing at - {}/{} of epoch {},".format(j + 1, nbatches, 0) +@@ -1290,6 +1583,8 @@ if __name__ == "__main__": + metadata={ + mlperf_logger.constants.STATUS: mlperf_logger.constants.SUCCESS}) + break ++ #ext_dist.barrier() ++ end_time = time_wrap(use_gpu) + + mlperf_logger.barrier() + mlperf_logger.log_end(key=mlperf_logger.constants.EPOCH_STOP, +@@ -1298,18 +1593,25 @@ if __name__ == "__main__": + mlperf_logger.log_end(key=mlperf_logger.constants.BLOCK_STOP, + metadata={mlperf_logger.constants.FIRST_EPOCH_NUM: k + 1}) + k += 1 # nepochs ++ if args.enable_profiling: ++ print(prof.key_averages().table(sort_by="cpu_time_total")) ++ ++ print('DLRM training summary: best_auc_test = %.6f ' % best_auc_test) + + if args.mlperf_logging and best_auc_test <= args.mlperf_auc_threshold: + mlperf_logger.barrier() + mlperf_logger.log_end(key=mlperf_logger.constants.RUN_STOP, + metadata={mlperf_logger.constants.STATUS: mlperf_logger.constants.ABORTED}) +- ++ + # profiling + if args.enable_profiling: +- with open("dlrm_s_pytorch.prof", "w") as prof_f: ++ file_prefix = "%s/dlrm_s_pytorch_r%d" % (".", ext_dist.dist.get_rank()) ++ #with open("dlrm_s_pytorch.prof", "w") as prof_f: ++ with open("%s.prof" % file_prefix, "w") as prof_f: + prof_f.write(prof.key_averages().table(sort_by="cpu_time_total")) + prof.export_chrome_trace("./dlrm_s_pytorch.json") + # print(prof.key_averages().table(sort_by="cpu_time_total")) ++ + + # plot compute graph + if args.plot_compute_graph: +diff --git a/extend_distributed.py b/extend_distributed.py +new file mode 100644 +index 0000000..b88f413 +--- /dev/null ++++ b/extend_distributed.py +@@ -0,0 +1,434 @@ ++import os ++import builtins ++import numpy as np ++import torch ++from torch.autograd import Function ++from torch.nn.parallel import DistributedDataParallel as DDP ++import torch.distributed as dist ++try: ++ import torch_ccl ++except ImportError as e: ++ #print(e) ++ torch_ccl = False ++ ++my_rank = 0 ++my_size = 1 ++my_local_rank = -1 ++my_local_size = -1 ++alltoall_supported = False ++allgatherv_supported = False ++a2a_impl = os.environ.get('DLRM_ALLTOALL_IMPL', '') ++ ++myreq = None ++ ++def env2int(env_list, default = -1): ++ for e in env_list: ++ val = int(os.environ.get(e, -1)) ++ if val >= 0: return val ++ return default ++ ++def get_my_slice(n): ++ grp_size = dist.get_world_size() ++ grp_rank = dist.get_rank() ++ if dist.get_world_size() > n: ++ num_split_grps = dist.get_world_size() // n ++ grp_size = dist.get_world_size() // num_split_grps ++ grp_rank = dist.get_rank() % grp_size ++ k, m = divmod(n, grp_size) ++ return slice(grp_rank * k + min(grp_rank, m), (grp_rank+1) * k + min(grp_rank+1, m), 1) ++ ++def get_split_lengths(n, split=False): ++ grp_size = dist.get_world_size() ++ if split: ++ if dist.get_world_size() > n: ++ num_split_grps = dist.get_world_size() // n ++ grp_size = dist.get_world_size() // num_split_grps ++ k, m = divmod(n, grp_size) ++ if m == 0: ++ splits = None ++ my_len = k ++ else: ++ my_rank = dist.get_rank() ++ splits = [(k+1) if i < m else k for i in range(grp_size)] ++ my_len = splits[my_rank] ++ return (my_len, splits) ++ ++def init_distributed(rank = -1, size = -1, backend=''): ++ global myreq ++ #global my_rank ++ global my_size ++ global my_local_rank ++ global my_local_size ++ global a2a_impl ++ global alltoall_supported ++ global allgatherv_supported ++ ++ # guess MPI ranks from env (works for IMPI, OMPI and MVAPICH2) ++ num_mpi_ranks = env2int(['PMI_SIZE', 'OMPI_COMM_WORLD_SIZE', 'MV2_COMM_WORLD_SIZE', 'WORLD_SIZE']) ++ if backend == '' and num_mpi_ranks > 1: ++ if torch_ccl and env2int(['CCL_WORKER_COUNT']) > 0: ++ backend = 'ccl' ++ elif dist.is_mpi_available(): ++ backend = 'mpi' ++ else: ++ print("WARNING: MPI multi-process launch detected but PyTorch MPI backend not available.") ++ backend = 'gloo' ++ if backend != '': ++ #guess Rank and size ++ if rank == -1: ++ rank = env2int(['PMI_RANK', 'OMPI_COMM_WORLD_RANK', 'MV2_COMM_WORLD_RANK', 'RANK'], 0) ++ if size == -1: ++ size = env2int(['PMI_SIZE', 'OMPI_COMM_WORLD_SIZE', 'MV2_COMM_WORLD_SIZE', 'WORLD_SIZE'], 1) ++ if not os.environ.get('RANK', None) and rank != -1: os.environ['RANK'] = str(rank) ++ if not os.environ.get('WORLD_SIZE', None) and size != -1: os.environ['WORLD_SIZE'] = str(size) ++ if not os.environ.get('MASTER_PORT', None): os.environ['MASTER_PORT'] = '29500' ++ if not os.environ.get('MASTER_ADDR', None): ++ local_size = env2int(['MPI_LOCALNRANKS', 'OMPI_COMM_WORLD_LOCAL_SIZE', 'MV2_COMM_WORLD_LOCAL_SIZE'], 1) ++ if local_size != size and backend != 'mpi': ++ print("Warning: Looks like distributed multinode run but MASTER_ADDR env not set, using '127.0.0.1' as default") ++ print("If this run hangs, try exporting rank 0's hostname as MASTER_ADDR") ++ os.environ['MASTER_ADDR'] = '127.0.0.1' ++ if size > 1: ++ dist.init_process_group(backend, rank=rank, world_size=size) ++ my_rank = dist.get_rank() ++ my_size = dist.get_world_size() ++ my_local_rank = env2int(['MPI_LOCALRANKID', 'OMPI_COMM_WORLD_LOCAL_RANK', 'MV2_COMM_WORLD_LOCAL_RANK'], 0) ++ my_local_size = env2int(['MPI_LOCALNRANKS', 'OMPI_COMM_WORLD_LOCAL_SIZE', 'MV2_COMM_WORLD_LOCAL_SIZE'], 1) ++ if my_rank == 0: print("Running on %d ranks using %s backend" % (my_size, backend)) ++ if hasattr(dist, 'all_to_all_single'): ++ try: ++ # dist.all_to_all_single(torch.empty([0]), torch.empty([0])) ++ alltoall_supported = True ++ except RuntimeError: ++ pass ++ if a2a_impl == 'alltoall' and alltoall_supported == False: ++ print("Requested DLRM_ALLTOALL_IMPL=%s but backend %s does not support it, use scatter/gather based alltoall" % (a2a_impl, backend)) ++ a2a_impl = 'scatter' ++ if a2a_impl != '': print("Using DLRM_ALLTOALL_IMPL=%s" % a2a_impl) ++ try: ++ x = torch.ones([my_rank]) ++ y = torch.zeros([(my_size*(my_size-1))//2]) ++ y = list(y.split([r for r in range(my_size)])) ++ dist.all_gather(y, x) ++ allgatherv_supported = True ++ except RuntimeError: ++ pass ++ else: ++ my_rank = 0 ++ my_size = 1 ++ my_local_rank = 0 ++ my_local_size = 1 ++ myreq = Request() ++ ++class Request(object): ++ def __init__(self): ++ self.req = None ++ self.tensor = None ++ self.WaitFunction = All2All_Scatter_Wait ++ ++ def wait(self): ++ ret = self.WaitFunction.apply(*self.tensor) ++ self.req = None ++ self.tensor = None ++ return ret ++ ++class All2All_ScatterList_Req(Function): ++ @staticmethod ++ def forward(ctx, a2ai, *inputs): ++ global myreq ++ my_rank = dist.get_rank() ++ #print("All2All_ScatterList_Req:forward") ++ mb_split_lengths = a2ai.gNS if a2ai.gNS else a2ai.lN ++ emb_split_lengths = a2ai.gSS if a2ai.gSS else [a2ai.lS] * my_size ++ gather_list = [] ++ req_list = [] ++ for i in range(my_size): ++ for j in range(emb_split_lengths[i]): ++ out_tensor = inputs[0].new_empty([a2ai.lN, a2ai.E]) ++ scatter_list = list(inputs[j].split(mb_split_lengths, dim = 0)) if i == my_rank else [] ++ req = dist.scatter(out_tensor, scatter_list, src=i, async_op=True) ++ gather_list.append(out_tensor) ++ req_list.append(req) ++ myreq.req = req_list ++ myreq.tensor = tuple(gather_list) ++ myreq.a2ai = a2ai ++ return myreq.tensor ++ ++ @staticmethod ++ def backward(ctx, *grad_output): ++ global myreq ++ #print("All2All_ScatterList_Req:backward") ++ for r in myreq.req: ++ r.wait() ++ myreq.req = None ++ grad_inputs = myreq.tensor ++ myreq.tensor = None ++ return (None, *grad_inputs) ++ ++ ++class All2All_ScatterList_Wait(Function): ++ @staticmethod ++ def forward(ctx, *output): ++ global myreq ++ #print("All2All_Scatter_Wait:forward") ++ ctx.a2ai = myreq.a2ai ++ for r in myreq.req: ++ r.wait() ++ myreq.req = None ++ myreq.tensor = None ++ return output ++ ++ @staticmethod ++ def backward(ctx, *grad_output): ++ global myreq ++ my_rank = dist.get_rank() ++ a2ai = ctx.a2ai ++ grad_output = [t.contiguous() for t in grad_output] ++ mb_split_lengths = a2ai.gNS if a2ai.gNS else [a2ai.lN] * my_size ++ per_rank_split_lengths = a2ai.gSS if a2ai.gSS else [a2ai.lS] * my_size ++ grad_inputs = [grad_output[0].new_empty([ctx.a2ai.N, ctx.a2ai.E]) for _ in range(a2ai.lS)] ++ req_list = [] ++ ind = 0 ++ for i in range(my_size): ++ for j in range(per_rank_split_lengths[i]): ++ gather_list = list(grad_inputs[j].split(mb_split_lengths, dim = 0)) if i == my_rank else None ++ req = dist.gather(grad_output[ind], gather_list, dst = i, async_op=True) ++ req_list.append(req) ++ ind += 1 ++ myreq.req = req_list ++ myreq.tensor = grad_inputs ++ return tuple(grad_output) ++ ++ ++ ++class All2All_Scatter_Req(Function): ++ @staticmethod ++ def forward(ctx, a2ai, *inputs): ++ global myreq ++ #print("All2All_Scatter_Req:forward") ++ my_rank = dist.get_rank() ++ mb_split_lengths = a2ai.gNS if a2ai.gNS else a2ai.lN ++ emb_split_lengths = a2ai.gSS if a2ai.gSS else [a2ai.lS] * my_size ++ input = torch.cat(inputs, dim=1) ++ scatter_list = list(input.split(mb_split_lengths, dim=0)) ++ gather_list = [] ++ req_list = [] ++ for i in range(my_size): ++ out_tensor = input.new_empty([a2ai.lN, emb_split_lengths[i] * a2ai.E]) ++ req = dist.scatter(out_tensor, scatter_list if i == my_rank else [], src=i, async_op=True) ++ gather_list.append(out_tensor) ++ req_list.append(req) ++ myreq.req = req_list ++ myreq.tensor = tuple(gather_list) ++ myreq.a2ai = a2ai ++ ctx.a2ai = a2ai ++ return myreq.tensor ++ ++ @staticmethod ++ def backward(ctx, *grad_output): ++ global myreq ++ #print("All2All_Scatter_Req:backward") ++ for r in myreq.req: ++ r.wait() ++ myreq.req = None ++ grad_input = myreq.tensor ++ grad_inputs = grad_input.split(ctx.a2ai.E, dim=1) ++ myreq.tensor = None ++ return (None, *grad_inputs) ++ ++ ++class All2All_Scatter_Wait(Function): ++ @staticmethod ++ def forward(ctx, *output): ++ global myreq ++ #print("All2All_Scatter_Wait:forward") ++ ctx.a2ai = myreq.a2ai ++ for r in myreq.req: ++ r.wait() ++ myreq.req = None ++ myreq.tensor = None ++ return output ++ ++ @staticmethod ++ def backward(ctx, *grad_output): ++ global myreq ++ my_rank = dist.get_rank() ++ #print("All2All_Scatter_Wait:backward") ++ assert len(grad_output) == my_size ++ scatter_list = [t.contiguous() for t in grad_output] ++ a2ai = ctx.a2ai ++ mb_split_lengths = a2ai.gNS if a2ai.gNS else a2ai.lN ++ emb_split_lengths = a2ai.gSS if a2ai.gSS else [a2ai.lS] * my_size ++ grad_input = grad_output[0].new_empty([a2ai.N, a2ai.E*a2ai.lS]) ++ gather_list = list(grad_input.split(mb_split_lengths, dim=0)) ++ req_list = [] ++ for i in range(my_size): ++ #req = dist.scatter(gather_list[i], scatter_list if i == my_rank else [], src=i, async_op=True) ++ req = dist.gather(scatter_list[i], gather_list if i == my_rank else [], dst=i, async_op=True) ++ req_list.append(req) ++ myreq.req = req_list ++ myreq.tensor = grad_input ++ return grad_output ++ ++ ++class All2All_Req(Function): ++ @staticmethod ++ def forward(ctx, a2ai, *inputs): ++ global myreq ++ #print("All2All_Req:forward") ++ mb_split_lengths = a2ai.gNS ++ if mb_split_lengths: mb_split_lengths = [m * a2ai.lS * a2ai.E for m in mb_split_lengths] ++ emb_split_lengths = a2ai.gSS ++ if emb_split_lengths: emb_split_lengths = [a2ai.lN * e * a2ai.E for e in emb_split_lengths] ++ input = torch.cat(inputs, dim=1).view([-1]) ++ output = input.new_empty([a2ai.S*a2ai.lN*a2ai.E]) ++ req = dist.all_to_all_single(output, input, emb_split_lengths, mb_split_lengths, async_op=True) ++ myreq.req = req ++ myreq.tensor = [] ++ myreq.tensor.append(output) ++ myreq.tensor = tuple(myreq.tensor) ++ a2ai.mb_split_lengths = mb_split_lengths ++ a2ai.emb_split_lengths = emb_split_lengths ++ myreq.a2ai = a2ai ++ ctx.a2ai = a2ai ++ return myreq.tensor ++ ++ @staticmethod ++ def backward(ctx, *grad_output): ++ global myreq ++ #print("All2All_Req:backward") ++ a2ai = ctx.a2ai ++ myreq.req.wait() ++ myreq.req = None ++ grad_input = myreq.tensor ++ grad_inputs = grad_input.view([a2ai.N, -1]).split(a2ai.E, dim=1) ++ grad_inputs = [gin.contiguous() for gin in grad_inputs] ++ myreq.tensor = None ++ return (None, *grad_inputs) ++ ++ ++class All2All_Wait(Function): ++ @staticmethod ++ def forward(ctx, *output): ++ global myreq ++ #print("All2All_Wait:forward") ++ a2ai = myreq.a2ai ++ ctx.a2ai = a2ai ++ myreq.req.wait() ++ myreq.req = None ++ myreq.tensor = None ++ emb_split_lengths = a2ai.emb_split_lengths if a2ai.emb_split_lengths else a2ai.lS * a2ai.lN * a2ai.E ++ #print("output[0].shape:", output[0].shape," a2ai.lN:" , a2ai.lN, " emb_split_lengths: ",emb_split_lengths) ++ outputs = output[0].split(emb_split_lengths) ++ outputs = tuple([out.view([a2ai.lN, -1]) for out in outputs]) ++ #print(outputs[0].shape) ++ return outputs ++ ++ @staticmethod ++ def backward(ctx, *grad_outputs): ++ global myreq ++ #print("All2All_Wait:backward") ++ a2ai = ctx.a2ai ++ grad_outputs = [gout for gout in grad_outputs] ++ grad_output = torch.cat(grad_outputs,dim=0).view([-1]).contiguous() ++ grad_input = grad_output.new_empty([a2ai.N * a2ai.lS * a2ai.E]) ++ req = dist.all_to_all_single(grad_input, grad_output, a2ai.mb_split_lengths, a2ai.emb_split_lengths, async_op=True) ++ myreq.req = req ++ myreq.tensor = grad_input ++ return (grad_output,) ++ ++class AllGather(Function): ++ ++ @staticmethod ++ def forward(ctx, input, global_lengths, dim=0): ++ if not isinstance(global_lengths, (list, tuple)): ++ global_lengths = [global_lengths] * my_size ++ my_rank = dist.get_rank() ++ assert(len(global_lengths) == my_size) ++ assert(global_lengths[my_rank] == input.size(dim)) ++ local_start = sum(global_lengths[:my_rank]) ++ ++ output_size = list(input.size()) ++ ++ ctx.dim = dim ++ ctx.local_start = local_start ++ ctx.local_length = global_lengths[my_rank] ++ ++ input = input.contiguous() ++ if dim == 0: ++ out_len = sum(global_lengths) ++ output_size[dim] = out_len ++ output = input.new_empty(output_size) ++ gather_list = list(output.split(global_lengths, dim=0)) ++ else: ++ gather_list = [torch.empty_like(input) for _ in range(my_size)] ++ gather_list = [] ++ for l in global_lengths: ++ output_size[dim] = l ++ gather_list.append(input.new_empty(output_size)) ++ ++ dist.all_gather(gather_list, input) ++ ++ if dim != 0: ++ output = torch.cat(gather_list, dim=dim) ++ ++ return output ++ ++ @staticmethod ++ def backward(ctx, grad_output): ++ # print("Inside All2AllBackward") ++ dim = ctx.dim ++ start = ctx.local_start ++ length = ctx.local_length ++ ++ grad_input = grad_output.narrow(dim, start, length) ++ ++ return (grad_input, None, None) ++ ++class All2AllInfo(object): ++ pass ++ ++def alltoall(inputs, per_rank_split_lengths): ++ global myreq ++ N, E = inputs[0].size() ++ a2ai = All2AllInfo() ++ a2ai.lS = len(inputs) ++ a2ai.gSS = per_rank_split_lengths ++ a2ai.lN, a2ai.gNS = get_split_lengths(N) ++ a2ai.E = E ++ a2ai.N = N ++ a2ai.S = sum(per_rank_split_lengths) if per_rank_split_lengths else a2ai.lS * my_size ++ if a2a_impl == '' and alltoall_supported or a2a_impl == 'alltoall': ++ output = All2All_Req.apply(a2ai, *inputs) ++ myreq.WaitFunction = All2All_Wait ++ elif a2a_impl == '' or a2a_impl == 'scatter': ++ #print("Using All2All_Scatter_Req") ++ output = All2All_Scatter_Req.apply(a2ai, *inputs) ++ myreq.WaitFunction = All2All_Scatter_Wait ++ elif a2a_impl == 'scatter_list': ++ #print("Using All2All_ScatterList_Req") ++ output = All2All_ScatterList_Req.apply(a2ai, *inputs) ++ myreq.WaitFunction = All2All_ScatterList_Wait ++ else: ++ print("Unknown value set for DLRM_ALLTOALL_IMPL (%s), please use one of [alltoall, scatter, scatter_list]" % a2a_impl) ++ return myreq ++ ++def shuffle_data(inputs): ++ input = torch.cat(inputs) ++ output = input.new_empty(input.size()) ++ req = dist.all_to_all_single(output, input) ++ output = output.reshape(my_size, -1) ++ return output ++ ++ ++def all_gather(input, lengths, dim=0): ++ #print("lengths: ", lengths) ++ if not lengths: lengths = [input.size(0)] * my_size ++ return AllGather.apply(input, lengths, dim) ++ ++def barrier(): ++ with torch.autograd.profiler.record_function("ext_barrier()"): ++ if my_size > 1: ++ dist.barrier() ++ ++ +diff --git a/get_hybridparallel_friendly_dataset.py b/get_hybridparallel_friendly_dataset.py +new file mode 100644 +index 0000000..b221fdf +--- /dev/null ++++ b/get_hybridparallel_friendly_dataset.py +@@ -0,0 +1,64 @@ ++import os ++import numpy as np ++import time ++import math ++from tqdm import tqdm ++import argparse ++ ++ ++def parse(data_file, counts_file, prefix='train', sparse_dense_boundary=2048): ++ tar_fea = 1 # single target ++ den_fea = 13 # 13 dense features ++ spa_fea = 26 # 26 sparse features ++ tad_fea = tar_fea + den_fea ++ tot_fea = tar_fea + den_fea + spa_fea ++ bytes_per_feature=4 ++ bytes_per_sample = bytes_per_feature * tot_fea ++ data_file_size = os.path.getsize(data_file) ++ num_samples = math.ceil(data_file_size / bytes_per_sample) ++ ++ print('data file:', data_file, ' counts_file: ',counts_file) ++ dir_name = os.path.dirname(data_file) ++ data_prefix = data_file.split('/')[-1].split('.')[0] ++ counts=[] ++ with np.load(counts_file) as data: ++ counts = data["counts"] ++ ++ dense_index = [] ++ sparse_index = [] ++ index = 0 ++ for count in counts: ++ if count >= sparse_dense_boundary: ++ sparse_index.append(index) ++ else: ++ dense_index.append(index) ++ index += 1 ++ print(dense_index, " ", sparse_index) ++ ++ file_str_list = data_file.split('.') ++ sparse_fd_map = dict() ++ for spa_index in sparse_index: ++ out_file_name = "{}/test/{}_sparse_embedding_index_{}.bin".format(dir_name, data_prefix, spa_index); ++ sparse_fd_map[spa_index] = open(out_file_name,'ab+') ++ out_file = '{}/test/{}_data_parallel.bin'.format(dir_name, data_prefix,'ab+') ++ out_file_fd = open(out_file, 'ab+') ++ ++ with open(data_file, 'rb') as file: ++ for idx in tqdm(range(num_samples)): ++ raw_data = file.read(bytes_per_sample) ++ array = np.frombuffer(raw_data, dtype=np.int32) ++ dp_data = array[:tad_fea] ++ emb_index = array[tad_fea:tot_fea] ++ dp_data = np.append(dp_data, emb_index[dense_index]) ++ out_file_fd.write(dp_data.tobytes()) ++ for spa_index in sparse_index: ++ sparse_fd_map[spa_index].write(emb_index[spa_index].tobytes()) ++ ++ out_file_fd.close() ++ for spa_index in sparse_index: ++ sparse_fd_map[spa_index].close() ++ ++ ++parse(data_file='dlrm_dataset/dlrm/input/terabyte_processed_train.bin',counts_file='dlrm_dataset/dlrm/input/day_fea_count.npz') ++parse(data_file='dlrm_dataset/dlrm/input/terabyte_processed_val.bin',counts_file='dlrm_dataset/dlrm/input/day_fea_count.npz') ++parse(data_file='dlrm_dataset/dlrm/input/terabyte_processed_test.bin',counts_file='dlrm_dataset/dlrm/input/day_fea_count.npz') +diff --git a/lamb.py b/lamb.py +new file mode 100644 +index 0000000..10e1d73 +--- /dev/null ++++ b/lamb.py +@@ -0,0 +1,149 @@ ++"""Lamb optimizer.""" ++ ++import collections ++import math ++ ++import torch ++from tensorboardX import SummaryWriter ++from torch.optim import Optimizer ++ ++ ++def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): ++ """Log a histogram of trust ratio scalars in across layers.""" ++ results = collections.defaultdict(list) ++ for group in optimizer.param_groups: ++ for p in group['params']: ++ state = optimizer.state[p] ++ for i in ('weight_norm', 'adam_norm', 'trust_ratio'): ++ if i in state: ++ results[i].append(state[i]) ++ ++ for k, v in results.items(): ++ event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) ++ ++class Lamb(Optimizer): ++ r"""Implements Lamb algorithm. ++ ++ It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. ++ ++ Arguments: ++ params (iterable): iterable of parameters to optimize or dicts defining ++ parameter groups ++ lr (float, optional): learning rate (default: 1e-3) ++ betas (Tuple[float, float], optional): coefficients used for computing ++ running averages of gradient and its square (default: (0.9, 0.999)) ++ eps (float, optional): term added to the denominator to improve ++ numerical stability (default: 1e-8) ++ weight_decay (float, optional): weight decay (L2 penalty) (default: 0) ++ adam (bool, optional): always use trust ratio = 1, which turns this into ++ Adam. Useful for comparison purposes. ++ ++ .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: ++ https://arxiv.org/abs/1904.00962 ++ """ ++ ++ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, ++ weight_decay=0, adam=False, bf16=False): ++ if not 0.0 <= lr: ++ raise ValueError("Invalid learning rate: {}".format(lr)) ++ if not 0.0 <= eps: ++ raise ValueError("Invalid epsilon value: {}".format(eps)) ++ if not 0.0 <= betas[0] < 1.0: ++ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) ++ if not 0.0 <= betas[1] < 1.0: ++ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) ++ defaults = dict(lr=lr, betas=betas, eps=eps, ++ weight_decay=weight_decay) ++ self.adam = adam ++ self.bf16 = bf16 ++ super(Lamb, self).__init__(params, defaults) ++ ++ def set_bf16(self, bf16=False): ++ self.bf16 = bf16 ++ ++ def step(self, closure=None): ++ """Performs a single optimization step. ++ ++ Arguments: ++ closure (callable, optional): A closure that reevaluates the model ++ and returns the loss. ++ """ ++ loss = None ++ if closure is not None: ++ loss = closure() ++ ++ for group in self.param_groups: ++ for p in group['params']: ++ if p.grad is None: ++ continue ++ grad = p.grad.data ++ if grad.is_sparse: ++ raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') ++ ++ state = self.state[p] ++ ++ # State initialization ++ if len(state) == 0: ++ state['step'] = 0 ++ # Exponential moving average of gradient values ++ state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float32) ++ # Exponential moving average of squared gradient values ++ state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float32) ++ if self.bf16: ++ # additional fp32 version of master weights ++ state['data_fp32'] = p.data.to(torch.float32) ++ ++ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] ++ beta1, beta2 = group['betas'] ++ if self.bf16: ++ grad_fp32 = grad.to(torch.float32) ++ data_fp32 = state['data_fp32'] ++ ++ state['step'] += 1 ++ ++ # Decay the first and second moment running average coefficient ++ if self.bf16: ++ # m_t ++ exp_avg.mul_(beta1).add_(1 - beta1, grad_fp32) ++ # v_t ++ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad_fp32, grad_fp32) ++ else: ++ # m_t ++ exp_avg.mul_(beta1).add_(1 - beta1, grad) ++ # v_t ++ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) ++ ++ # Paper v3 does not use debiasing. ++ # bias_correction1 = 1 - beta1 ** state['step'] ++ # bias_correction2 = 1 - beta2 ** state['step'] ++ # Apply bias to lr to avoid broadcast. ++ step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 ++ ++ weight_norm = data_fp32.pow(2).sum().sqrt().clamp(0, 10) if self.bf16 \ ++ else p.data.pow(2).sum().sqrt().clamp(0, 10) ++ ++ adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) ++ if group['weight_decay'] != 0: ++ if self.bf16: ++ adam_step.add_(group['weight_decay'], data_fp32) ++ else: ++ adam_step.add_(group['weight_decay'], p.data) ++ ++ adam_norm = adam_step.pow(2).sum().sqrt() ++ if weight_norm == 0 or adam_norm == 0: ++ trust_ratio = 1 ++ else: ++ trust_ratio = weight_norm / adam_norm ++ state['weight_norm'] = weight_norm ++ state['adam_norm'] = adam_norm ++ state['trust_ratio'] = trust_ratio ++ if self.adam: ++ trust_ratio = 1 ++ ++ if self.bf16: ++ data_fp32.add_(-step_size * trust_ratio, adam_step) ++ p.data = data_fp32.to(torch.bfloat16) ++ else: ++ p.data.add_(-step_size * trust_ratio, adam_step) ++ ++ return loss +diff --git a/mlperf_logger.py b/mlperf_logger.py +index e07e658..ab6b9a3 100644 +--- a/mlperf_logger.py ++++ b/mlperf_logger.py +@@ -61,9 +61,9 @@ def barrier(): + Calls all_reduce on dummy tensor and synchronizes with GPU. + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): +- torch.distributed.all_reduce(torch.cuda.FloatTensor(1)) +- torch.cuda.synchronize() +- ++ #torch.distributed.all_reduce(torch.cuda.FloatTensor(1)) ++ #torch.cuda.synchronize() ++ torch.distributed.barrier() + + def get_rank(): + """