diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index d7a35b7dbbe9..d1ebe4b2f83d 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -407,7 +407,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): max_norm = torch.tensor([float(max_norm)], device=parameters[0].device) clip_coef = max_norm / (total_norm + 1e-6) tmp_tensor = torch.tensor([1.0], device=parameters[0].device) - clip_coef = torch.max(tmp_tensor, clip_coef) + clip_coef = torch.min(tmp_tensor, clip_coef) for p in parameters: p.grad.data.mul_(clip_coef) return total_norm diff --git a/tests/unit/runtime/test_runtime_utils.py b/tests/unit/runtime/test_runtime_utils.py index 5d8478b249be..6fdeb2074246 100644 --- a/tests/unit/runtime/test_runtime_utils.py +++ b/tests/unit/runtime/test_runtime_utils.py @@ -26,10 +26,10 @@ def test_call_to_str(): assert c2s('hello', 1138, val=3) == 'hello(1138, val=3)' -class TestClibGradNorm(DistributedTest): +class TestClipGradNorm(DistributedTest): world_size = 2 - def test(self): + def test_gather(self): param1 = torch.nn.Parameter(torch.Tensor([0])) param1.grad = torch.Tensor([1]) param2 = torch.nn.Parameter(torch.Tensor([0])) @@ -50,6 +50,27 @@ def test(self): assert gathered_norm[0] == gathered_norm[1], "norm at rank 0 does not match the norm at rank 1" + def test_clipped_val(self): + max_norm = 0.1 + + def test_params(): + param1 = torch.nn.Parameter(torch.Tensor([0])) + param1.grad = torch.Tensor([1]) + param2 = torch.nn.Parameter(torch.Tensor([0])) + param2.grad = torch.Tensor([1]) + return [param1, param2] + + # This assumes gradients are same on all the ranks and doesn't consider multiple ranks + params_expected = test_params() + torch.nn.utils.clip_grad_norm_(params_expected, max_norm) + + params_actual = test_params() + ds_utils.clip_grad_norm_(params_actual, max_norm=max_norm) + + # This can be allclose + assert torch.equal(params_expected[0].grad, params_actual[0].grad) + assert torch.equal(params_expected[1].grad, params_actual[1].grad) + @pytest.mark.parametrize("check_using_norm", [(False), (True)]) class TestCheckOverflow(DistributedTest):