From 2b88065170fba6ed377292f16da0ff5810eefdd9 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Wed, 28 Jun 2023 17:01:11 +0200 Subject: [PATCH] First pass over proposed-API for weight-tying --- merlin/models/torch/inputs/embedding.py | 8 ++ merlin/models/torch/models/base.py | 6 ++ merlin/models/torch/outputs/classification.py | 94 ++++++++++++------- 3 files changed, 72 insertions(+), 36 deletions(-) diff --git a/merlin/models/torch/inputs/embedding.py b/merlin/models/torch/inputs/embedding.py index fe8cf28170..ef339ae781 100644 --- a/merlin/models/torch/inputs/embedding.py +++ b/merlin/models/torch/inputs/embedding.py @@ -374,6 +374,14 @@ def update_feature(self, col_schema: ColumnSchema) -> "EmbeddingTable": return self + def feature_weights(self, name: str): + if name not in self.domains: + raise ValueError() + + domain = self.domains[name] + + return self.table.weights[domain.min : domain.max] + def select(self, selection: Selection) -> Selectable: selected = select(self.input_schema, selection) diff --git a/merlin/models/torch/models/base.py b/merlin/models/torch/models/base.py index 56851d285a..7ebd3226b8 100644 --- a/merlin/models/torch/models/base.py +++ b/merlin/models/torch/models/base.py @@ -62,6 +62,7 @@ def __init__( self, *blocks: nn.Module, optimizer=torch.optim.Adam, + batch_block: nn.Module = None, ): super().__init__() @@ -79,6 +80,11 @@ def initialize(self, data: Union[Dataset, Loader, Batch]): def forward( self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None ): + if not batch: + batch = Batch(inputs, None) + + batch = self.batch_block(batch) + """Performs a forward pass through the model.""" outputs = inputs for block in self.values: diff --git a/merlin/models/torch/outputs/classification.py b/merlin/models/torch/outputs/classification.py index e39ccf1e98..8772b4f018 100644 --- a/merlin/models/torch/outputs/classification.py +++ b/merlin/models/torch/outputs/classification.py @@ -20,6 +20,7 @@ from torchmetrics import AUROC, Accuracy, AveragePrecision, Metric, Precision, Recall import merlin.dtypes as md +from merlin.models.torch import schema from merlin.models.torch.inputs.embedding import EmbeddingTable from merlin.models.torch.outputs.base import ModelOutput from merlin.schema import ColumnSchema, Schema, Tags @@ -127,43 +128,32 @@ def __init__( loss: nn.Module = nn.CrossEntropyLoss(), metrics: Optional[Sequence[Metric]] = None, logits_temperature: float = 1.0, - weight_tying: Optional[EmbeddingTable] = None, ): super().__init__( loss=loss, metrics=metrics, logits_temperature=logits_temperature, ) - self.weight_tying = weight_tying if schema: self.setup_schema(schema) - # def setup( - # self, - # to_call: Optional[ - # Union[ - # Schema, - # ColumnSchema, - # EmbeddingTable, - # "CategoricalTarget", - # "EmbeddingTablePrediction", - # ] - # ] = None - # ): - # if isinstance(to_call, (Schema, ColumnSchema)): - # self.setup_schema(to_call) - # elif isinstance(to_call, (EmbeddingTable)): - # self.prepend(EmbeddingTablePrediction(to_call)) - # elif isinstance(to_call, (CategoricalTarget, EmbeddingTablePrediction)): - # self.prepend(to_call) - # else: - # raise ValueError(f"Invalid to_call type: {type(to_call)}") - - # self.num_classes = self[0].num_classes - - # if not self.metrics: - # self.metrics = self.default_metrics() + @classmethod + def with_weight_tying( + cls, + selection: schema.Selection, + block: nn.Module, + loss: nn.Module = nn.CrossEntropyLoss(), + metrics: Optional[Sequence[Metric]] = None, + logits_temperature: float = 1.0, + ) -> "CategoricalOutput": + self = cls(loss=loss, metrics=metrics, logits_temperature=logits_temperature) + self = self.tie_weights(selection, block) + self.output_schema = categorical_output_schema(self[0].col_schema, self.num_classes) + if not self.metrics: + self.metrics = self.default_metrics(self.num_classes) + + return self def setup_schema(self, target: Optional[Union[ColumnSchema, Schema]]): """Set up the schema for the output. @@ -178,15 +168,27 @@ def setup_schema(self, target: Optional[Union[ColumnSchema, Schema]]): raise ValueError("Schema must contain exactly one column.") target = target.first + to_call = CategoricalTarget(target) self.num_classes = to_call.num_classes - self.prepend(to_call) - self.output_schema = categorical_output_schema(target, self.num_classes) + if not self.metrics: + self.metrics = self.default_metrics(self.num_classes) + + def tie_weights(self, selection: schema.Selection, block: nn.Module) -> "CategoricalOutput": + if isinstance(block, EmbeddingTable): + table = block + else: + try: + selected = schema.select(block, selection) + table = selected.leaf() + except Exception as e: + raise ValueError("Could not find embedding table in block.") from e + + self[0] = EmbeddingTablePrediction(table, selection) - def tie_weights(self, embedding_table: EmbeddingTable): - self.weight_tying = embedding_table + return self @classmethod def default_metrics(cls, num_classes: int) -> List[Metric]: @@ -313,10 +315,19 @@ class EmbeddingTablePrediction(nn.Module): arXiv:1611.01462 (2016). """ - def __init__(self, table: EmbeddingTable): + def __init__(self, table: EmbeddingTable, selection: Optional[schema.Selection] = None): super().__init__() self.table = table - self.num_classes = table.num_embeddings + if len(table.domains) > 1: + if not selection: + raise ValueError( + f"Table {table} has multiple columns. ", + "Must specify selection to choose column.", + ) + self.add_selection(selection) + else: + self.num_classes = table.num_embeddings + self.col_name = table.input_schema.first.name self.bias = nn.Parameter( torch.zeros(self.num_classes, dtype=torch.float32, device=self.embeddings().device) ) @@ -336,6 +347,16 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ return nn.functional.linear(inputs, self.embeddings(), self.bias) + def add_selection(self, selection: schema.Selection): + selected = schema.select(self.table.input_schema, selection) + if not len(selected) == 1: + raise ValueError("Schema must contain exactly one column. ", f"got: {selected}") + self.col_schema = selected.first + self.col_name = self.col_schema.name + self.num_classes = self.col_schema.int_domain.max + 1 + + return self + def embeddings(self) -> nn.Parameter: """Fetch the weight matrix from the embedding table. @@ -344,6 +365,9 @@ def embeddings(self) -> nn.Parameter: nn.Parameter Weight matrix from the embedding table. """ + if len(self.table.domains) > 1: + return self.table.feature_weights(self.col_name) + return self.table.table.weight def embedding_lookup(self, inputs: torch.Tensor) -> torch.Tensor: @@ -359,9 +383,7 @@ def embedding_lookup(self, inputs: torch.Tensor) -> torch.Tensor: torch.Tensor The corresponding embeddings. """ - # TODO: Make sure that we check if the table holds multiple features - # If so, we need to add domain.min to the inputs - return self.table.table(inputs) + return self.table.table({self.col_name: inputs})[self.col_name] def categorical_output_schema(target: ColumnSchema, num_classes: int) -> Schema: