diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index c98caae31534..49b846633d6e 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -248,6 +248,11 @@ def _zero2_merge_frozen_params(state_dict, zero_model_states): print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): param_shapes = zero_model_states[0].param_shapes @@ -287,7 +292,7 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero avail_numel = full_single_fp32_vector.numel() for name, shape in shapes.items(): - unpartitioned_numel = shape.numel() + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) total_numel += unpartitioned_numel total_params += 1