diff --git a/besskge/pipeline.py b/besskge/pipeline.py index 535a1d6..83c593f 100644 --- a/besskge/pipeline.py +++ b/besskge/pipeline.py @@ -23,9 +23,9 @@ class AllScoresPipeline(torch.nn.Module): """ Pipeline to compute scores of (h, r, ?) / (?, r, t) queries against all entities - in the KG, and related prediction metrics. - It supports filtering out the scores of specific completions that appear in a given - set of triples. + in the KG (or a given subset of entities), and related prediction metrics. + It supports filtering out, for each query, the scores of specific completions that + appear in a given set of triples. To be used in combination with a batch sampler based on a "h_shard"/"t_shard"-partitioned triple set. @@ -38,6 +38,7 @@ def __init__( score_fn: BaseScoreFunction, evaluation: Optional[Evaluation] = None, filter_triples: Optional[List[Union[torch.Tensor, NDArray[np.int32]]]] = None, + candidate_ents: Optional[Union[torch.Tensor, NDArray[np.int32]]] = None, return_scores: bool = False, return_topk: bool = False, k: int = 10, @@ -62,6 +63,12 @@ def __init__( The set of all triples whose scores need to be filtered. The triples passed here must have GLOBAL IDs for head/tail entities. Default: None. + :param candidate_ents: + If specified, score queries only against a given set of entities. + This array needs to contain the global IDs of the + candidate entities to be used for completion. All other entities + will then be ignored when scoring queries. + Default: None (i.e. score queries against all entities). :param return_scores: If True, store and return scores of all queries' completions (with filters applied, if specified). @@ -165,6 +172,13 @@ def __init__( ], dim=0, ) + self.candidate_mask: Optional[torch.Tensor] = None + if candidate_ents is not None: + self.candidate_mask = torch.from_numpy( + np.setdiff1d( + np.arange(self.bess_module.sharding.n_entity), candidate_ents + ) + ) def forward(self) -> Dict[str, Any]: """ @@ -231,6 +245,10 @@ def forward(self) -> Dict[str, Any]: batch_scores_filt = batch_scores[triple_mask.flatten()][ :, np.unique(np.concatenate(batch_idx), return_index=True)[1] ][:, : self.bess_module.sharding.n_entity] + if self.candidate_mask is not None: + # Filter scores for entities that are not in + # the given set of canidates + batch_scores_filt[:, self.candidate_mask] = -torch.inf if ground_truth is not None: # Scores of positive triples true_scores = batch_scores_filt[ diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index cd54962..b230585 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -36,13 +36,19 @@ test_triples_r = np.random.randint(n_relation_type, size=n_test_triple) triples = {"test": np.stack([test_triples_h, test_triples_r, test_triples_t], axis=1)} +compl_candidates = np.arange(0, n_entity - 1, step=5) + @pytest.mark.parametrize("corruption_scheme", ["h", "t"]) @pytest.mark.parametrize( "filter_scores, extra_only", [(True, True), (True, False), (False, False)] ) +@pytest.mark.parametrize("filter_candidates", [True, False]) def test_all_scores_pipeline( - corruption_scheme: str, filter_scores: bool, extra_only: bool + corruption_scheme: str, + filter_scores: bool, + extra_only: bool, + filter_candidates: bool, ) -> None: ds = KGDataset( n_entity=n_entity, @@ -104,6 +110,7 @@ def test_all_scores_pipeline( score_fn, evaluation, filter_triples=triples_to_filter, # type: ignore + candidate_ents=compl_candidates if filter_candidates else None, return_scores=True, return_topk=True, k=10, @@ -136,6 +143,14 @@ def test_all_scores_pipeline( triple_reordered[:, 1], unsharded_entity_table[triple_reordered[:, 2]], ).flatten() + if filter_candidates: + # positive score -inf if ground truth not in candidate list + pos_scores[ + torch.from_numpy( + ~np.in1d(triple_reordered[:, ground_truth_col], compl_candidates) + ) + ] = -torch.inf + # mask positive scores to compute metrics cpu_scores[ torch.arange(cpu_scores.shape[0]), triple_reordered[:, ground_truth_col] @@ -161,19 +176,31 @@ def test_all_scores_pipeline( assert torch.all( tr_filter[1::2, 1] == triple_reordered[:, ground_truth_col] + 1 ) + if filter_candidates: + cand_mask = np.setdiff1d(np.arange(cpu_scores.shape[-1]), compl_candidates) + cpu_scores[:, cand_mask] = -torch.inf cpu_ranks = evaluation.ranks_from_scores(pos_scores, cpu_scores) - # we allow for a off-by-one rank difference on at most 1% of triples, - # due to rounding differences in CPU vs IPU score computations - assert torch.all(torch.abs(cpu_ranks - out["ranks"]) <= 1) - assert (cpu_ranks != out["ranks"]).sum() < n_test_triple / 100 # restore positive scores cpu_scores[ torch.arange(cpu_scores.shape[0]), triple_reordered[:, ground_truth_col] ] = pos_scores + cpu_preds = torch.topk(cpu_scores, k=pipeline.k, dim=-1).indices + assert torch.all(cpu_preds == out["topk_global_id"]) + + if filter_candidates: + # check that all predictions are in set of candidates + assert np.all(np.in1d(out["topk_global_id"], compl_candidates)) + assert np.all(np.in1d(cpu_preds, compl_candidates)) + + cpu_scores = cpu_scores[:, compl_candidates] + out["scores"] = out["scores"][:, compl_candidates] + assert_close(cpu_scores, out["scores"], atol=1e-3, rtol=1e-4) - assert torch.all( - torch.topk(cpu_scores, k=pipeline.k, dim=-1).indices == out["topk_global_id"] - ) + + # we allow for a off-by-one rank difference on at most 1% of triples, + # due to rounding differences in CPU vs IPU score computations + assert torch.all(torch.abs(cpu_ranks - out["ranks"]) <= 1) + assert (cpu_ranks != out["ranks"]).sum() < n_test_triple / 100