Skip to content

Commit

Permalink
Fix KeyedOptimizer init state test (#1874)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1874

`KeyedOptimizer.test_init_state` test started failing since upstream pytorch changes: pytorch/pytorch#122349 (and later follow up: pytorch/pytorch#123757), which only initializes the state for param groups if momentum is enabled for SGD.

Updating unit test to enable momentum fixes it. Also adding a new unit test to check state if momentum is disabled.

Reviewed By: henrylhtsang

Differential Revision: D56076424

fbshipit-source-id: e3d5ae063a5187d2d8702ad7f0bb4b2791b954fe
  • Loading branch information
sarckk authored and facebook-github-bot committed Apr 12, 2024
1 parent 568e116 commit 7584fbd
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions torchrec/optim/tests/test_keyed.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,12 @@ def test_non_param_state_key(self) -> None:
[{"params": [param_1], "param_group_val_0": 3.0}],
)

def test_init_state(self) -> None:
def test_init_state_with_momentum(self) -> None:
dense = torch.nn.Parameter(torch.ones((2, 3), dtype=torch.float))
sparse = torch.nn.Parameter(torch.ones((1, 4), dtype=torch.float))
opt = KeyedOptimizerWrapper(
{"dense": dense, "sparse": sparse},
lambda params: torch.optim.SGD(params, lr=0.1),
lambda params: torch.optim.SGD(params, lr=0.1, momentum=0.1),
)
opt.init_state({"sparse"})

Expand All @@ -208,6 +208,24 @@ def test_init_state(self) -> None:
self.assertTrue(sparse.grad.is_sparse)
self.assertTrue("momentum_buffer" in opt.state_dict()["state"]["sparse"])

def test_init_state_no_momentum(self) -> None:
dense = torch.nn.Parameter(torch.ones((2, 3), dtype=torch.float))
sparse = torch.nn.Parameter(torch.ones((1, 4), dtype=torch.float))
opt = KeyedOptimizerWrapper(
{"dense": dense, "sparse": sparse},
lambda params: torch.optim.SGD(params, lr=0.1),
)
opt.init_state({"sparse"})

self.assertTrue(dense.grad is not None)
self.assertFalse(dense.grad.is_sparse)

self.assertTrue(sparse.grad is not None)
self.assertTrue(sparse.grad.is_sparse)

self.assertTrue("state" in opt.state_dict())
self.assertFalse(opt.state_dict()["state"])

def test_pickle(self) -> None:
dense = torch.nn.Parameter(torch.ones((2, 3), dtype=torch.float))
sparse = torch.nn.Parameter(torch.ones((1, 4), dtype=torch.float))
Expand Down

0 comments on commit 7584fbd

Please sign in to comment.