Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding CrossBlock (used in DCN-v2) #1172

Merged
merged 4 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions merlin/models/torch/blocks/cross.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Dict, Optional, Union

import torch
from torch import nn
from torch.nn.modules.lazy import LazyModuleMixin

from merlin.models.torch.block import Block
from merlin.models.torch.transforms.agg import Concat
from merlin.models.utils.doc_utils import docstring_parameter

_DCNV2_REF = """
References
----------
.. [1]. Wang, Ruoxi, et al. "DCN V2: Improved deep & cross network and
practical lessons for web-scale learning to rank systems." Proceedings
of the Web Conference 2021. 2021. https://arxiv.org/pdf/2008.13535.pdf

"""


class LazyMirrorLinear(LazyModuleMixin, nn.Linear):
"""A :class:`torch.nn.Linear` module where both
`in_features` & `out_features` are inferred. (i.e. `out_features` = `in_features`)

Parameters
----------
bias:
If set to ``False``, the layer will not learn an additive bias.
Default: ``True``

"""

cls_to_become = nn.Linear # type: ignore[assignment]
weight: nn.parameter.UninitializedParameter
bias: nn.parameter.UninitializedParameter # type: ignore[assignment]

def __init__(self, bias: bool = True, device=None, dtype=None) -> None:
# This code is taken from torch.nn.LazyLinear.__init__
factory_kwargs = {"device": device, "dtype": dtype}
# bias is hardcoded to False to avoid creating tensor
# that will soon be overwritten.
super().__init__(0, 0, False)
self.weight = nn.parameter.UninitializedParameter(**factory_kwargs)
if bias:
self.bias = nn.parameter.UninitializedParameter(**factory_kwargs)

def reset_parameters(self) -> None:
if not self.has_uninitialized_params() and self.in_features != 0:
super().reset_parameters()

def initialize_parameters(self, input) -> None: # type: ignore[override]
if self.has_uninitialized_params():
with torch.no_grad():
self.in_features = input.shape[-1]
self.out_features = self.in_features
self.weight.materialize((self.out_features, self.in_features))
if self.bias is not None:
self.bias.materialize((self.out_features,))
self.reset_parameters()


@docstring_parameter(dcn_reference=_DCNV2_REF)
class CrossBlock(Block):
"""
This block provides a way to create high-order feature interactions
by a number of stacked Cross Layers, from DCN V2: Improved Deep & Cross Network [1].
See Eq. (1) for full-rank and Eq. (2) for low-rank version.

Parameters
----------
depth : int, optional
Number of cross-layers to be stacked, by default 1

{dcn_reference}
"""

def __init__(self, *module, name: Optional[str] = None):
super().__init__(*module, name=name)
self.concat = Concat()

@classmethod
def with_depth(cls, depth: int):
if not depth > 0:
raise ValueError(f"`depth` must be greater than 0, got {depth}")

return cls(*Block(LazyMirrorLinear()).repeat(depth))

def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]):
x = self.concat(inputs)
else:
x = inputs

x0 = x
current = x
for module in self.values:
module_out = module(current)
if not isinstance(module_out, torch.Tensor):
raise RuntimeError("CrossBlock expects a Tensor as output")

current = x0 * module_out + current

return current
86 changes: 86 additions & 0 deletions tests/unit/torch/blocks/test_cross.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Tuple

import pytest
import torch
from torch import nn

from merlin.models.torch.blocks.cross import CrossBlock, LazyMirrorLinear
from merlin.models.torch.utils import module_utils


class TestLazyMirrorLinear:
def test_init(self):
module = LazyMirrorLinear(bias=True)
assert isinstance(module.weight, nn.parameter.UninitializedParameter)
assert isinstance(module.bias, nn.parameter.UninitializedParameter)

def test_no_bias_init(self):
module = LazyMirrorLinear(bias=False)
assert isinstance(module.weight, nn.parameter.UninitializedParameter)
assert module.bias is None

def test_reset_parameters(self):
module = LazyMirrorLinear(bias=True)
input = torch.randn(10, 20)
module.initialize_parameters(input)
assert module.in_features == 20
assert module.out_features == 20
assert module.weight.shape == (20, 20)
assert module.bias.shape == (20,)

def test_forward(self):
module = LazyMirrorLinear(bias=True)
input = torch.randn(10, 20)
output = module_utils.module_test(module, input)
assert output.shape == (10, 20)

def test_no_bias_forward(self):
module = LazyMirrorLinear(bias=False)
input = torch.randn(10, 20)
output = module_utils.module_test(module, input)
assert output.shape == (10, 20)


class TestCrossBlock:
def test_with_depth(self):
crossblock = CrossBlock.with_depth(depth=1)
assert len(crossblock) == 1
assert isinstance(crossblock[0][0], LazyMirrorLinear)

def test_with_multiple_depth(self):
crossblock = CrossBlock.with_depth(depth=3)
assert len(crossblock) == 3
for module in crossblock:
assert isinstance(module[0], LazyMirrorLinear)

def test_crossblock_invalid_depth(self):
with pytest.raises(ValueError):
CrossBlock.with_depth(depth=0)

def test_forward_tensor(self):
crossblock = CrossBlock.with_depth(depth=1)
input = torch.randn(5, 10)
output = module_utils.module_test(crossblock, input)
assert output.shape == (5, 10)

def test_forward_dict(self):
crossblock = CrossBlock.with_depth(depth=1)
inputs = {"a": torch.randn(5, 10), "b": torch.randn(5, 10)}
output = module_utils.module_test(crossblock, inputs)
assert output.shape == (5, 20)

def test_forward_multiple_depth(self):
crossblock = CrossBlock.with_depth(depth=3)
input = torch.randn(5, 10)
output = module_utils.module_test(crossblock, input)
assert output.shape == (5, 10)

def test_exception(self):
class ToTuple(nn.Module):
def forward(self, input) -> Tuple[torch.Tensor, torch.Tensor]:
return input, input

crossblock = CrossBlock(ToTuple())

with pytest.raises(RuntimeError):
module_utils.module_test(crossblock, torch.randn(5, 10))