Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding MMOE & PLE #1173

Merged
merged 13 commits into from
Jul 11, 2023
17 changes: 16 additions & 1 deletion merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,17 @@

from merlin.models.torch import schema
from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.block import Block, ParallelBlock, ResidualBlock, ShortcutBlock
from merlin.models.torch.block import (
Block,
ParallelBlock,
ResidualBlock,
ShortcutBlock,
repeat,
repeat_parallel,
repeat_parallel_like,
)
from merlin.models.torch.blocks.dlrm import DLRMBlock
from merlin.models.torch.blocks.experts import CGCBlock, MMOEBlock, PLEBlock
from merlin.models.torch.blocks.mlp import MLPBlock
from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables
from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys
Expand Down Expand Up @@ -67,6 +76,9 @@
"Concat",
"Stack",
"schema",
"repeat",
"repeat_parallel",
"repeat_parallel_like",
"CategoricalOutput",
"CategoricalTarget",
"EmbeddingTablePrediction",
Expand All @@ -77,4 +89,7 @@
"DLRMBlock",
"DLRMModel",
"DCNModel",
"MMOEBlock",
"PLEBlock",
"CGCBlock",
]
87 changes: 79 additions & 8 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import inspect
import textwrap
from copy import deepcopy
from typing import Dict, Optional, Tuple, TypeVar, Union
from typing import Dict, Optional, Protocol, Tuple, TypeVar, Union, runtime_checkable

import torch
from torch import nn
Expand All @@ -31,6 +31,12 @@
from merlin.schema import Schema


@runtime_checkable
class HasKeys(Protocol):
def keys(self):
...


class Block(BlockContainer, RegistryMixin, TraversableMixin):
"""A base-class that calls it's modules sequentially.

Expand Down Expand Up @@ -87,15 +93,13 @@ def repeat(self, n: int = 1, name=None) -> "Block":
Block
The new block created by repeating the current block `n` times.
"""
if not isinstance(n, int):
raise TypeError("n must be an integer")
return repeat(self, n, name=name)

if n < 1:
raise ValueError("n must be greater than 0")
def repeat_parallel(self, n: int = 1, name=None) -> "ParallelBlock":
return repeat_parallel(self, n, name=name)

repeats = [self.copy() for _ in range(n - 1)]

return Block(self, *repeats, name=name)
def repeat_parallel_like(self, like: HasKeys, name=None) -> "ParallelBlock":
return repeat_parallel_like(self, like, name=name)

def copy(self) -> "Block":
"""
Expand Down Expand Up @@ -342,6 +346,9 @@ def replace(self, pre=None, branches=None, post=None) -> "ParallelBlock":

return output

def keys(self):
return self.branches.keys()

def leaf(self) -> nn.Module:
if self.pre:
raise ValueError("Cannot call leaf() on a ParallelBlock with a pre-processing stage")
Expand Down Expand Up @@ -567,6 +574,70 @@ def forward(
return to_return


def _validate_n(n: int) -> None:
if not isinstance(n, int):
raise TypeError("n must be an integer")

if n < 1:
raise ValueError("n must be greater than 0")


def repeat(module: nn.Module, n: int = 1, name=None) -> Block:
"""
Creates a new block by repeating the current block `n` times.
Each repetition is a deep copy of the current block.

Parameters
----------
module: nn.Module
The module to be repeated.
n : int
The number of times to repeat the current block.
name : Optional[str], default = None
The name for the new block. If None, no name is assigned.

Returns
-------
Block
The new block created by repeating the current block `n` times.
"""
_validate_n(n)

repeats = [module.copy() if hasattr(module, "copy") else deepcopy(module) for _ in range(n - 1)]

return Block(module, *repeats, name=name)


def repeat_parallel(module: nn.Module, n: int = 1, agg=None) -> ParallelBlock:
_validate_n(n)

branches = {"0": module}
branches.update(
{str(n): module.copy() if hasattr(module, "copy") else deepcopy(module) for n in range(n)}
)

output = ParallelBlock(branches)
if agg:
output.append(Block.parse(agg))

return output


def repeat_parallel_like(module: nn.Module, like: HasKeys, agg=None) -> ParallelBlock:
branches = {}
for i, key in enumerate(like.keys()):
if i == 0:
branches[str(key)] = module
else:
branches[str(key)] = module.copy() if hasattr(module, "copy") else deepcopy(module)

output = ParallelBlock(branches)
if agg:
output.append(Block.parse(agg))

return output


def get_pre(module: nn.Module) -> BlockContainer:
if hasattr(module, "pre"):
return module.pre
Expand Down
Loading