Skip to content

Commit

Permalink
Fix broken two tower test (#1425)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1425

Fixing test_two_tower_retrieval.

Reviewed By: zainhuda

Differential Revision: D49957751

fbshipit-source-id: 1f54a16d24602ee0570d7bcbb0e2fa78b9a1e519
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Oct 6, 2023
1 parent b496641 commit 8beba7d
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 7 deletions.
12 changes: 10 additions & 2 deletions examples/retrieval/modules/two_tower.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(
layer_sizes: List[int],
k: int,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
self.embedding_dim: int = query_ebc.embedding_bag_configs()[0].embedding_dim
Expand All @@ -186,10 +187,16 @@ def __init__(
self.query_ebc = query_ebc
self.candidate_ebc = candidate_ebc
self.query_proj = MLP(
in_size=self.embedding_dim, layer_sizes=layer_sizes, device=device
in_size=self.embedding_dim,
layer_sizes=layer_sizes,
device=device,
dtype=dtype,
)
self.candidate_proj = MLP(
in_size=self.embedding_dim, layer_sizes=layer_sizes, device=device
in_size=self.embedding_dim,
layer_sizes=layer_sizes,
device=device,
dtype=dtype,
)
self.faiss_index: Union[faiss.GpuIndexIVFPQ, faiss.IndexIVFPQ] = faiss_index
self.k = k
Expand All @@ -212,6 +219,7 @@ def forward(self, query_kjt: KeyedJaggedTensor) -> torch.Tensor:
candidates = torch.empty(
(batch_size, self.k), device=self.device, dtype=torch.int64
)
query_embedding = query_embedding.to(torch.float32) # required by faiss
self.faiss_index.search(query_embedding, self.k, distances, candidates)

# candidate lookup
Expand Down
1 change: 1 addition & 0 deletions examples/retrieval/tests/test_two_tower_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ def test_infer_function(self) -> None:
infer(
embedding_dim=16,
layer_sizes=[16],
world_size=2,
)
15 changes: 11 additions & 4 deletions examples/retrieval/two_tower_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ShardingEnv, ShardingType
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

Expand Down Expand Up @@ -78,6 +78,7 @@ def infer(
faiss_device_idx: int = 0,
batch_size: int = 32,
load_dir: Optional[str] = None,
world_size: int = 2,
) -> None:
"""
Loads the serialized model and FAISS index from `two_tower_train.py`.
Expand Down Expand Up @@ -116,6 +117,7 @@ def infer(
embedding_dim=embedding_dim,
num_embeddings=num_embeddings,
feature_names=[feature_name],
data_type=DataType.FP16,
)
ebcs.append(
EmbeddingBagCollection(
Expand Down Expand Up @@ -156,7 +158,9 @@ def infer(
index.train(embeddings)
index.add(embeddings)

retrieval_model = TwoTowerRetrieval(index, ebcs[0], ebcs[1], layer_sizes, k, device)
retrieval_model = TwoTowerRetrieval(
index, ebcs[0], ebcs[1], layer_sizes, k, device, dtype=torch.float16
)

constraints = {}
for feature_name in two_tower_column_names:
Expand All @@ -166,13 +170,16 @@ def infer(
)

quant_model = trec_infer.modules.quantize_embeddings(
retrieval_model, dtype=torch.qint8, inplace=True
retrieval_model,
dtype=torch.qint8,
inplace=True,
output_dtype=torch.float16,
)

dmp = DistributedModelParallel(
module=quant_model,
device=device,
env=ShardingEnv.from_local(world_size=2, rank=model_device_idx),
env=ShardingEnv.from_local(world_size=world_size, rank=model_device_idx),
init_data_parallel=False,
)
if retrieval_sd is not None:
Expand Down
9 changes: 8 additions & 1 deletion torchrec/modules/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,18 @@ def __init__(
Callable[[torch.Tensor], torch.Tensor],
] = torch.relu,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}")
self._out_size = out_size
self._in_size = in_size
self._linear: nn.Linear = nn.Linear(
self._in_size, self._out_size, bias=bias, device=device
self._in_size,
self._out_size,
bias=bias,
device=device,
dtype=dtype,
)
self._activation_fn: Callable[[torch.Tensor], torch.Tensor] = activation

Expand Down Expand Up @@ -120,6 +125,7 @@ def __init__(
Callable[[torch.Tensor], torch.Tensor],
] = torch.relu,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()

Expand All @@ -137,6 +143,7 @@ def __init__(
bias=bias,
activation=extract_module_or_tensor_callable(activation),
device=device,
dtype=dtype,
)
for i in range(len(layer_sizes))
]
Expand Down

0 comments on commit 8beba7d

Please sign in to comment.