From 253b3206ebb7f390d12fe0801878aec18dc7c653 Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo <84471416+AlCatt91@users.noreply.github.com> Date: Wed, 8 Nov 2023 15:41:44 +0000 Subject: [PATCH] Add ConvE scoring function (#35) * ConvE implementation * address PR comments --- besskge/embedding.py | 18 ++++ besskge/scoring.py | 219 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 228 insertions(+), 9 deletions(-) diff --git a/besskge/embedding.py b/besskge/embedding.py index ce04c54..86244da 100644 --- a/besskge/embedding.py +++ b/besskge/embedding.py @@ -26,6 +26,24 @@ def init_uniform_norm(embedding_table: torch.Tensor) -> torch.Tensor: return torch.nn.functional.normalize(torch.nn.init.uniform(embedding_table), dim=-1) +def init_xavier_norm(embedding_table: torch.Tensor, gain: float = 1.0) -> torch.Tensor: + """ + Initialize embeddings according to Xavier normal scheme, with + `fan_in = 0`, `fan_out=row_size`. + + :param embedding_table: + Tensor of embedding parameters to initialize. + :param gain: + Scaling factor for standard deviation. Default: 1.0. + + :return: + Initialized tensor. + """ + return torch.nn.init.normal_( + embedding_table, std=gain * np.sqrt(2.0 / embedding_table.shape[-1]) + ) + + def init_KGE_uniform( embedding_table: torch.Tensor, b: float = 1.0, divide_by_embedding_size: bool = True ) -> torch.Tensor: diff --git a/besskge/scoring.py b/besskge/scoring.py index cf4173b..7c0b985 100644 --- a/besskge/scoring.py +++ b/besskge/scoring.py @@ -15,6 +15,7 @@ init_KGE_normal, init_KGE_uniform, init_uniform_norm, + init_xavier_norm, initialize_entity_embedding, initialize_relation_embedding, refactor_embedding_sharding, @@ -291,7 +292,7 @@ def __init__( :param relation_initializer: Initialization function or table for relation embeddings. :param inverse_relations: - If True, learn embeddings for inverse relations. Default: False + If True, learn embeddings for inverse relations. Default: False. """ super(TransE, self).__init__( negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm @@ -391,7 +392,7 @@ def __init__( :param relation_initializer: Initialization function or table for relation embeddings. :param inverse_relations: - If True, learn embeddings for inverse relations. Default: False + If True, learn embeddings for inverse relations. Default: False. """ super(RotatE, self).__init__( negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm @@ -502,7 +503,7 @@ def __init__( If True, L2-normalize head and tail entity embeddings before projecting, as in :cite:p:`PairRE`. Default: True. :param inverse_relations: - If True, learn embeddings for inverse relations. Default: False + If True, learn embeddings for inverse relations. Default: False. """ super(PairRE, self).__init__( negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm @@ -637,7 +638,7 @@ def __init__( Offset factor for head/tail relation projections, as in TripleREv2. Default: 0.0 (no offset). :param inverse_relations: - If True, learn embeddings for inverse relations. Default: False + If True, learn embeddings for inverse relations. Default: False. """ super(TripleRE, self).__init__( negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm @@ -776,7 +777,7 @@ def __init__( :param relation_initializer: Initialization function or table for relation embeddings. :param inverse_relations: - If True, learn embeddings for inverse relations. Default: False + If True, learn embeddings for inverse relations. Default: False. """ super(DistMult, self).__init__(negative_sample_sharing=negative_sample_sharing) @@ -870,7 +871,7 @@ def __init__( :param relation_initializer: Initialization function or table for relation embeddings. :param inverse_relations: - If True, learn embeddings for inverse relations. Default: False + If True, learn embeddings for inverse relations. Default: False. """ super(ComplEx, self).__init__(negative_sample_sharing=negative_sample_sharing) @@ -944,6 +945,206 @@ def score_tails( ) +class ConvE(MatrixDecompositionScoreFunction): + """ + ConvE scoring function :cite:p:`ConvE`. + + Note that, differently from :cite:p:`ConvE`, the scores returned by this class + have not been passed through a final sigmoid layer, as we assume that this is + included in the loss function. + + By design, this scoring function should be used in combination with a + negative/candidate sampler that only corrupts tails (possibly after + including all inverse triples in the dataset, see the `add_inverse_triples` + argument in :func:`besskge.sharding.PartitionedTripleSet.create_from_dataset`). + """ + + def __init__( + self, + negative_sample_sharing: bool, + sharding: Sharding, + n_relation_type: int, + embedding_size: int, + embedding_height: int, + embedding_width: int, + entity_initializer: Union[torch.Tensor, List[Callable[..., torch.Tensor]]] = [ + init_xavier_norm, + torch.nn.init.zeros_, + ], + relation_initializer: Union[torch.Tensor, List[Callable[..., torch.Tensor]]] = [ + init_xavier_norm, + ], + inverse_relations: bool = True, + input_channels: int = 1, + output_channels: int = 32, + kernel_height: int = 3, + kernel_width: int = 3, + input_dropout: float = 0.2, + feature_map_dropout: float = 0.2, + hidden_dropout: float = 0.3, + batch_normalization: bool = True, + ) -> None: + """ + Initialize ConvE model. + + :param negative_sample_sharing: + see :meth:`DistanceBasedScoreFunction.__init__` + :param sharding: + Entity sharding. + :param n_relation_type: + Number of relation types in the knowledge graph. + :param embedding_size: + Size of entity and relation embeddings. + :param embedding_height: + Height of the 2D-reshaping of the concatenation of + head and relation embeddings. + :param embedding_width: + Width of the 2D-reshaping of the concatenation of + head and relation embeddings. + :param entity_initializer: + Initialization functions or table for entity embeddings. + If not passing a table, two functions are needed: the initializer + for entity embeddings and initializer for (scalar) tail biases. + :param relation_initializer: + Initialization function or table for relation embeddings. + :param inverse_relations: + If True, learn embeddings for inverse relations. Default: True. + :param input_channels: + Number of input channels of the Conv2D operator. Default: 1. + :param output_channels: + Number of output channels of the Conv2D operator. Default: 32. + :param kernel_height: + Height of the Conv2D kernel. Default: 3. + :param kernel_width: + Width of the Conv2D kernel. Default: 3. + :param input_dropout: + Rate of Dropout applied before the convolution. Default: 0.2. + :param feature_map_dropout: + Rate of Dropout applied after the convolution. Default: 0.2. + :param hidden_dropout: + Rate of Dropout applied after the Linear layer. Default: 0.3. + :param batch_normalization: + If True, apply batch normalization before and after the + convolution and after the Linear layer. Default: True. + """ + super(ConvE, self).__init__(negative_sample_sharing=negative_sample_sharing) + + self.sharding = sharding + + if input_channels * embedding_width * embedding_height != embedding_size: + raise ValueError( + "`embedding_size` needs to be equal to" + " `input_channels * embedding_width * embedding_height`" + ) + + # self.entity_embedding[..., :embedding_size] entity_embeddings + # self.entity_embedding[..., -1] tail biases + self.entity_embedding = initialize_entity_embedding( + self.sharding, entity_initializer, [embedding_size, 1] + ) + self.relation_embedding = initialize_relation_embedding( + n_relation_type, inverse_relations, relation_initializer, [embedding_size] + ) + assert ( + self.entity_embedding.shape[-1] - 1 + == self.relation_embedding.shape[-1] + == embedding_size + ), ( + "ConvE requires `embedding_size + 1` embedding parameters for each entity" + " and `embedding_size` embedding parameters for each relation" + ) + self.embedding_size = embedding_size + + self.inp_channels = input_channels + self.emb_h = embedding_height + self.emb_w = embedding_width + conv_layers = [ + torch.nn.Dropout(input_dropout), + torch.nn.Conv2d( + in_channels=self.inp_channels, + out_channels=output_channels, + kernel_size=(kernel_height, kernel_width), + ), + torch.nn.ReLU(), + torch.nn.Dropout2d(feature_map_dropout), + ] + fc_layers = [ + torch.nn.Linear( + output_channels + * (2 * self.emb_h - kernel_height + 1) + * (self.emb_w - kernel_width + 1), + embedding_size, + ), + torch.nn.Dropout(hidden_dropout), + torch.nn.ReLU(), + ] + if batch_normalization: + conv_layers.insert(0, torch.nn.BatchNorm2d(input_channels)) + conv_layers.insert(3, torch.nn.BatchNorm2d(output_channels)) + fc_layers.insert(2, torch.nn.BatchNorm1d(embedding_size)) + self.conv_layers = torch.nn.Sequential(*conv_layers) + self.fc_layers = torch.nn.Sequential(*fc_layers) + + # docstr-coverage: inherited + def score_triple( + self, + head_emb: torch.Tensor, + relation_id: torch.Tensor, + tail_emb: torch.Tensor, + ) -> torch.Tensor: + relation_emb = torch.index_select( + self.relation_embedding, index=relation_id, dim=0 + ) + # Discard bias for heads + head_emb = head_emb[..., :-1] + tail_emb, tail_bias = torch.split(tail_emb, self.embedding_size, dim=-1) + hr_cat = torch.cat( + [ + head_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w), + relation_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w), + ], + dim=-2, + ) + hr_conv = self.fc_layers(self.conv_layers(hr_cat).flatten(start_dim=1)) + return self.reduce_embedding(hr_conv * tail_emb) + tail_bias.squeeze(-1) + + # docstr-coverage: inherited + def score_heads( + self, + head_emb: torch.Tensor, + relation_id: torch.Tensor, + tail_emb: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError("ConvE should not be used with head corruption") + + # docstr-coverage: inherited + def score_tails( + self, + head_emb: torch.Tensor, + relation_id: torch.Tensor, + tail_emb: torch.Tensor, + ) -> torch.Tensor: + relation_emb = torch.index_select( + self.relation_embedding, index=relation_id, dim=0 + ) + # Discard bias for heads + head_emb = head_emb[..., :-1] + tail_emb, tail_bias = torch.split(tail_emb, self.embedding_size, dim=-1) + if self.negative_sample_sharing: + tail_bias = tail_bias.view(1, -1) + else: + tail_bias = tail_bias.squeeze(-1) + hr_cat = torch.cat( + [ + head_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w), + relation_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w), + ], + dim=-2, + ) + hr_conv = self.fc_layers(self.conv_layers(hr_cat).flatten(start_dim=1)) + return self.broadcasted_dot_product(hr_conv, tail_emb) + tail_bias + + class BoxE(DistanceBasedScoreFunction): """ BoxE scoring function :cite:p:`BoxE`. @@ -1000,7 +1201,7 @@ def __init__( Softening parameter for geometric normalization of box widths. Default: 1e-6. :param inverse_relations: - If True, learn embeddings for inverse relations. Default: False + If True, learn embeddings for inverse relations. Default: False. """ super(BoxE, self).__init__( negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm @@ -1258,7 +1459,7 @@ def __init__( :param offset: Offset applied to auxiliary entity embeddings. Default: 1.0. :param inverse_relations: - If True, learn embeddings for inverse relations. Default: False + If True, learn embeddings for inverse relations. Default: False. """ super(InterHT, self).__init__( negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm @@ -1415,7 +1616,7 @@ def __init__( :param offset: Offset applied to tilde entity embeddings. Default: 1.0. :param inverse_relations: - If True, learn embeddings for inverse relations. Default: False + If True, learn embeddings for inverse relations. Default: False. """ super(TranS, self).__init__( negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm