Skip to content

Commit

Permalink
type conversions for fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
jomayeri committed Mar 6, 2024
1 parent aa06091 commit cf9fabf
Showing 1 changed file with 66 additions and 10 deletions.
76 changes: 66 additions & 10 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,36 @@
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
PARAM_SLICE_MAPPINGS)
from transformer_engine.pytorch.float8_tensor import Float8Tensor

from transformer_engine.pytorch.float8_tensor import Float8Tensor, _FromFloat8Func, _ToFloat8Func
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine_extensions as tex

setattr(sys.modules[__name__], 'fragment_address', fragment_address)

def fp8_to_fp32(fp8_param, fp8_group_flat):
scale_inv = fp8_param._scale_inv
fp8_dtype = fp8_param._fp8_dtype

data = fp8_group_flat.contiguous().view(1,-1).detach()
out = tex.cast_from_fp8(data, scale_inv, fp8_dtype, TE_DType[torch.float32])
out = out.view(fp8_group_flat.size())
return out

def fp32_to_fp8(fp8_param, fp8_partition_size, fp32_partition):
scale_inv = fp8_param._scale_inv
scale = scale_inv.reciprocal()
fp8_dtype = fp8_param._fp8_dtype
amax = torch.empty_like(scale)

out = tex.cast_to_fp8(fp32_partition.view(1,-1),
scale,
amax,
scale_inv,
fp8_dtype,
).view(fp8_partition_size)
return out


class BF16_Optimizer(ZeROOptimizer):

Expand All @@ -42,11 +68,32 @@ def __init__(self,
grad_acc_dtype=None,
graph_harvesting=False,
immediate_grad_update=False):
super().__init__()
see_memory_usage('begin bf16_optimizer', force=True)
self.timers = timers
self.optimizer = init_optimizer
self.param_names = param_names

new_groups = []
for param_group in self.optimizer.param_groups:
float8_params = []
bf16_params = []
for param in param_group['params']:
if isinstance(param, Float8Tensor):
float8_params.append(param)
else:
bf16_params.append(param)
# make copy of param group for float8
if float8_params:
f8_group = param_group.copy()
f8_group['params'] = float8_params
new_groups.append(f8_group)
param_group['params'] = bf16_params

for group in new_groups:
self.optimizer.add_param_group(group)

super().__init__()

self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)

assert grad_acc_dtype in [torch.float32, torch.bfloat16
Expand Down Expand Up @@ -88,7 +135,6 @@ def __init__(self,
self.graph_harvesting = graph_harvesting
if self.using_real_optimizer:
self._setup_for_real_optimizer()

see_memory_usage('end bf16_optimizer', force=True)

def _setup_for_real_optimizer(self):
Expand All @@ -100,7 +146,6 @@ def _setup_for_real_optimizer(self):

partition_id = dist.get_rank(group=self.real_dp_process_group[i])

# grab the original list
trainable_parameters = [param for param in param_group['params'] if param.requires_grad]
self.bf16_groups.append(trainable_parameters)

Expand All @@ -122,7 +167,12 @@ def _setup_for_real_optimizer(self):
self.bf16_partitioned_groups.append(bf16_dp_partitions)

# create fp32 params partition
self.fp32_groups_flat_partition.append(bf16_dp_partitions[partition_id].clone().float().detach())
# TODO: uint8 --> fp32 --> uint8 is the issue
if bf16_dp_partitions[partition_id].dtype == torch.uint8:
self.fp32_groups_flat_partition.append(fp8_to_fp32(self.bf16_groups[i][0],
bf16_dp_partitions[partition_id]))
else:
self.fp32_groups_flat_partition.append(bf16_dp_partitions[partition_id].clone().float().detach())
self.fp32_groups_flat_partition[i].requires_grad = True

num_elem_list = [t.numel() for t in self.bf16_groups[i]]
Expand Down Expand Up @@ -254,13 +304,13 @@ def _update_storage_to_flattened_tensor(self, tensor_list, flat_tensor):
p.data = q.data

def _flatten_dense_tensors_aligned(self, tensor_list, alignment):
new_tensor_list = []
p_list = []
for t in tensor_list:
if isinstance(t, Float8Tensor):
new_tensor_list.append(t._data)
p_list.append(t._data)
else:
new_tensor_list.append(t)
return self.flatten(align_dense_tensors(new_tensor_list, alignment))
p_list.append(t)
return self.flatten(align_dense_tensors(p_list, alignment))

@torch.no_grad()
def step(self, closure=None):
Expand Down Expand Up @@ -370,10 +420,16 @@ def get_grads_for_norm(self, for_clipping=False):

@torch.no_grad()
def update_lp_params(self):
#import pdb; pdb.set_trace()
for i, (bf16_partitions,
fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
bf16_partitions[partition_id].data.copy_(fp32_partition.data)
if bf16_partitions[partition_id].dtype == torch.uint8:
bf16_partitions[partition_id].data.copy_(fp32_to_fp8(self.bf16_groups[i][0],
bf16_partitions[partition_id].size(),
fp32_partition.data))
else:
bf16_partitions[partition_id].data.copy_(fp32_partition.data)
# print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
# if i == 0:
# print_rank_0(f'{fp32_partition[:10]=}', force=True)
Expand Down

0 comments on commit cf9fabf

Please sign in to comment.