From 4921858aff7572a57c0ec1239b42719632df8fb4 Mon Sep 17 00:00:00 2001 From: Lucas Date: Wed, 6 Mar 2024 05:09:08 +0000 Subject: [PATCH] changed to no_alloc_cast --- deepspeed/runtime/bf16_optimizer.py | 30 ++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index fa6c41bb1607..d9ffad9c18fb 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -39,19 +39,21 @@ def fp8_to_fp32(fp8_param, fp8_group_flat): out = out.view(fp8_group_flat.size()) return out -def fp32_to_fp8(fp8_param, fp8_partition_size, fp32_partition): +def fp32_to_fp8(fp8_param, fp8_partition_size, fp32_partition, out): scale_inv = fp8_param._scale_inv scale = scale_inv.reciprocal() fp8_dtype = fp8_param._fp8_dtype - amax = torch.empty_like(scale) - - out = tex.cast_to_fp8(fp32_partition.view(1,-1), - scale, - amax, - scale_inv, - fp8_dtype, - ).view(fp8_partition_size) - return out + amax = torch.ones_like(scale) + + tex.cast_to_fp8_noalloc( + fp32_partition.view(1, -1), + scale, + out.view(1, -1), + amax, + scale_inv, + fp8_dtype, + ) + return None class BF16_Optimizer(ZeROOptimizer): @@ -425,9 +427,11 @@ def update_lp_params(self): fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) if bf16_partitions[partition_id].dtype == torch.uint8: - bf16_partitions[partition_id].data.copy_(fp32_to_fp8(self.bf16_groups[i][0], - bf16_partitions[partition_id].size(), - fp32_partition.data)) + fp32_to_fp8(self.bf16_groups[i][0], + bf16_partitions[partition_id].size(), + fp32_partition.data, + out=bf16_partitions[partition_id].data) + else: bf16_partitions[partition_id].data.copy_(fp32_partition.data) # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)