Skip to content

Commit

Permalink
Add ConvE scoring function (#35)
Browse files Browse the repository at this point in the history
* ConvE implementation

* address PR comments
  • Loading branch information
AlCatt91 authored Nov 8, 2023
1 parent 5d2ec26 commit 253b320
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 9 deletions.
18 changes: 18 additions & 0 deletions besskge/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@ def init_uniform_norm(embedding_table: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.normalize(torch.nn.init.uniform(embedding_table), dim=-1)


def init_xavier_norm(embedding_table: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
"""
Initialize embeddings according to Xavier normal scheme, with
`fan_in = 0`, `fan_out=row_size`.
:param embedding_table:
Tensor of embedding parameters to initialize.
:param gain:
Scaling factor for standard deviation. Default: 1.0.
:return:
Initialized tensor.
"""
return torch.nn.init.normal_(
embedding_table, std=gain * np.sqrt(2.0 / embedding_table.shape[-1])
)


def init_KGE_uniform(
embedding_table: torch.Tensor, b: float = 1.0, divide_by_embedding_size: bool = True
) -> torch.Tensor:
Expand Down
219 changes: 210 additions & 9 deletions besskge/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
init_KGE_normal,
init_KGE_uniform,
init_uniform_norm,
init_xavier_norm,
initialize_entity_embedding,
initialize_relation_embedding,
refactor_embedding_sharding,
Expand Down Expand Up @@ -291,7 +292,7 @@ def __init__(
:param relation_initializer:
Initialization function or table for relation embeddings.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(TransE, self).__init__(
negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm
Expand Down Expand Up @@ -391,7 +392,7 @@ def __init__(
:param relation_initializer:
Initialization function or table for relation embeddings.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(RotatE, self).__init__(
negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm
Expand Down Expand Up @@ -502,7 +503,7 @@ def __init__(
If True, L2-normalize head and tail entity embeddings before projecting,
as in :cite:p:`PairRE`. Default: True.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(PairRE, self).__init__(
negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm
Expand Down Expand Up @@ -637,7 +638,7 @@ def __init__(
Offset factor for head/tail relation projections, as in TripleREv2.
Default: 0.0 (no offset).
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(TripleRE, self).__init__(
negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm
Expand Down Expand Up @@ -776,7 +777,7 @@ def __init__(
:param relation_initializer:
Initialization function or table for relation embeddings.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(DistMult, self).__init__(negative_sample_sharing=negative_sample_sharing)

Expand Down Expand Up @@ -870,7 +871,7 @@ def __init__(
:param relation_initializer:
Initialization function or table for relation embeddings.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(ComplEx, self).__init__(negative_sample_sharing=negative_sample_sharing)

Expand Down Expand Up @@ -944,6 +945,206 @@ def score_tails(
)


class ConvE(MatrixDecompositionScoreFunction):
"""
ConvE scoring function :cite:p:`ConvE`.
Note that, differently from :cite:p:`ConvE`, the scores returned by this class
have not been passed through a final sigmoid layer, as we assume that this is
included in the loss function.
By design, this scoring function should be used in combination with a
negative/candidate sampler that only corrupts tails (possibly after
including all inverse triples in the dataset, see the `add_inverse_triples`
argument in :func:`besskge.sharding.PartitionedTripleSet.create_from_dataset`).
"""

def __init__(
self,
negative_sample_sharing: bool,
sharding: Sharding,
n_relation_type: int,
embedding_size: int,
embedding_height: int,
embedding_width: int,
entity_initializer: Union[torch.Tensor, List[Callable[..., torch.Tensor]]] = [
init_xavier_norm,
torch.nn.init.zeros_,
],
relation_initializer: Union[torch.Tensor, List[Callable[..., torch.Tensor]]] = [
init_xavier_norm,
],
inverse_relations: bool = True,
input_channels: int = 1,
output_channels: int = 32,
kernel_height: int = 3,
kernel_width: int = 3,
input_dropout: float = 0.2,
feature_map_dropout: float = 0.2,
hidden_dropout: float = 0.3,
batch_normalization: bool = True,
) -> None:
"""
Initialize ConvE model.
:param negative_sample_sharing:
see :meth:`DistanceBasedScoreFunction.__init__`
:param sharding:
Entity sharding.
:param n_relation_type:
Number of relation types in the knowledge graph.
:param embedding_size:
Size of entity and relation embeddings.
:param embedding_height:
Height of the 2D-reshaping of the concatenation of
head and relation embeddings.
:param embedding_width:
Width of the 2D-reshaping of the concatenation of
head and relation embeddings.
:param entity_initializer:
Initialization functions or table for entity embeddings.
If not passing a table, two functions are needed: the initializer
for entity embeddings and initializer for (scalar) tail biases.
:param relation_initializer:
Initialization function or table for relation embeddings.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: True.
:param input_channels:
Number of input channels of the Conv2D operator. Default: 1.
:param output_channels:
Number of output channels of the Conv2D operator. Default: 32.
:param kernel_height:
Height of the Conv2D kernel. Default: 3.
:param kernel_width:
Width of the Conv2D kernel. Default: 3.
:param input_dropout:
Rate of Dropout applied before the convolution. Default: 0.2.
:param feature_map_dropout:
Rate of Dropout applied after the convolution. Default: 0.2.
:param hidden_dropout:
Rate of Dropout applied after the Linear layer. Default: 0.3.
:param batch_normalization:
If True, apply batch normalization before and after the
convolution and after the Linear layer. Default: True.
"""
super(ConvE, self).__init__(negative_sample_sharing=negative_sample_sharing)

self.sharding = sharding

if input_channels * embedding_width * embedding_height != embedding_size:
raise ValueError(
"`embedding_size` needs to be equal to"
" `input_channels * embedding_width * embedding_height`"
)

# self.entity_embedding[..., :embedding_size] entity_embeddings
# self.entity_embedding[..., -1] tail biases
self.entity_embedding = initialize_entity_embedding(
self.sharding, entity_initializer, [embedding_size, 1]
)
self.relation_embedding = initialize_relation_embedding(
n_relation_type, inverse_relations, relation_initializer, [embedding_size]
)
assert (
self.entity_embedding.shape[-1] - 1
== self.relation_embedding.shape[-1]
== embedding_size
), (
"ConvE requires `embedding_size + 1` embedding parameters for each entity"
" and `embedding_size` embedding parameters for each relation"
)
self.embedding_size = embedding_size

self.inp_channels = input_channels
self.emb_h = embedding_height
self.emb_w = embedding_width
conv_layers = [
torch.nn.Dropout(input_dropout),
torch.nn.Conv2d(
in_channels=self.inp_channels,
out_channels=output_channels,
kernel_size=(kernel_height, kernel_width),
),
torch.nn.ReLU(),
torch.nn.Dropout2d(feature_map_dropout),
]
fc_layers = [
torch.nn.Linear(
output_channels
* (2 * self.emb_h - kernel_height + 1)
* (self.emb_w - kernel_width + 1),
embedding_size,
),
torch.nn.Dropout(hidden_dropout),
torch.nn.ReLU(),
]
if batch_normalization:
conv_layers.insert(0, torch.nn.BatchNorm2d(input_channels))
conv_layers.insert(3, torch.nn.BatchNorm2d(output_channels))
fc_layers.insert(2, torch.nn.BatchNorm1d(embedding_size))
self.conv_layers = torch.nn.Sequential(*conv_layers)
self.fc_layers = torch.nn.Sequential(*fc_layers)

# docstr-coverage: inherited
def score_triple(
self,
head_emb: torch.Tensor,
relation_id: torch.Tensor,
tail_emb: torch.Tensor,
) -> torch.Tensor:
relation_emb = torch.index_select(
self.relation_embedding, index=relation_id, dim=0
)
# Discard bias for heads
head_emb = head_emb[..., :-1]
tail_emb, tail_bias = torch.split(tail_emb, self.embedding_size, dim=-1)
hr_cat = torch.cat(
[
head_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w),
relation_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w),
],
dim=-2,
)
hr_conv = self.fc_layers(self.conv_layers(hr_cat).flatten(start_dim=1))
return self.reduce_embedding(hr_conv * tail_emb) + tail_bias.squeeze(-1)

# docstr-coverage: inherited
def score_heads(
self,
head_emb: torch.Tensor,
relation_id: torch.Tensor,
tail_emb: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError("ConvE should not be used with head corruption")

# docstr-coverage: inherited
def score_tails(
self,
head_emb: torch.Tensor,
relation_id: torch.Tensor,
tail_emb: torch.Tensor,
) -> torch.Tensor:
relation_emb = torch.index_select(
self.relation_embedding, index=relation_id, dim=0
)
# Discard bias for heads
head_emb = head_emb[..., :-1]
tail_emb, tail_bias = torch.split(tail_emb, self.embedding_size, dim=-1)
if self.negative_sample_sharing:
tail_bias = tail_bias.view(1, -1)
else:
tail_bias = tail_bias.squeeze(-1)
hr_cat = torch.cat(
[
head_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w),
relation_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w),
],
dim=-2,
)
hr_conv = self.fc_layers(self.conv_layers(hr_cat).flatten(start_dim=1))
return self.broadcasted_dot_product(hr_conv, tail_emb) + tail_bias


class BoxE(DistanceBasedScoreFunction):
"""
BoxE scoring function :cite:p:`BoxE`.
Expand Down Expand Up @@ -1000,7 +1201,7 @@ def __init__(
Softening parameter for geometric normalization of box widths.
Default: 1e-6.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(BoxE, self).__init__(
negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm
Expand Down Expand Up @@ -1258,7 +1459,7 @@ def __init__(
:param offset:
Offset applied to auxiliary entity embeddings. Default: 1.0.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(InterHT, self).__init__(
negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm
Expand Down Expand Up @@ -1415,7 +1616,7 @@ def __init__(
:param offset:
Offset applied to tilde entity embeddings. Default: 1.0.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(TranS, self).__init__(
negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm
Expand Down

0 comments on commit 253b320

Please sign in to comment.