Skip to content

Commit

Permalink
First pass over proposed-API for weight-tying
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromeyn committed Jun 28, 2023
1 parent 89df557 commit 2b88065
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 36 deletions.
8 changes: 8 additions & 0 deletions merlin/models/torch/inputs/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,14 @@ def update_feature(self, col_schema: ColumnSchema) -> "EmbeddingTable":

return self

def feature_weights(self, name: str):
if name not in self.domains:
raise ValueError()

domain = self.domains[name]

return self.table.weights[domain.min : domain.max]

def select(self, selection: Selection) -> Selectable:
selected = select(self.input_schema, selection)

Expand Down
6 changes: 6 additions & 0 deletions merlin/models/torch/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
self,
*blocks: nn.Module,
optimizer=torch.optim.Adam,
batch_block: nn.Module = None,
):
super().__init__()

Expand All @@ -79,6 +80,11 @@ def initialize(self, data: Union[Dataset, Loader, Batch]):
def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
):
if not batch:
batch = Batch(inputs, None)

batch = self.batch_block(batch)

"""Performs a forward pass through the model."""
outputs = inputs
for block in self.values:
Expand Down
94 changes: 58 additions & 36 deletions merlin/models/torch/outputs/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics import AUROC, Accuracy, AveragePrecision, Metric, Precision, Recall

import merlin.dtypes as md
from merlin.models.torch import schema
from merlin.models.torch.inputs.embedding import EmbeddingTable
from merlin.models.torch.outputs.base import ModelOutput
from merlin.schema import ColumnSchema, Schema, Tags
Expand Down Expand Up @@ -127,43 +128,32 @@ def __init__(
loss: nn.Module = nn.CrossEntropyLoss(),
metrics: Optional[Sequence[Metric]] = None,
logits_temperature: float = 1.0,
weight_tying: Optional[EmbeddingTable] = None,
):
super().__init__(
loss=loss,
metrics=metrics,
logits_temperature=logits_temperature,
)
self.weight_tying = weight_tying

if schema:
self.setup_schema(schema)

# def setup(
# self,
# to_call: Optional[
# Union[
# Schema,
# ColumnSchema,
# EmbeddingTable,
# "CategoricalTarget",
# "EmbeddingTablePrediction",
# ]
# ] = None
# ):
# if isinstance(to_call, (Schema, ColumnSchema)):
# self.setup_schema(to_call)
# elif isinstance(to_call, (EmbeddingTable)):
# self.prepend(EmbeddingTablePrediction(to_call))
# elif isinstance(to_call, (CategoricalTarget, EmbeddingTablePrediction)):
# self.prepend(to_call)
# else:
# raise ValueError(f"Invalid to_call type: {type(to_call)}")

# self.num_classes = self[0].num_classes

# if not self.metrics:
# self.metrics = self.default_metrics()
@classmethod
def with_weight_tying(
cls,
selection: schema.Selection,
block: nn.Module,
loss: nn.Module = nn.CrossEntropyLoss(),
metrics: Optional[Sequence[Metric]] = None,
logits_temperature: float = 1.0,
) -> "CategoricalOutput":
self = cls(loss=loss, metrics=metrics, logits_temperature=logits_temperature)
self = self.tie_weights(selection, block)
self.output_schema = categorical_output_schema(self[0].col_schema, self.num_classes)
if not self.metrics:
self.metrics = self.default_metrics(self.num_classes)

return self

def setup_schema(self, target: Optional[Union[ColumnSchema, Schema]]):
"""Set up the schema for the output.
Expand All @@ -178,15 +168,27 @@ def setup_schema(self, target: Optional[Union[ColumnSchema, Schema]]):
raise ValueError("Schema must contain exactly one column.")

target = target.first

to_call = CategoricalTarget(target)
self.num_classes = to_call.num_classes

self.prepend(to_call)

self.output_schema = categorical_output_schema(target, self.num_classes)
if not self.metrics:
self.metrics = self.default_metrics(self.num_classes)

def tie_weights(self, selection: schema.Selection, block: nn.Module) -> "CategoricalOutput":
if isinstance(block, EmbeddingTable):
table = block
else:
try:
selected = schema.select(block, selection)
table = selected.leaf()
except Exception as e:
raise ValueError("Could not find embedding table in block.") from e

self[0] = EmbeddingTablePrediction(table, selection)

def tie_weights(self, embedding_table: EmbeddingTable):
self.weight_tying = embedding_table
return self

@classmethod
def default_metrics(cls, num_classes: int) -> List[Metric]:
Expand Down Expand Up @@ -313,10 +315,19 @@ class EmbeddingTablePrediction(nn.Module):
arXiv:1611.01462 (2016).
"""

def __init__(self, table: EmbeddingTable):
def __init__(self, table: EmbeddingTable, selection: Optional[schema.Selection] = None):
super().__init__()
self.table = table
self.num_classes = table.num_embeddings
if len(table.domains) > 1:
if not selection:
raise ValueError(
f"Table {table} has multiple columns. ",
"Must specify selection to choose column.",
)
self.add_selection(selection)
else:
self.num_classes = table.num_embeddings
self.col_name = table.input_schema.first.name
self.bias = nn.Parameter(
torch.zeros(self.num_classes, dtype=torch.float32, device=self.embeddings().device)
)
Expand All @@ -336,6 +347,16 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
return nn.functional.linear(inputs, self.embeddings(), self.bias)

def add_selection(self, selection: schema.Selection):
selected = schema.select(self.table.input_schema, selection)
if not len(selected) == 1:
raise ValueError("Schema must contain exactly one column. ", f"got: {selected}")
self.col_schema = selected.first
self.col_name = self.col_schema.name
self.num_classes = self.col_schema.int_domain.max + 1

return self

def embeddings(self) -> nn.Parameter:
"""Fetch the weight matrix from the embedding table.
Expand All @@ -344,6 +365,9 @@ def embeddings(self) -> nn.Parameter:
nn.Parameter
Weight matrix from the embedding table.
"""
if len(self.table.domains) > 1:
return self.table.feature_weights(self.col_name)

return self.table.table.weight

def embedding_lookup(self, inputs: torch.Tensor) -> torch.Tensor:
Expand All @@ -359,9 +383,7 @@ def embedding_lookup(self, inputs: torch.Tensor) -> torch.Tensor:
torch.Tensor
The corresponding embeddings.
"""
# TODO: Make sure that we check if the table holds multiple features
# If so, we need to add domain.min to the inputs
return self.table.table(inputs)
return self.table.table({self.col_name: inputs})[self.col_name]


def categorical_output_schema(target: ColumnSchema, num_classes: int) -> Schema:
Expand Down

0 comments on commit 2b88065

Please sign in to comment.