diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index fba57be948..bcbc77c5d5 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -16,8 +16,17 @@ from merlin.models.torch import schema from merlin.models.torch.batch import Batch, Sequence -from merlin.models.torch.block import Block, ParallelBlock, ResidualBlock, ShortcutBlock +from merlin.models.torch.block import ( + Block, + ParallelBlock, + ResidualBlock, + ShortcutBlock, + repeat, + repeat_parallel, + repeat_parallel_like, +) from merlin.models.torch.blocks.dlrm import DLRMBlock +from merlin.models.torch.blocks.experts import CGCBlock, MMOEBlock, PLEBlock from merlin.models.torch.blocks.mlp import MLPBlock from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys @@ -67,6 +76,9 @@ "Concat", "Stack", "schema", + "repeat", + "repeat_parallel", + "repeat_parallel_like", "CategoricalOutput", "CategoricalTarget", "EmbeddingTablePrediction", @@ -77,4 +89,7 @@ "DLRMBlock", "DLRMModel", "DCNModel", + "MMOEBlock", + "PLEBlock", + "CGCBlock", ] diff --git a/merlin/models/torch/block.py b/merlin/models/torch/block.py index cf5bea6f29..9f81e6da99 100644 --- a/merlin/models/torch/block.py +++ b/merlin/models/torch/block.py @@ -17,7 +17,7 @@ import inspect import textwrap from copy import deepcopy -from typing import Dict, Optional, Tuple, TypeVar, Union +from typing import Dict, Optional, Protocol, Tuple, TypeVar, Union, runtime_checkable import torch from torch import nn @@ -31,6 +31,12 @@ from merlin.schema import Schema +@runtime_checkable +class HasKeys(Protocol): + def keys(self): + ... + + class Block(BlockContainer, RegistryMixin, TraversableMixin): """A base-class that calls it's modules sequentially. @@ -87,15 +93,13 @@ def repeat(self, n: int = 1, name=None) -> "Block": Block The new block created by repeating the current block `n` times. """ - if not isinstance(n, int): - raise TypeError("n must be an integer") + return repeat(self, n, name=name) - if n < 1: - raise ValueError("n must be greater than 0") + def repeat_parallel(self, n: int = 1, name=None) -> "ParallelBlock": + return repeat_parallel(self, n, name=name) - repeats = [self.copy() for _ in range(n - 1)] - - return Block(self, *repeats, name=name) + def repeat_parallel_like(self, like: HasKeys, name=None) -> "ParallelBlock": + return repeat_parallel_like(self, like, name=name) def copy(self) -> "Block": """ @@ -342,6 +346,9 @@ def replace(self, pre=None, branches=None, post=None) -> "ParallelBlock": return output + def keys(self): + return self.branches.keys() + def leaf(self) -> nn.Module: if self.pre: raise ValueError("Cannot call leaf() on a ParallelBlock with a pre-processing stage") @@ -567,6 +574,70 @@ def forward( return to_return +def _validate_n(n: int) -> None: + if not isinstance(n, int): + raise TypeError("n must be an integer") + + if n < 1: + raise ValueError("n must be greater than 0") + + +def repeat(module: nn.Module, n: int = 1, name=None) -> Block: + """ + Creates a new block by repeating the current block `n` times. + Each repetition is a deep copy of the current block. + + Parameters + ---------- + module: nn.Module + The module to be repeated. + n : int + The number of times to repeat the current block. + name : Optional[str], default = None + The name for the new block. If None, no name is assigned. + + Returns + ------- + Block + The new block created by repeating the current block `n` times. + """ + _validate_n(n) + + repeats = [module.copy() if hasattr(module, "copy") else deepcopy(module) for _ in range(n - 1)] + + return Block(module, *repeats, name=name) + + +def repeat_parallel(module: nn.Module, n: int = 1, agg=None) -> ParallelBlock: + _validate_n(n) + + branches = {"0": module} + branches.update( + {str(n): module.copy() if hasattr(module, "copy") else deepcopy(module) for n in range(n)} + ) + + output = ParallelBlock(branches) + if agg: + output.append(Block.parse(agg)) + + return output + + +def repeat_parallel_like(module: nn.Module, like: HasKeys, agg=None) -> ParallelBlock: + branches = {} + for i, key in enumerate(like.keys()): + if i == 0: + branches[str(key)] = module + else: + branches[str(key)] = module.copy() if hasattr(module, "copy") else deepcopy(module) + + output = ParallelBlock(branches) + if agg: + output.append(Block.parse(agg)) + + return output + + def get_pre(module: nn.Module) -> BlockContainer: if hasattr(module, "pre"): return module.pre diff --git a/merlin/models/torch/blocks/experts.py b/merlin/models/torch/blocks/experts.py new file mode 100644 index 0000000000..627b96a1fd --- /dev/null +++ b/merlin/models/torch/blocks/experts.py @@ -0,0 +1,289 @@ +import textwrap +from functools import partial +from typing import Dict, Optional, Union + +import torch +from torch import nn + +from merlin.models.torch.batch import Batch +from merlin.models.torch.block import ( + Block, + ParallelBlock, + ShortcutBlock, + repeat_parallel, + repeat_parallel_like, +) +from merlin.models.torch.transforms.agg import Concat, Stack +from merlin.models.utils.doc_utils import docstring_parameter + +_PLE_REFERENCE = """ + References + ---------- + .. [1] Tang, Hongyan, et al. "Progressive layered extraction (ple): A novel multi-task + learning (mtl) model for personalized recommendations." + Fourteenth ACM Conference on Recommender Systems. 2020. +""" + + +class MMOEBlock(Block): + """ + Multi-gate Mixture-of-Experts (MMoE) Block introduced in [1]. + + The MMoE model builds upon the concept of an expert model by using a mixture + of experts for decision making. Each expert contributes independently to the + final decision, allowing for increased model complexity and performance + in multi-task learning scenarios. + + Example usage for multi-task learning:: + >>> outputs = mm.TabularOutputBlock(schema, init="defaults") + >>> mmoe = mm.MMOEBlock( + expert=mm.MLPBlock([5]), + num_experts=2, + outputs=outputs, + ) + >>> outputs.prepend_for_each(mm.MLPBlock([64])) # Add task-towers + >>> outputs.prepend(mmoe) + + + References + ---------- + [1] Ma, Jiaqi, et al. "Modeling task relationships in multi-task learning with + multi-gate mixture-of-experts." Proceedings of the 24th ACM SIGKDD international + conference on knowledge discovery & data mining. 2018. + + Parameters + ---------- + expert : nn.Module + The base expert model that serves as the foundation for the MMoE structure. + num_experts : int + The total number of experts in the MMoE model. Each expert operates independently + in the decision-making process. + outputs : Optional[ParallelBlock] + The output block of the model. + If it is an instance of ParallelBlock, the block is repeated for each expert. + Otherwise, a single ExpertGateBlock is used. + """ + + def __init__( + self, expert: nn.Module, num_experts: int, outputs: Optional[ParallelBlock] = None + ): + experts = repeat_parallel(expert, num_experts, agg=Stack(dim=1)) + super().__init__(ShortcutBlock(experts, output_name="experts")) + if isinstance(outputs, ParallelBlock): + self.append(repeat_parallel_like(ExpertGateBlock(num_experts), outputs)) + else: + self.append(ExpertGateBlock(num_experts)) + + +@docstring_parameter(ple_reference=_PLE_REFERENCE) +class PLEBlock(Block): + """ + Progressive Layered Extraction (PLE) Block proposed in [1]. + + The PLE model enhances the architecture of a typical expert model by organizing + shared and task-specific experts in a layered format. This layered structure + allows the extraction of increasingly complex features at each level and can + improve performance in multi-task settings. + + Example usage for multi-task learning:: + >>> outputs = mm.TabularOutputBlock(schema, init="defaults") + >>> ple = mm.PLEBlock( + expert=mm.MLPBlock([5]), + num_shared_experts=2, + num_task_experts=2, + depth=2, + outputs=outputs, + ) + >>> outputs.prepend_for_each(mm.MLPBlock([64])) # Add task-towers + >>> outputs.prepend(ple) + + {ple_reference} + + Parameters + ---------- + expert : nn.Module + The base expert model that forms the basis of the PLE structure. + num_shared_experts : int + The total count of shared experts. These experts contribute to the + decision process in all tasks. + num_task_experts : int + The total count of task-specific experts. These experts contribute + only to their specific tasks. + depth : int + The depth of the layered structure. Each layer comprises a set of experts + and the depth determines the number of such layers. + outputs : ParallelBlock + The output block, which encapsulates the final output from the model. + """ + + def __init__( + self, + expert: nn.Module, + *, + num_shared_experts: int, + num_task_experts: int, + depth: int, + outputs: ParallelBlock, + ): + cgc_kwargs = { + "expert": expert, + "num_shared_experts": num_shared_experts, + "num_task_experts": num_task_experts, + "outputs": outputs, + } + super().__init__(*CGCBlock(shared_gate=True, **cgc_kwargs).repeat(depth - 1)) + self.append(CGCBlock(**cgc_kwargs)) + + +class CGCBlock(Block): + """ + Implements the Customized Gate Control (CGC) proposed in [1]. + + The CGC model extends the capability of a typical expert model by introducing + shared and task-specific experts, thereby customizing the gating control per task, + which may lead to improved performance in multi-task settings. + + {ple_reference} + + Parameters + ---------- + expert : nn.Module + The base expert model that is used as the foundation for the gating mechanism. + num_shared_experts : int + The total count of shared experts. These experts contribute to the decision + process in all tasks. + num_task_experts : int + The total count of task-specific experts. These experts contribute only + to their specific tasks. + outputs : ParallelBlock + The output block, which encapsulates the final output from the model. + shared_gate : bool, optional + Defines whether a shared gate is used across all tasks or not. + If set to True, a shared gate is used. Defaults to False. + """ + + def __init__( + self, + expert: nn.Module, + *, + num_shared_experts: int, + num_task_experts: int, + outputs: ParallelBlock, + shared_gate: bool = False, + ): + shared_experts = repeat_parallel(expert, num_shared_experts, agg=Stack(dim=1)) + expert_shortcut = partial(ShortcutBlock, output_name="experts") + super().__init__(expert_shortcut(shared_experts)) + + gates = ParallelBlock() + for name in outputs.branches: + gates.branches[name] = PLEExpertGateBlock( + num_shared_experts + num_task_experts, + task_experts=repeat_parallel(expert, num_task_experts, agg=Stack(dim=1)), + name=name, + ) + if shared_gate: + gates.branches["experts"] = expert_shortcut( + ExpertGateBlock(num_shared_experts), propagate_shortcut=True + ) + + self.append(gates) + + +class ExpertGateBlock(Block): + """Expert Gate Block. + + # TODO: Add initialize_from_schema to remove the need to pass in num_experts + + Parameters + ---------- + num_experts : int + The number of experts used. + """ + + def __init__(self, num_experts: int): + super().__init__(GateBlock(num_experts)) + + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None + ) -> torch.Tensor: + if torch.jit.isinstance(inputs, torch.Tensor): + raise RuntimeError("ExpertGateBlock requires a dictionary input") + + experts = inputs["experts"] + outputs = inputs["shortcut"] + for module in self.values: + outputs = module(outputs, batch=batch) + + # return torch.sum(experts * outputs, dim=1, keepdim=False) + gated = outputs.expand_as(experts) + + # Multiply and sum along the experts dimension + return (experts * gated).sum(dim=1) + + +class PLEExpertGateBlock(Block): + """ + Progressive Layered Extraction (PLE) Expert Gate Block. + + Parameters + ---------- + num_experts : int + The number of experts used. + task_experts : nn.Module + The expert module. + name : str + The name of the task. + """ + + def __init__(self, num_experts: int, task_experts: nn.Module, name: str): + super().__init__(ExpertGateBlock(num_experts), name=f"PLEExpertGateBlock[{name}]") + self.stack = Stack(dim=1) + self.concat = Concat(dim=1) + self.task_experts = task_experts + self.task_name = name + + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None + ) -> torch.Tensor: + if torch.jit.isinstance(inputs, torch.Tensor): + raise RuntimeError("ExpertGateBlock requires a dictionary input") + + task_experts = self.task_experts(inputs["shortcut"], batch=batch) + if torch.jit.isinstance(task_experts, torch.Tensor): + _task = task_experts + elif torch.jit.isinstance(task_experts, Dict[str, torch.Tensor]): + _task = self.stack(task_experts) + else: + raise RuntimeError("PLEExpertGateBlock requires a dictionary input") + experts = self.concat({"experts": inputs["experts"], "task_experts": _task}) + task = inputs[self.task_name] if self.task_name in inputs else inputs["shortcut"] + + outputs = {"experts": experts, "shortcut": task} + for block in self.values: + outputs = block(outputs, batch=batch) + + return outputs + + def __repr__(self) -> str: + indent_str = " " + output = textwrap.indent("\n(task_experts): " + repr(self.task_experts), indent_str) + output += textwrap.indent("\n(gate): " + repr(self.values[0]), indent_str) + + return f"{self._get_name()}({output}\n)" + + +class SoftmaxGate(nn.Module): + """Softmax Gate for gating mechanism.""" + + def forward(self, gate_logits): + return torch.softmax(gate_logits, dim=-1).unsqueeze(-1) + + +class GateBlock(Block): + """Gate Block for gating mechanism.""" + + def __init__(self, num_experts: int): + super().__init__() + self.append(nn.LazyLinear(num_experts)) + self.append(SoftmaxGate()) diff --git a/merlin/models/torch/container.py b/merlin/models/torch/container.py index e289c694fa..2ee258dcdb 100644 --- a/merlin/models/torch/container.py +++ b/merlin/models/torch/container.py @@ -16,7 +16,7 @@ from copy import deepcopy from functools import reduce -from typing import Dict, Iterator, Optional, Union +from typing import Dict, Iterator, Optional, Sequence, Union from torch import nn from torch._jit_internal import _copy_to_script_wrapper @@ -62,6 +62,23 @@ def append(self, module: nn.Module): return self + def extend(self, sequence: Sequence[nn.Module]): + """Extends the list by appending elements from the iterable. + + Parameters + ---------- + module : nn.Module + The PyTorch module to be appended. + + Returns + ------- + self + """ + for m in sequence: + self.append(m) + + return self + def prepend(self, module: nn.Module): """Prepends a given module to the beginning of the list. diff --git a/tests/unit/torch/blocks/test_experts.py b/tests/unit/torch/blocks/test_experts.py new file mode 100644 index 0000000000..5f32bf3485 --- /dev/null +++ b/tests/unit/torch/blocks/test_experts.py @@ -0,0 +1,123 @@ +import pytest +import torch + +import merlin.models.torch as mm +from merlin.models.torch.blocks.experts import ( + CGCBlock, + ExpertGateBlock, + MMOEBlock, + PLEBlock, + PLEExpertGateBlock, +) +from merlin.models.torch.utils import module_utils + +dict_inputs = {"experts": torch.rand((10, 4, 5)), "shortcut": torch.rand((10, 8))} + + +class TestExpertGateBlock: + @pytest.fixture + def expert_gate(self): + return ExpertGateBlock(num_experts=4) + + def test_requires_dict_input(self, expert_gate): + with pytest.raises(RuntimeError, match="ExpertGateBlock requires a dictionary input"): + expert_gate(torch.rand((10, 5))) + + def test_forward_pass(self, expert_gate): + result = module_utils.module_test(expert_gate, dict_inputs) + assert result.shape == (10, 5) + + +class TestMMOEBlock: + def test_init(self): + mmoe = MMOEBlock(mm.MLPBlock([2, 2]), 2) + + assert isinstance(mmoe, MMOEBlock) + assert isinstance(mmoe[0], mm.ShortcutBlock) + assert len(mmoe[0][0].branches) == 2 + for i in range(2): + assert mmoe[0][0][str(i)][1].out_features == 2 + assert mmoe[0][0][str(i)][3].out_features == 2 + assert isinstance(mmoe[0][0].post[0], mm.Stack) + assert isinstance(mmoe[1], ExpertGateBlock) + assert mmoe[1][0][0].out_features == 2 + + def test_init_with_outputs(self): + outputs = mm.ParallelBlock({"a": mm.BinaryOutput(), "b": mm.BinaryOutput()}) + outputs.prepend_for_each(mm.MLPBlock([2])) + outputs.prepend(MMOEBlock(mm.MLPBlock([2, 2]), 2, outputs)) + + assert isinstance(outputs.pre[0], MMOEBlock) + assert list(outputs.pre[0][1].keys()) == ["a", "b"] + + def test_forward(self): + mmoe = MMOEBlock(mm.MLPBlock([2, 2]), 2) + + outputs = module_utils.module_test(mmoe, torch.rand(5, 5)) + assert outputs.shape == (5, 2) + + def test_forward_with_outputs(self): + outputs = mm.ParallelBlock({"a": mm.BinaryOutput(), "b": mm.BinaryOutput()}) + outputs.prepend_for_each(mm.MLPBlock([2, 2])) + outputs.prepend(MMOEBlock(mm.MLPBlock([2, 2]), 2, outputs)) + + outputs = module_utils.module_test(outputs, torch.rand(5, 5)) + assert outputs["a"].shape == (5, 1) + assert outputs["b"].shape == (5, 1) + + +class TestPLEExpertGateBlock: + @pytest.fixture + def ple_expert_gate(self): + return PLEExpertGateBlock( + num_experts=6, task_experts=mm.repeat_parallel(mm.MLPBlock([5, 5]), 2), name="a" + ) + + def test_repr(self, ple_expert_gate): + assert "(task_experts)" in str(ple_expert_gate) + assert "(gate)" in str(ple_expert_gate) + + def test_requires_dict_input(self, ple_expert_gate): + with pytest.raises(RuntimeError, match="ExpertGateBlock requires a dictionary input"): + ple_expert_gate(torch.rand((10, 5))) + + def test_ple_forward(self, ple_expert_gate): + result = module_utils.module_test(ple_expert_gate, dict_inputs) + assert result.shape == (10, 5) + + +class TestCGCBlock: + @pytest.mark.parametrize("shared_gate", [True, False]) + def test_forward(self, music_streaming_data, shared_gate): + output_block = mm.TabularOutputBlock(music_streaming_data.schema, init="defaults") + cgc = CGCBlock( + mm.MLPBlock([5]), + num_shared_experts=2, + num_task_experts=2, + outputs=output_block, + shared_gate=shared_gate, + ) + + outputs = module_utils.module_test(cgc, torch.rand(5, 5)) + assert len(outputs) == len(output_block) + (2 if shared_gate else 0) + + +class TestPLEBlock: + def test_forward(self, music_streaming_data): + output_block = mm.TabularOutputBlock(music_streaming_data.schema, init="defaults") + ple = PLEBlock( + mm.MLPBlock([5]), + num_shared_experts=2, + num_task_experts=2, + depth=2, + outputs=output_block, + ) + + assert isinstance(ple[0], CGCBlock) + assert len(ple[0][1]) == len(output_block) + 1 + assert isinstance(ple[0][1]["experts"][0], ExpertGateBlock) + assert isinstance(ple[1], CGCBlock) + assert list(ple[1][1].branches.keys()) == list(ple[1][1].branches.keys()) + + outputs = module_utils.module_test(ple, torch.rand(5, 5)) + assert len(outputs) == len(output_block)