Skip to content

Commit

Permalink
Merge branch 'main' into torch/contrastive-output
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromeyn authored Jul 8, 2023
2 parents 51da47a + c5afbd1 commit d121a2a
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 9 deletions.
3 changes: 2 additions & 1 deletion merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys
from merlin.models.torch.inputs.tabular import TabularInputBlock
from merlin.models.torch.models.base import Model
from merlin.models.torch.models.ranking import DLRMModel
from merlin.models.torch.models.ranking import DCNModel, DLRMModel
from merlin.models.torch.outputs.base import ModelOutput
from merlin.models.torch.outputs.classification import (
BinaryOutput,
Expand Down Expand Up @@ -75,4 +75,5 @@
"target_schema",
"DLRMBlock",
"DLRMModel",
"DCNModel",
]
5 changes: 4 additions & 1 deletion merlin/models/torch/blocks/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import nn
from torch.nn.modules.lazy import LazyModuleMixin

from merlin.models.torch.batch import Batch
from merlin.models.torch.block import Block
from merlin.models.torch.transforms.agg import Concat
from merlin.models.utils.doc_utils import docstring_parameter
Expand Down Expand Up @@ -127,7 +128,9 @@ def with_low_rank(cls, depth: int, low_rank: nn.Module) -> "CrossBlock":

return cls(*(Block(deepcopy(low_rank), *block) for block in cls.with_depth(depth)))

def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
) -> torch.Tensor:
"""Forward-pass of the cross-block.
Parameters
Expand Down
89 changes: 82 additions & 7 deletions merlin/models/torch/models/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@

from torch import nn

from merlin.models.torch.block import Block
from merlin.models.torch.blocks.dlrm import DLRMBlock
from merlin.models.torch.block import Block, ParallelBlock
from merlin.models.torch.blocks.cross import _DCNV2_REF, CrossBlock
from merlin.models.torch.blocks.dlrm import _DLRM_REF, DLRMBlock
from merlin.models.torch.blocks.mlp import MLPBlock
from merlin.models.torch.inputs.tabular import TabularInputBlock
from merlin.models.torch.models.base import Model
from merlin.models.torch.outputs.tabular import TabularOutputBlock
from merlin.models.torch.transforms.agg import Concat, MaybeAgg
from merlin.models.utils.doc_utils import docstring_parameter
from merlin.schema import Schema


@docstring_parameter(dlrm_reference=_DLRM_REF)
class DLRMModel(Model):
"""
The Deep Learning Recommendation Model (DLRM) as proposed in Naumov, et al. [1]
Expand Down Expand Up @@ -42,15 +48,13 @@ class DLRMModel(Model):
... schema,
... dim=64,
... bottom_block=mm.MLPBlock([256, 64]),
... output_block=BinaryOutput(ColumnSchema("target")))
... output_block=mm.BinaryOutput(ColumnSchema("target")),
... )
>>> trainer = pl.Trainer()
>>> model.initialize(dataloader)
>>> trainer.fit(model, dataloader)
References
----------
[1] Naumov, Maxim, et al. "Deep learning recommendation model for
personalization and recommendation systems." arXiv preprint arXiv:1906.00091 (2019).
{dlrm_reference}
"""

def __init__(
Expand All @@ -74,3 +78,74 @@ def __init__(
)

super().__init__(dlrm_body, output_block)


@docstring_parameter(dcn_reference=_DCNV2_REF)
class DCNModel(Model):
"""
The Deep & Cross Network (DCN) architecture as proposed in Wang, et al. [1]
Parameters
----------
schema : Schema
The schema to use for selection.
depth : int, optional
Number of cross-layers to be stacked, by default 1
deep_block : Block, optional
The `Block` to use as the deep part of the model (typically a `MLPBlock`)
stacked : bool
Whether to use the stacked version of the model or the parallel version.
input_block : Block, optional
The `Block` to use as the input layer. If None, a default `TabularInputBlock` object
is instantiated, that creates the embedding tables for the categorical features
based on the schema. The embedding dimensions are inferred from the features
cardinality. For a custom representation of input data you can instantiate
and provide a `TabularInputBlock` instance.
Returns
-------
Model
An instance of Model class representing the fully formed DCN.
Example usage
-------------
>>> model = mm.DCNModel(
... schema,
... depth=2,
... deep_block=mm.MLPBlock([256, 64]),
... output_block=mm.BinaryOutput(ColumnSchema("target")),
... )
>>> trainer = pl.Trainer()
>>> model.initialize(dataloader)
>>> trainer.fit(model, dataloader)
{dcn_reference}
"""

def __init__(
self,
schema: Schema,
depth: int = 1,
deep_block: Optional[Block] = None,
stacked: bool = True,
input_block: Optional[Block] = None,
output_block: Optional[Block] = None,
) -> None:
if input_block is None:
input_block = TabularInputBlock(schema, init="defaults")

if output_block is None:
output_block = TabularOutputBlock(schema, init="defaults")

if deep_block is None:
deep_block = MLPBlock([512, 256])

if stacked:
cross_network = Block(CrossBlock.with_depth(depth), deep_block)
else:
cross_network = Block(
ParallelBlock({"cross": CrossBlock.with_depth(depth), "deep": deep_block}),
MaybeAgg(Concat()),
)

super().__init__(input_block, *cross_network, output_block)
31 changes: 31 additions & 0 deletions tests/unit/torch/models/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,34 @@ def test_train_dlrm_with_lightning_loader(

batch = sample_batch(music_streaming_data, batch_size)
_ = module_utils.module_test(model, batch)


class TestDCNModel:
@pytest.mark.parametrize("depth", [1, 2])
@pytest.mark.parametrize("stacked", [True, False])
@pytest.mark.parametrize("deep_block", [None, mm.MLPBlock([4, 2])])
def test_train_dcn_with_lightning_trainer(
self,
music_streaming_data,
depth,
stacked,
deep_block,
batch_size=16,
):
schema = music_streaming_data.schema.select_by_name(
["item_id", "user_id", "user_age", "item_genres", "click"]
)
music_streaming_data.schema = schema

model = mm.DCNModel(schema, depth=depth, deep_block=deep_block, stacked=stacked)

trainer = pl.Trainer(max_epochs=1, devices=1)

with Loader(music_streaming_data, batch_size=batch_size) as train_loader:
model.initialize(train_loader)
trainer.fit(model, train_loader)

assert trainer.logged_metrics["train_loss"] > 0.0

batch = sample_batch(music_streaming_data, batch_size)
_ = module_utils.module_test(model, batch)

0 comments on commit d121a2a

Please sign in to comment.