diff --git a/torchrec/distributed/tests/test_model_parallel_hierarchical.py b/torchrec/distributed/tests/test_model_parallel_hierarchical.py index 5209d479c..d820f85e4 100644 --- a/torchrec/distributed/tests/test_model_parallel_hierarchical.py +++ b/torchrec/distributed/tests/test_model_parallel_hierarchical.py @@ -73,7 +73,7 @@ class ModelParallelHierarchicalTest(ModelParallelTestShared): [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] @@ -162,7 +162,7 @@ def test_sharding_nccl_twrw( [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] @@ -286,7 +286,7 @@ def test_sharding_empty_rank( [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] @@ -355,7 +355,7 @@ def test_embedding_tower_nccl( [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] diff --git a/torchrec/distributed/tests/test_sequence_model_parallel.py b/torchrec/distributed/tests/test_sequence_model_parallel.py index f1f773aa1..aec092354 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel.py @@ -57,7 +57,7 @@ class SequenceModelParallelTest(MultiProcessTestBase): [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] @@ -150,7 +150,7 @@ def test_sharding_nccl_dp( [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] @@ -203,7 +203,7 @@ def test_sharding_nccl_tw( [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] diff --git a/torchrec/distributed/tests/test_sequence_model_parallel_hierarchical.py b/torchrec/distributed/tests/test_sequence_model_parallel_hierarchical.py index 325c51be4..7ea296a91 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel_hierarchical.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel_hierarchical.py @@ -62,7 +62,7 @@ class SequenceModelParallelHierarchicalTest(MultiProcessTestBase): [ None, { - "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ]