Skip to content

Commit

Permalink
Candidate entities for scoring in pipeline (#40)
Browse files Browse the repository at this point in the history
* add feature: restrict scoring to candidate entities in pipeline

* fix test

* CI fix
  • Loading branch information
AlCatt91 authored Mar 21, 2024
1 parent 604351d commit 1fedeac
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 11 deletions.
24 changes: 21 additions & 3 deletions besskge/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
class AllScoresPipeline(torch.nn.Module):
"""
Pipeline to compute scores of (h, r, ?) / (?, r, t) queries against all entities
in the KG, and related prediction metrics.
It supports filtering out the scores of specific completions that appear in a given
set of triples.
in the KG (or a given subset of entities), and related prediction metrics.
It supports filtering out, for each query, the scores of specific completions that
appear in a given set of triples.
To be used in combination with a batch sampler based on a
"h_shard"/"t_shard"-partitioned triple set.
Expand All @@ -38,6 +38,7 @@ def __init__(
score_fn: BaseScoreFunction,
evaluation: Optional[Evaluation] = None,
filter_triples: Optional[List[Union[torch.Tensor, NDArray[np.int32]]]] = None,
candidate_ents: Optional[Union[torch.Tensor, NDArray[np.int32]]] = None,
return_scores: bool = False,
return_topk: bool = False,
k: int = 10,
Expand All @@ -62,6 +63,12 @@ def __init__(
The set of all triples whose scores need to be filtered.
The triples passed here must have GLOBAL IDs for head/tail
entities. Default: None.
:param candidate_ents:
If specified, score queries only against a given set of entities.
This array needs to contain the global IDs of the
candidate entities to be used for completion. All other entities
will then be ignored when scoring queries.
Default: None (i.e. score queries against all entities).
:param return_scores:
If True, store and return scores of all queries' completions
(with filters applied, if specified).
Expand Down Expand Up @@ -165,6 +172,13 @@ def __init__(
],
dim=0,
)
self.candidate_mask: Optional[torch.Tensor] = None
if candidate_ents is not None:
self.candidate_mask = torch.from_numpy(
np.setdiff1d(
np.arange(self.bess_module.sharding.n_entity), candidate_ents
)
)

def forward(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -231,6 +245,10 @@ def forward(self) -> Dict[str, Any]:
batch_scores_filt = batch_scores[triple_mask.flatten()][
:, np.unique(np.concatenate(batch_idx), return_index=True)[1]
][:, : self.bess_module.sharding.n_entity]
if self.candidate_mask is not None:
# Filter scores for entities that are not in
# the given set of canidates
batch_scores_filt[:, self.candidate_mask] = -torch.inf
if ground_truth is not None:
# Scores of positive triples
true_scores = batch_scores_filt[
Expand Down
43 changes: 35 additions & 8 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@
test_triples_r = np.random.randint(n_relation_type, size=n_test_triple)
triples = {"test": np.stack([test_triples_h, test_triples_r, test_triples_t], axis=1)}

compl_candidates = np.arange(0, n_entity - 1, step=5)


@pytest.mark.parametrize("corruption_scheme", ["h", "t"])
@pytest.mark.parametrize(
"filter_scores, extra_only", [(True, True), (True, False), (False, False)]
)
@pytest.mark.parametrize("filter_candidates", [True, False])
def test_all_scores_pipeline(
corruption_scheme: str, filter_scores: bool, extra_only: bool
corruption_scheme: str,
filter_scores: bool,
extra_only: bool,
filter_candidates: bool,
) -> None:
ds = KGDataset(
n_entity=n_entity,
Expand Down Expand Up @@ -104,6 +110,7 @@ def test_all_scores_pipeline(
score_fn,
evaluation,
filter_triples=triples_to_filter, # type: ignore
candidate_ents=compl_candidates if filter_candidates else None,
return_scores=True,
return_topk=True,
k=10,
Expand Down Expand Up @@ -136,6 +143,14 @@ def test_all_scores_pipeline(
triple_reordered[:, 1],
unsharded_entity_table[triple_reordered[:, 2]],
).flatten()
if filter_candidates:
# positive score -inf if ground truth not in candidate list
pos_scores[
torch.from_numpy(
~np.in1d(triple_reordered[:, ground_truth_col], compl_candidates)
)
] = -torch.inf

# mask positive scores to compute metrics
cpu_scores[
torch.arange(cpu_scores.shape[0]), triple_reordered[:, ground_truth_col]
Expand All @@ -161,19 +176,31 @@ def test_all_scores_pipeline(
assert torch.all(
tr_filter[1::2, 1] == triple_reordered[:, ground_truth_col] + 1
)
if filter_candidates:
cand_mask = np.setdiff1d(np.arange(cpu_scores.shape[-1]), compl_candidates)
cpu_scores[:, cand_mask] = -torch.inf

cpu_ranks = evaluation.ranks_from_scores(pos_scores, cpu_scores)
# we allow for a off-by-one rank difference on at most 1% of triples,
# due to rounding differences in CPU vs IPU score computations
assert torch.all(torch.abs(cpu_ranks - out["ranks"]) <= 1)
assert (cpu_ranks != out["ranks"]).sum() < n_test_triple / 100

# restore positive scores
cpu_scores[
torch.arange(cpu_scores.shape[0]), triple_reordered[:, ground_truth_col]
] = pos_scores

cpu_preds = torch.topk(cpu_scores, k=pipeline.k, dim=-1).indices
assert torch.all(cpu_preds == out["topk_global_id"])

if filter_candidates:
# check that all predictions are in set of candidates
assert np.all(np.in1d(out["topk_global_id"], compl_candidates))
assert np.all(np.in1d(cpu_preds, compl_candidates))

cpu_scores = cpu_scores[:, compl_candidates]
out["scores"] = out["scores"][:, compl_candidates]

assert_close(cpu_scores, out["scores"], atol=1e-3, rtol=1e-4)
assert torch.all(
torch.topk(cpu_scores, k=pipeline.k, dim=-1).indices == out["topk_global_id"]
)

# we allow for a off-by-one rank difference on at most 1% of triples,
# due to rounding differences in CPU vs IPU score computations
assert torch.all(torch.abs(cpu_ranks - out["ranks"]) <= 1)
assert (cpu_ranks != out["ranks"]).sum() < n_test_triple / 100

0 comments on commit 1fedeac

Please sign in to comment.