From 8beba7dc16edaf66ac5baa94b9d8e2bafc61d8f9 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Fri, 6 Oct 2023 09:51:52 -0700 Subject: [PATCH] Fix broken two tower test (#1425) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1425 Fixing test_two_tower_retrieval. Reviewed By: zainhuda Differential Revision: D49957751 fbshipit-source-id: 1f54a16d24602ee0570d7bcbb0e2fa78b9a1e519 --- examples/retrieval/modules/two_tower.py | 12 ++++++++++-- .../retrieval/tests/test_two_tower_retrieval.py | 1 + examples/retrieval/two_tower_retrieval.py | 15 +++++++++++---- torchrec/modules/mlp.py | 9 ++++++++- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/examples/retrieval/modules/two_tower.py b/examples/retrieval/modules/two_tower.py index cb1ac1954..224bcbfc7 100644 --- a/examples/retrieval/modules/two_tower.py +++ b/examples/retrieval/modules/two_tower.py @@ -174,6 +174,7 @@ def __init__( layer_sizes: List[int], k: int, device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() self.embedding_dim: int = query_ebc.embedding_bag_configs()[0].embedding_dim @@ -186,10 +187,16 @@ def __init__( self.query_ebc = query_ebc self.candidate_ebc = candidate_ebc self.query_proj = MLP( - in_size=self.embedding_dim, layer_sizes=layer_sizes, device=device + in_size=self.embedding_dim, + layer_sizes=layer_sizes, + device=device, + dtype=dtype, ) self.candidate_proj = MLP( - in_size=self.embedding_dim, layer_sizes=layer_sizes, device=device + in_size=self.embedding_dim, + layer_sizes=layer_sizes, + device=device, + dtype=dtype, ) self.faiss_index: Union[faiss.GpuIndexIVFPQ, faiss.IndexIVFPQ] = faiss_index self.k = k @@ -212,6 +219,7 @@ def forward(self, query_kjt: KeyedJaggedTensor) -> torch.Tensor: candidates = torch.empty( (batch_size, self.k), device=self.device, dtype=torch.int64 ) + query_embedding = query_embedding.to(torch.float32) # required by faiss self.faiss_index.search(query_embedding, self.k, distances, candidates) # candidate lookup diff --git a/examples/retrieval/tests/test_two_tower_retrieval.py b/examples/retrieval/tests/test_two_tower_retrieval.py index eef1ef455..9b101b269 100644 --- a/examples/retrieval/tests/test_two_tower_retrieval.py +++ b/examples/retrieval/tests/test_two_tower_retrieval.py @@ -26,4 +26,5 @@ def test_infer_function(self) -> None: infer( embedding_dim=16, layer_sizes=[16], + world_size=2, ) diff --git a/examples/retrieval/two_tower_retrieval.py b/examples/retrieval/two_tower_retrieval.py index b1b4ccb49..7c11e75ba 100644 --- a/examples/retrieval/two_tower_retrieval.py +++ b/examples/retrieval/two_tower_retrieval.py @@ -18,7 +18,7 @@ from torchrec.distributed.model_parallel import DistributedModelParallel from torchrec.distributed.planner.types import ParameterConstraints from torchrec.distributed.types import ShardingEnv, ShardingType -from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -78,6 +78,7 @@ def infer( faiss_device_idx: int = 0, batch_size: int = 32, load_dir: Optional[str] = None, + world_size: int = 2, ) -> None: """ Loads the serialized model and FAISS index from `two_tower_train.py`. @@ -116,6 +117,7 @@ def infer( embedding_dim=embedding_dim, num_embeddings=num_embeddings, feature_names=[feature_name], + data_type=DataType.FP16, ) ebcs.append( EmbeddingBagCollection( @@ -156,7 +158,9 @@ def infer( index.train(embeddings) index.add(embeddings) - retrieval_model = TwoTowerRetrieval(index, ebcs[0], ebcs[1], layer_sizes, k, device) + retrieval_model = TwoTowerRetrieval( + index, ebcs[0], ebcs[1], layer_sizes, k, device, dtype=torch.float16 + ) constraints = {} for feature_name in two_tower_column_names: @@ -166,13 +170,16 @@ def infer( ) quant_model = trec_infer.modules.quantize_embeddings( - retrieval_model, dtype=torch.qint8, inplace=True + retrieval_model, + dtype=torch.qint8, + inplace=True, + output_dtype=torch.float16, ) dmp = DistributedModelParallel( module=quant_model, device=device, - env=ShardingEnv.from_local(world_size=2, rank=model_device_idx), + env=ShardingEnv.from_local(world_size=world_size, rank=model_device_idx), init_data_parallel=False, ) if retrieval_sd is not None: diff --git a/torchrec/modules/mlp.py b/torchrec/modules/mlp.py index c369b24c3..41685b341 100644 --- a/torchrec/modules/mlp.py +++ b/torchrec/modules/mlp.py @@ -50,13 +50,18 @@ def __init__( Callable[[torch.Tensor], torch.Tensor], ] = torch.relu, device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}") self._out_size = out_size self._in_size = in_size self._linear: nn.Linear = nn.Linear( - self._in_size, self._out_size, bias=bias, device=device + self._in_size, + self._out_size, + bias=bias, + device=device, + dtype=dtype, ) self._activation_fn: Callable[[torch.Tensor], torch.Tensor] = activation @@ -120,6 +125,7 @@ def __init__( Callable[[torch.Tensor], torch.Tensor], ] = torch.relu, device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() @@ -137,6 +143,7 @@ def __init__( bias=bias, activation=extract_module_or_tensor_callable(activation), device=device, + dtype=dtype, ) for i in range(len(layer_sizes)) ]