Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DLRM Model #1171

Merged
merged 4 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys
from merlin.models.torch.inputs.tabular import TabularInputBlock
from merlin.models.torch.models.base import Model
from merlin.models.torch.models.ranking import DLRMModel
from merlin.models.torch.outputs.base import ModelOutput
from merlin.models.torch.outputs.classification import BinaryOutput
from merlin.models.torch.outputs.regression import RegressionOutput
Expand Down Expand Up @@ -55,4 +56,5 @@
"Stack",
"schema",
"DLRMBlock",
"DLRMModel",
]
5 changes: 5 additions & 0 deletions merlin/models/torch/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ def compute_loss(
else:
raise ValueError(f"Unknown 'predictions' type: {type(predictions)}")

if _targets.size() != _predictions.size():
_targets = _targets.view(_predictions.size())
if _targets.type() != _predictions.type():
_targets = _targets.type_as(_predictions)

results["loss"] = results["loss"] + model_out.loss(_predictions, _targets) / len(
model_outputs
)
Expand Down
76 changes: 76 additions & 0 deletions merlin/models/torch/models/ranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Optional

from torch import nn

from merlin.models.torch.block import Block
from merlin.models.torch.blocks.dlrm import DLRMBlock
from merlin.models.torch.models.base import Model
from merlin.models.torch.outputs.tabular import TabularOutputBlock
from merlin.schema import Schema


class DLRMModel(Model):
"""
The Deep Learning Recommendation Model (DLRM) as proposed in Naumov, et al. [1]

Parameters
----------
schema : Schema
The schema to use for selection.
dim : int
The dimensionality of the output vectors.
bottom_block : Block
Block to pass the continuous features to.
Note that, the output dimensionality of this block must be equal to ``dim``.
top_block : Block, optional
An optional upper-level block of the model.
interaction : nn.Module, optional
Interaction module for DLRM.
If not provided, DLRMInteraction will be used by default.
output_block : Block, optional
The output block of the model, by default None.
If None, a TabularOutputBlock with schema and default initializations is used.

Returns
-------
Model
An instance of Model class representing the fully formed DLRM.

Example usage
-------------
>>> model = mm.DLRMModel(
... schema,
... dim=64,
... bottom_block=mm.MLPBlock([256, 64]),
... output_block=BinaryOutput(ColumnSchema("target")))
>>> trainer = pl.Trainer()
>>> model.initialize(dataloader)
>>> trainer.fit(model, dataloader)

References
----------
[1] Naumov, Maxim, et al. "Deep learning recommendation model for
personalization and recommendation systems." arXiv preprint arXiv:1906.00091 (2019).
"""

def __init__(
self,
schema: Schema,
dim: int,
bottom_block: Block,
top_block: Optional[Block] = None,
interaction: Optional[nn.Module] = None,
output_block: Optional[Block] = None,
) -> None:
if output_block is None:
output_block = TabularOutputBlock(schema, init="defaults")

dlrm_body = DLRMBlock(
schema,
dim,
bottom_block,
top_block=top_block,
interaction=interaction,
)

super().__init__(dlrm_body, output_block)
2 changes: 1 addition & 1 deletion merlin/models/torch/utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def initialize(module, data: Union[Dataset, Loader, Batch], dtype=torch.float32)
if hasattr(module, "model_outputs"):
for model_out in module.model_outputs():
for metric in model_out.metrics:
metric.to(batch.device())
metric.to(device=batch.device())

from merlin.models.torch import schema

Expand Down
36 changes: 22 additions & 14 deletions tests/unit/torch/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
#
import pandas as pd
import pytest
import pytorch_lightning as pl
import torch
from torch import nn
from torchmetrics import AUROC, Accuracy, Precision, Recall

import merlin.models.torch as mm
from merlin.dataloader.torch import Loader
from merlin.io import Dataset
from merlin.models.torch.batch import Batch
from merlin.models.torch.batch import Batch, sample_batch
from merlin.models.torch.models.base import compute_loss
from merlin.models.torch.utils import module_utils
from merlin.schema import ColumnSchema
Expand Down Expand Up @@ -200,22 +201,29 @@ def test_no_output_schema(self):
with pytest.raises(ValueError, match="Could not get output schema of PlusOne()"):
mm.schema.output(model)

# def test_train_classification(self, music_streaming_data):
# schema = music_streaming_data.schema.without(["user_genres", "like", "item_genres"])
# music_streaming_data.schema = schema
def test_train_classification_with_lightning_trainer(self, music_streaming_data, batch_size=16):
schema = music_streaming_data.schema.select_by_name(
["item_id", "user_id", "user_age", "item_genres", "click"]
)
music_streaming_data.schema = schema

# model = mm.Model(
# mm.TabularInputBlock(schema),
# mm.MLPBlock([4, 2]),
# mm.BinaryOutput(schema.select_by_name("click").first),
# schema=schema,
# )
model = mm.Model(
mm.TabularInputBlock(schema, init="defaults"),
mm.MLPBlock([4, 2]),
mm.BinaryOutput(schema.select_by_name("click").first),
)

trainer = pl.Trainer(max_epochs=1, devices=1)

with Loader(music_streaming_data, batch_size=batch_size) as loader:
model.initialize(loader)
trainer.fit(model, loader)

# trainer = pl.Trainer(max_epochs=1)
assert trainer.logged_metrics["train_loss"] > 0.0
assert trainer.num_training_batches == 7 # 100 rows // 16 per batch + 1 for last batch

# with Loader(music_streaming_data, batch_size=16) as loader:
# model.initialize(loader)
# trainer.fit(model, loader)
batch = sample_batch(music_streaming_data, batch_size)
_ = module_utils.module_test(model, batch)


class TestComputeLoss:
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/torch/models/test_ranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
import pytorch_lightning as pl

import merlin.models.torch as mm
from merlin.dataloader.torch import Loader
from merlin.models.torch.batch import sample_batch
from merlin.models.torch.utils import module_utils
from merlin.schema import ColumnSchema


@pytest.mark.parametrize("output_block", [None, mm.BinaryOutput(ColumnSchema("click"))])
class TestDLRMModel:
def test_train_dlrm_with_lightning_loader(
self, music_streaming_data, output_block, dim=2, batch_size=16
):
schema = music_streaming_data.schema.select_by_name(
["item_id", "user_id", "user_age", "item_genres", "click"]
)
music_streaming_data.schema = schema

model = mm.DLRMModel(
schema,
dim=dim,
bottom_block=mm.MLPBlock([4, 2]),
top_block=mm.MLPBlock([4, 2]),
output_block=output_block,
)

trainer = pl.Trainer(max_epochs=1, devices=1)

with Loader(music_streaming_data, batch_size=batch_size) as train_loader:
model.initialize(train_loader)
trainer.fit(model, train_loader)

assert trainer.logged_metrics["train_loss"] > 0.0

batch = sample_batch(music_streaming_data, batch_size)
_ = module_utils.module_test(model, batch)