From e8bc44f5f8d655da0278572a6705bda660cbed56 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 3 Jul 2023 16:58:57 +0200 Subject: [PATCH 1/7] First pass over MMOEBlock & PLEBlock --- merlin/models/torch/block.py | 64 ++++++++++++++- merlin/models/torch/blocks/experts.py | 113 ++++++++++++++++++++++++++ merlin/models/torch/container.py | 19 ++++- 3 files changed, 194 insertions(+), 2 deletions(-) create mode 100644 merlin/models/torch/blocks/experts.py diff --git a/merlin/models/torch/block.py b/merlin/models/torch/block.py index 42dede5b9b..262956f931 100644 --- a/merlin/models/torch/block.py +++ b/merlin/models/torch/block.py @@ -17,7 +17,7 @@ import inspect import textwrap from copy import deepcopy -from typing import Dict, Optional, Tuple, TypeVar, Union +from typing import Dict, Optional, Tuple, TypeVar, Union, runtime_checkable import torch from torch import nn @@ -567,6 +567,68 @@ def forward( return to_return +def _validate_n(n: int) -> None: + if not isinstance(n, int): + raise TypeError("n must be an integer") + + if n < 1: + raise ValueError("n must be greater than 0") + + +def repeat(module: nn.Module, n: int = 1, name=None) -> Block: + """ + Creates a new block by repeating the current block `n` times. + Each repetition is a deep copy of the current block. + + Parameters + ---------- + module: nn.Module + The module to be repeated. + n : int + The number of times to repeat the current block. + name : Optional[str], default = None + The name for the new block. If None, no name is assigned. + + Returns + ------- + Block + The new block created by repeating the current block `n` times. + """ + _validate_n(n) + + repeats = [module.copy() if hasattr(module, "copy") else deepcopy(module) for _ in range(n - 1)] + + return Block(module, *repeats, name=name) + + +def repeat_parallel(module: nn.Module, n: int = 1, name=None) -> ParallelBlock: + _validate_n(n) + + branches = {"0": module} + branches.update( + {n: module.copy() if hasattr(module, "copy") else deepcopy(module) for n in range(n - 1)} + ) + + return ParallelBlock(branches, name=name) + + +@runtime_checkable +class HasKeys: + def keys(self): + ... + + +def repeat_parallel_like(module: nn.Module, like: HasKeys, name=None) -> ParallelBlock: + branches = {} + for i, key in enumerate(like.keys()): + if i == 0: + branches[key] = module + else: + branches[key] = module.copy() if hasattr(module, "copy") else deepcopy(module) + + return ParallelBlock(branches, name=name) + + def get_pre(module: nn.Module) -> BlockContainer: if hasattr(module, "pre"): return module.pre diff --git a/merlin/models/torch/blocks/experts.py b/merlin/models/torch/blocks/experts.py new file mode 100644 index 0000000000..3dd82cd30e --- /dev/null +++ b/merlin/models/torch/blocks/experts.py @@ -0,0 +1,113 @@ +from typing import Dict, Optional + +import torch +from torch import nn + +from merlin.models.torch.block import ( + Block, + ParallelBlock, + ShortcutBlock, + repeat_parallel, + repeat_parallel_like, +) +from merlin.models.torch.transforms.agg import Stack + + +class MMOEBlock(Block): + def __init__( + self, expert: nn.Module, num_experts: int, outputs: Optional[ParallelBlock] = None + ): + super().__init__( + ShortcutBlock(repeat_parallel(expert, num_experts, agg="stack"), output_name="experts") + ) + if isinstance(outputs, ParallelBlock): + self.append(repeat_parallel_like(ExpertGateBlock(len(outputs)), outputs)) + else: + self.append(ExpertGateBlock(1)) + + +class PLEBlock(Block): + def __init__( + self, + expert: nn.Module, + num_shared_experts: int, + num_task_experts: int, + depth: int, + outputs: ParallelBlock, + ): + cgc = CGCBlock( + expert, num_shared_experts, num_task_experts, outputs=outputs, shared_gate=True + ) + super().__init__(*cgc.repeat(depth - 1)) + self.append(CGCBlock(expert, num_shared_experts, num_task_experts, outputs=outputs)) + + +class CGCBlock(Block): + def __init__( + self, + expert: nn.Module, + num_shared_experts: int, + num_task_experts: int, + outputs: ParallelBlock, + shared_gate: bool = False, + ): + super().__init__( + ShortcutBlock( + repeat_parallel(expert, num_shared_experts, agg="stack"), output_name="experts" + ) + ) + + gates = ParallelBlock() + for key in outputs.branches: + gates[key] = PLEExpertGateBlock( + len(outputs), + experts=repeat_parallel(expert, num_task_experts, agg="stack"), + name=key, + ) + if shared_gate: + gates["experts"] = ExpertGateBlock(len(outputs)) + + self.append(gates) + + +class ExpertGateBlock(Block): + def __init__(self, num_outputs: int): + super().__init__(GateBlock(num_outputs)) + + def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: + experts = inputs["experts"] + outputs = inputs["shortcut"] + for module in self.values: + outputs = module(outputs) + + gated = outputs.expand_as(experts) + + # Multiply and sum along the experts dimension + return (experts * gated).sum(dim=1) + + +class PLEExpertGateBlock(Block): + def __init__(self, num_outputs: int, experts: nn.Module, name: str): + super().__init__(GateBlock(num_outputs), name=f"PLEExpertGateBlock[{name}]") + self.stack = Stack(dim=1) + self.experts = experts + self.task_name = name + + def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: + task_experts = self.experts(inputs["shortcut"]) + experts = self.stack({"experts": inputs["experts"], "task_experts": task_experts}) + task = inputs[self.name] if self.name in inputs else inputs["shortcut"] + + return self.output({"experts": experts, "shortcut": task}) + + +class SoftmaxGate(nn.Module): + def forward(self, gate_logits): + return torch.softmax(gate_logits, dim=1).unsqueeze(2) + + +class GateBlock(Block): + def __init__(self, num_outputs: int): + super().__init__() + self.append(nn.LazyLinear(num_outputs)) + self.append(SoftmaxGate()) diff --git a/merlin/models/torch/container.py b/merlin/models/torch/container.py index e289c694fa..2ee258dcdb 100644 --- a/merlin/models/torch/container.py +++ b/merlin/models/torch/container.py @@ -16,7 +16,7 @@ from copy import deepcopy from functools import reduce -from typing import Dict, Iterator, Optional, Union +from typing import Dict, Iterator, Optional, Sequence, Union from torch import nn from torch._jit_internal import _copy_to_script_wrapper @@ -62,6 +62,23 @@ def append(self, module: nn.Module): return self + def extend(self, sequence: Sequence[nn.Module]): + """Extends the list by appending elements from the iterable. + + Parameters + ---------- + module : nn.Module + The PyTorch module to be appended. + + Returns + ------- + self + """ + for m in sequence: + self.append(m) + + return self + def prepend(self, module: nn.Module): """Prepends a given module to the beginning of the list. From 38321f31393197a02d8b750bbddb51f896352d48 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 3 Jul 2023 17:57:43 +0200 Subject: [PATCH 2/7] Adding some simple tests for MMOEBlock --- merlin/models/torch/__init__.py | 13 ++++++- merlin/models/torch/block.py | 49 +++++++++++++++---------- merlin/models/torch/blocks/experts.py | 36 ++++++++++-------- tests/unit/torch/blocks/test_experts.py | 46 +++++++++++++++++++++++ 4 files changed, 107 insertions(+), 37 deletions(-) create mode 100644 tests/unit/torch/blocks/test_experts.py diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index 025c8ba0dc..45b038d430 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -16,7 +16,15 @@ from merlin.models.torch import schema from merlin.models.torch.batch import Batch, Sequence -from merlin.models.torch.block import Block, ParallelBlock, ResidualBlock, ShortcutBlock +from merlin.models.torch.block import ( + Block, + ParallelBlock, + ResidualBlock, + ShortcutBlock, + repeat, + repeat_parallel, + repeat_parallel_like, +) from merlin.models.torch.blocks.dlrm import DLRMBlock from merlin.models.torch.blocks.mlp import MLPBlock from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables @@ -55,6 +63,9 @@ "Concat", "Stack", "schema", + "repeat", + "repeat_parallel", + "repeat_parallel_like", "DLRMBlock", "DLRMModel", ] diff --git a/merlin/models/torch/block.py b/merlin/models/torch/block.py index 262956f931..86a9a9bacc 100644 --- a/merlin/models/torch/block.py +++ b/merlin/models/torch/block.py @@ -17,7 +17,7 @@ import inspect import textwrap from copy import deepcopy -from typing import Dict, Optional, Tuple, TypeVar, Union, runtime_checkable +from typing import Dict, Optional, Protocol, Tuple, TypeVar, Union, runtime_checkable import torch from torch import nn @@ -31,6 +31,12 @@ from merlin.schema import Schema +@runtime_checkable +class HasKeys(Protocol): + def keys(self): + ... + + class Block(BlockContainer, RegistryMixin, TraversableMixin): """A base-class that calls it's modules sequentially. @@ -87,15 +93,13 @@ def repeat(self, n: int = 1, name=None) -> "Block": Block The new block created by repeating the current block `n` times. """ - if not isinstance(n, int): - raise TypeError("n must be an integer") - - if n < 1: - raise ValueError("n must be greater than 0") + return repeat(self, n, name=name) - repeats = [self.copy() for _ in range(n - 1)] + def repeat_parallel(self, n: int = 1, name=None) -> "ParallelBlock": + return repeat_parallel(self, n, name=name) - return Block(self, *repeats, name=name) + def repeat_parallel_like(self, like: HasKeys, name=None) -> "ParallelBlock": + return repeat_parallel_like(self, like, name=name) def copy(self) -> "Block": """ @@ -342,6 +346,9 @@ def replace(self, pre=None, branches=None, post=None) -> "ParallelBlock": return output + def keys(self): + return self.branches.keys() + def leaf(self) -> nn.Module: if self.pre: raise ValueError("Cannot call leaf() on a ParallelBlock with a pre-processing stage") @@ -601,32 +608,34 @@ def repeat(module: nn.Module, n: int = 1, name=None) -> Block: return Block(module, *repeats, name=name) -def repeat_parallel(module: nn.Module, n: int = 1, name=None) -> ParallelBlock: +def repeat_parallel(module: nn.Module, n: int = 1, agg=None) -> ParallelBlock: _validate_n(n) branches = {"0": module} branches.update( - {n: module.copy() if hasattr(module, "copy") else deepcopy(module) for n in range(n - 1)} + {str(n): module.copy() if hasattr(module, "copy") else deepcopy(module) for n in range(n)} ) - return ParallelBlock(branches, name=name) + output = ParallelBlock(branches) + if agg: + output.append(Block.parse(agg)) - -@runtime_checkable -class HasKeys: - def keys(self): - ... + return output -def repeat_parallel_like(module: nn.Module, like: HasKeys, name=None) -> ParallelBlock: +def repeat_parallel_like(module: nn.Module, like: HasKeys, agg=None) -> ParallelBlock: branches = {} for i, key in enumerate(like.keys()): if i == 0: - branches[key] = module + branches[str(key)] = module else: - branches[key] = module.copy() if hasattr(module, "copy") else deepcopy(module) + branches[str(key)] = module.copy() if hasattr(module, "copy") else deepcopy(module) + + output = ParallelBlock(branches) + if agg: + output.append(Block.parse(agg)) - return ParallelBlock(branches, name=name) + return output def get_pre(module: nn.Module) -> BlockContainer: diff --git a/merlin/models/torch/blocks/experts.py b/merlin/models/torch/blocks/experts.py index 3dd82cd30e..e08b19c44f 100644 --- a/merlin/models/torch/blocks/experts.py +++ b/merlin/models/torch/blocks/experts.py @@ -1,8 +1,9 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Union import torch from torch import nn +from merlin.models.torch.batch import Batch from merlin.models.torch.block import ( Block, ParallelBlock, @@ -17,9 +18,8 @@ class MMOEBlock(Block): def __init__( self, expert: nn.Module, num_experts: int, outputs: Optional[ParallelBlock] = None ): - super().__init__( - ShortcutBlock(repeat_parallel(expert, num_experts, agg="stack"), output_name="experts") - ) + experts = repeat_parallel(expert, num_experts, agg=Stack(dim=1)) + super().__init__(ShortcutBlock(experts, output_name="experts")) if isinstance(outputs, ParallelBlock): self.append(repeat_parallel_like(ExpertGateBlock(len(outputs)), outputs)) else: @@ -51,18 +51,15 @@ def __init__( outputs: ParallelBlock, shared_gate: bool = False, ): - super().__init__( - ShortcutBlock( - repeat_parallel(expert, num_shared_experts, agg="stack"), output_name="experts" - ) - ) + shared_experts = repeat_parallel(expert, num_shared_experts, agg="stack") + super().__init__(ShortcutBlock(shared_experts, output_name="experts")) gates = ParallelBlock() - for key in outputs.branches: - gates[key] = PLEExpertGateBlock( + for name in outputs.branches: + gates[name] = PLEExpertGateBlock( len(outputs), experts=repeat_parallel(expert, num_task_experts, agg="stack"), - name=key, + name=name, ) if shared_gate: gates["experts"] = ExpertGateBlock(len(outputs)) @@ -74,11 +71,16 @@ class ExpertGateBlock(Block): def __init__(self, num_outputs: int): super().__init__(GateBlock(num_outputs)) - def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None + ) -> torch.Tensor: + if torch.jit.isinstance(inputs, torch.Tensor): + raise RuntimeError("ExpertGateBlock requires a dictionary input") + experts = inputs["experts"] outputs = inputs["shortcut"] for module in self.values: - outputs = module(outputs) + outputs = module(outputs, batch=batch) gated = outputs.expand_as(experts) @@ -93,8 +95,10 @@ def __init__(self, num_outputs: int, experts: nn.Module, name: str): self.experts = experts self.task_name = name - def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: - task_experts = self.experts(inputs["shortcut"]) + def forward( + self, inputs: Dict[str, torch.Tensor], batch: Optional[Batch] = None + ) -> torch.Tensor: + task_experts = self.experts(inputs["shortcut"], batch=batch) experts = self.stack({"experts": inputs["experts"], "task_experts": task_experts}) task = inputs[self.name] if self.name in inputs else inputs["shortcut"] diff --git a/tests/unit/torch/blocks/test_experts.py b/tests/unit/torch/blocks/test_experts.py new file mode 100644 index 0000000000..f3738d4e0f --- /dev/null +++ b/tests/unit/torch/blocks/test_experts.py @@ -0,0 +1,46 @@ +import torch + +import merlin.models.torch as mm +from merlin.models.torch.blocks.experts import ( # CGCBlock,; PLEBlock,; PLEExpertGateBlock, + ExpertGateBlock, + MMOEBlock, +) +from merlin.models.torch.utils import module_utils + + +class TestMMOEBlock: + def test_init(self): + mmoe = MMOEBlock(mm.MLPBlock([10, 10]), 2) + + assert isinstance(mmoe, MMOEBlock) + assert isinstance(mmoe[0], mm.ShortcutBlock) + assert len(mmoe[0][0].branches) == 2 + for i in range(2): + assert mmoe[0][0][str(i)][1].out_features == 10 + assert mmoe[0][0][str(i)][3].out_features == 10 + assert isinstance(mmoe[0][0].post[0], mm.Stack) + assert isinstance(mmoe[1], ExpertGateBlock) + assert mmoe[1][0][0].out_features == 1 + + def test_init_with_outputs(self): + outputs = mm.ParallelBlock({"a": mm.BinaryOutput(), "b": mm.BinaryOutput()}) + outputs.prepend_for_each(mm.MLPBlock([10, 10])) + outputs.prepend(MMOEBlock(mm.MLPBlock([10, 10]), 2, outputs)) + + assert isinstance(outputs.pre[0], MMOEBlock) + assert list(outputs.pre[0][1].keys()) == ["a", "b"] + + def test_forward(self): + mmoe = MMOEBlock(mm.MLPBlock([2, 2]), 2) + + outputs = module_utils.module_test(mmoe, torch.rand(5, 5)) + assert outputs.shape == (5, 10) + + def test_forward_with_outputs(self): + outputs = mm.ParallelBlock({"a": mm.BinaryOutput(), "b": mm.BinaryOutput()}) + outputs.prepend_for_each(mm.MLPBlock([2, 2])) + outputs.prepend(MMOEBlock(mm.MLPBlock([2, 2]), 2, outputs)) + + outputs = module_utils.module_test(outputs, torch.rand(5, 5)) + assert outputs["a"].shape == (5, 1) + assert outputs["b"].shape == (5, 1) From f3d578e27a0449ac42f6eeb66071b4dab1ef2ee3 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 3 Jul 2023 18:19:12 +0200 Subject: [PATCH 3/7] Adding some doc-strings --- merlin/models/torch/blocks/experts.py | 93 +++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/merlin/models/torch/blocks/experts.py b/merlin/models/torch/blocks/experts.py index e08b19c44f..89758c2584 100644 --- a/merlin/models/torch/blocks/experts.py +++ b/merlin/models/torch/blocks/experts.py @@ -12,9 +12,38 @@ repeat_parallel_like, ) from merlin.models.torch.transforms.agg import Stack +from merlin.models.utils.doc_utils import docstring_parameter + +_PLE_REFERENCE = """ + References + ---------- + .. [1] Tang, Hongyan, et al. "Progressive layered extraction (ple): A novel multi-task + learning (mtl) model for personalized recommendations." + Fourteenth ACM Conference on Recommender Systems. 2020. +""" class MMOEBlock(Block): + """ + Multi-gate Mixture-of-Experts (MMoE) Block introduced in [1]. + + References + ---------- + [1] Ma, Jiaqi, et al. "Modeling task relationships in multi-task learning with + multi-gate mixture-of-experts." Proceedings of the 24th ACM SIGKDD international + conference on knowledge discovery & data mining. 2018. + + Parameters + ---------- + expert : nn.Module + The base expert model to be used. + num_experts : int + The number of experts to be used. + outputs : Optional[ParallelBlock] + The output block. If it is an instance of ParallelBlock, + repeat it for each expert, otherwise use a single ExpertGateBlock. + """ + def __init__( self, expert: nn.Module, num_experts: int, outputs: Optional[ParallelBlock] = None ): @@ -26,7 +55,27 @@ def __init__( self.append(ExpertGateBlock(1)) +@docstring_parameter(ple_reference=_PLE_REFERENCE) class PLEBlock(Block): + """ + Progressive Layered Extraction (PLE) Block proposed in [1]. + + {ple_reference} + + Parameters + ---------- + expert : nn.Module + The base expert model to be used. + num_shared_experts : int + The number of shared experts. + num_task_experts : int + The number of task-specific experts. + depth : int + The depth of the network. + outputs : ParallelBlock + The output block. + """ + def __init__( self, expert: nn.Module, @@ -43,6 +92,25 @@ def __init__( class CGCBlock(Block): + """ + Implements the Customized Gate Control (CGC) proposed in [1]. + + {ple_reference} + + Parameters + ---------- + expert : nn.Module + The base expert model to be used. + num_shared_experts : int + The number of shared experts. + num_task_experts : int + The number of task-specific experts. + outputs : ParallelBlock + The output block. + shared_gate : bool, optional + If true, use a shared gate for all tasks. Defaults to False. + """ + def __init__( self, expert: nn.Module, @@ -68,6 +136,14 @@ def __init__( class ExpertGateBlock(Block): + """Expert Gate Block. + + Parameters + ---------- + num_outputs : int + The number of output channels. + """ + def __init__(self, num_outputs: int): super().__init__(GateBlock(num_outputs)) @@ -89,6 +165,19 @@ def forward( class PLEExpertGateBlock(Block): + """ + Progressive Layered Extraction (PLE) Expert Gate Block. + + Parameters + ---------- + num_outputs : int + The number of output channels. + experts : nn.Module + The expert module. + name : str + The name of the task. + """ + def __init__(self, num_outputs: int, experts: nn.Module, name: str): super().__init__(GateBlock(num_outputs), name=f"PLEExpertGateBlock[{name}]") self.stack = Stack(dim=1) @@ -106,11 +195,15 @@ def forward( class SoftmaxGate(nn.Module): + """Softmax Gate for gating mechanism.""" + def forward(self, gate_logits): return torch.softmax(gate_logits, dim=1).unsqueeze(2) class GateBlock(Block): + """Gate Block for gating mechanism.""" + def __init__(self, num_outputs: int): super().__init__() self.append(nn.LazyLinear(num_outputs)) From ebea1dfe7fcffdf94202496dd3ce2c6c9de5e0db Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 3 Jul 2023 18:30:33 +0200 Subject: [PATCH 4/7] Fixing failing tests --- merlin/models/torch/blocks/experts.py | 13 ++++++++----- tests/unit/torch/blocks/test_experts.py | 12 ++++++------ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/merlin/models/torch/blocks/experts.py b/merlin/models/torch/blocks/experts.py index 89758c2584..9977dd4e97 100644 --- a/merlin/models/torch/blocks/experts.py +++ b/merlin/models/torch/blocks/experts.py @@ -84,11 +84,14 @@ def __init__( depth: int, outputs: ParallelBlock, ): - cgc = CGCBlock( - expert, num_shared_experts, num_task_experts, outputs=outputs, shared_gate=True - ) - super().__init__(*cgc.repeat(depth - 1)) - self.append(CGCBlock(expert, num_shared_experts, num_task_experts, outputs=outputs)) + cgc_kwargs = { + "expert": expert, + "num_shared_experts": num_shared_experts, + "num_task_experts": num_task_experts, + "outputs": outputs, + } + super().__init__(*CGCBlock(shared_gate=True, **cgc_kwargs).repeat(depth - 1)) + self.append(CGCBlock(**cgc_kwargs)) class CGCBlock(Block): diff --git a/tests/unit/torch/blocks/test_experts.py b/tests/unit/torch/blocks/test_experts.py index f3738d4e0f..2452c4cd82 100644 --- a/tests/unit/torch/blocks/test_experts.py +++ b/tests/unit/torch/blocks/test_experts.py @@ -10,22 +10,22 @@ class TestMMOEBlock: def test_init(self): - mmoe = MMOEBlock(mm.MLPBlock([10, 10]), 2) + mmoe = MMOEBlock(mm.MLPBlock([2, 2]), 2) assert isinstance(mmoe, MMOEBlock) assert isinstance(mmoe[0], mm.ShortcutBlock) assert len(mmoe[0][0].branches) == 2 for i in range(2): - assert mmoe[0][0][str(i)][1].out_features == 10 - assert mmoe[0][0][str(i)][3].out_features == 10 + assert mmoe[0][0][str(i)][1].out_features == 2 + assert mmoe[0][0][str(i)][3].out_features == 2 assert isinstance(mmoe[0][0].post[0], mm.Stack) assert isinstance(mmoe[1], ExpertGateBlock) assert mmoe[1][0][0].out_features == 1 def test_init_with_outputs(self): outputs = mm.ParallelBlock({"a": mm.BinaryOutput(), "b": mm.BinaryOutput()}) - outputs.prepend_for_each(mm.MLPBlock([10, 10])) - outputs.prepend(MMOEBlock(mm.MLPBlock([10, 10]), 2, outputs)) + outputs.prepend_for_each(mm.MLPBlock([2])) + outputs.prepend(MMOEBlock(mm.MLPBlock([2, 2]), 2, outputs)) assert isinstance(outputs.pre[0], MMOEBlock) assert list(outputs.pre[0][1].keys()) == ["a", "b"] @@ -34,7 +34,7 @@ def test_forward(self): mmoe = MMOEBlock(mm.MLPBlock([2, 2]), 2) outputs = module_utils.module_test(mmoe, torch.rand(5, 5)) - assert outputs.shape == (5, 10) + assert outputs.shape == (5, 2) def test_forward_with_outputs(self): outputs = mm.ParallelBlock({"a": mm.BinaryOutput(), "b": mm.BinaryOutput()}) From e78f94e899b30aac2fdf63ad9ea8c7bcfc70258d Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Fri, 7 Jul 2023 11:05:26 +0200 Subject: [PATCH 5/7] Increase test-coverage --- merlin/models/torch/blocks/experts.py | 79 ++++++++++++++++-------- tests/unit/torch/blocks/test_experts.py | 81 ++++++++++++++++++++++++- 2 files changed, 134 insertions(+), 26 deletions(-) diff --git a/merlin/models/torch/blocks/experts.py b/merlin/models/torch/blocks/experts.py index 9977dd4e97..a301286980 100644 --- a/merlin/models/torch/blocks/experts.py +++ b/merlin/models/torch/blocks/experts.py @@ -1,3 +1,5 @@ +import textwrap +from functools import partial from typing import Dict, Optional, Union import torch @@ -11,7 +13,7 @@ repeat_parallel, repeat_parallel_like, ) -from merlin.models.torch.transforms.agg import Stack +from merlin.models.torch.transforms.agg import Concat, Stack from merlin.models.utils.doc_utils import docstring_parameter _PLE_REFERENCE = """ @@ -50,9 +52,9 @@ def __init__( experts = repeat_parallel(expert, num_experts, agg=Stack(dim=1)) super().__init__(ShortcutBlock(experts, output_name="experts")) if isinstance(outputs, ParallelBlock): - self.append(repeat_parallel_like(ExpertGateBlock(len(outputs)), outputs)) + self.append(repeat_parallel_like(ExpertGateBlock(num_experts), outputs)) else: - self.append(ExpertGateBlock(1)) + self.append(ExpertGateBlock(num_experts)) @docstring_parameter(ple_reference=_PLE_REFERENCE) @@ -79,6 +81,7 @@ class PLEBlock(Block): def __init__( self, expert: nn.Module, + *, num_shared_experts: int, num_task_experts: int, depth: int, @@ -117,23 +120,27 @@ class CGCBlock(Block): def __init__( self, expert: nn.Module, + *, num_shared_experts: int, num_task_experts: int, outputs: ParallelBlock, shared_gate: bool = False, ): - shared_experts = repeat_parallel(expert, num_shared_experts, agg="stack") - super().__init__(ShortcutBlock(shared_experts, output_name="experts")) + shared_experts = repeat_parallel(expert, num_shared_experts, agg=Stack(dim=1)) + expert_shortcut = partial(ShortcutBlock, output_name="experts") + super().__init__(expert_shortcut(shared_experts)) gates = ParallelBlock() for name in outputs.branches: - gates[name] = PLEExpertGateBlock( - len(outputs), - experts=repeat_parallel(expert, num_task_experts, agg="stack"), + gates.branches[name] = PLEExpertGateBlock( + num_shared_experts + num_task_experts, + experts=repeat_parallel(expert, num_task_experts, agg=Stack(dim=1)), name=name, ) if shared_gate: - gates["experts"] = ExpertGateBlock(len(outputs)) + gates.branches["experts"] = expert_shortcut( + ExpertGateBlock(num_shared_experts), propagate_shortcut=True + ) self.append(gates) @@ -141,14 +148,16 @@ def __init__( class ExpertGateBlock(Block): """Expert Gate Block. + # TODO: Add initialize_from_schema to remove the need to pass in num_experts + Parameters ---------- - num_outputs : int - The number of output channels. + num_experts : int + The number of experts used. """ - def __init__(self, num_outputs: int): - super().__init__(GateBlock(num_outputs)) + def __init__(self, num_experts: int): + super().__init__(GateBlock(num_experts)) def forward( self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None @@ -161,6 +170,7 @@ def forward( for module in self.values: outputs = module(outputs, batch=batch) + # return torch.sum(experts * outputs, dim=1, keepdim=False) gated = outputs.expand_as(experts) # Multiply and sum along the experts dimension @@ -173,41 +183,62 @@ class PLEExpertGateBlock(Block): Parameters ---------- - num_outputs : int - The number of output channels. + num_experts : int + The number of experts used. experts : nn.Module The expert module. name : str The name of the task. """ - def __init__(self, num_outputs: int, experts: nn.Module, name: str): - super().__init__(GateBlock(num_outputs), name=f"PLEExpertGateBlock[{name}]") + def __init__(self, num_experts: int, experts: nn.Module, name: str): + super().__init__(ExpertGateBlock(num_experts), name=f"PLEExpertGateBlock[{name}]") self.stack = Stack(dim=1) + self.concat = Concat(dim=1) self.experts = experts self.task_name = name def forward( - self, inputs: Dict[str, torch.Tensor], batch: Optional[Batch] = None + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None ) -> torch.Tensor: + if torch.jit.isinstance(inputs, torch.Tensor): + raise RuntimeError("ExpertGateBlock requires a dictionary input") + task_experts = self.experts(inputs["shortcut"], batch=batch) - experts = self.stack({"experts": inputs["experts"], "task_experts": task_experts}) - task = inputs[self.name] if self.name in inputs else inputs["shortcut"] + if torch.jit.isinstance(task_experts, torch.Tensor): + _task = task_experts + elif torch.jit.isinstance(task_experts, Dict[str, torch.Tensor]): + _task = self.stack(task_experts) + else: + raise RuntimeError("PLEExpertGateBlock requires a dictionary input") + experts = self.concat({"experts": inputs["experts"], "task_experts": _task}) + task = inputs[self.task_name] if self.task_name in inputs else inputs["shortcut"] + + outputs = {"experts": experts, "shortcut": task} + for block in self.values: + outputs = block(outputs, batch=batch) + + return outputs + + def __repr__(self) -> str: + indent_str = " " + output = textwrap.indent("\n(experts): " + repr(self.experts), indent_str) + output += textwrap.indent("\n(gate): " + repr(self.values[0]), indent_str) - return self.output({"experts": experts, "shortcut": task}) + return f"{self._get_name()}({output}\n)" class SoftmaxGate(nn.Module): """Softmax Gate for gating mechanism.""" def forward(self, gate_logits): - return torch.softmax(gate_logits, dim=1).unsqueeze(2) + return torch.softmax(gate_logits, dim=-1).unsqueeze(-1) class GateBlock(Block): """Gate Block for gating mechanism.""" - def __init__(self, num_outputs: int): + def __init__(self, num_experts: int): super().__init__() - self.append(nn.LazyLinear(num_outputs)) + self.append(nn.LazyLinear(num_experts)) self.append(SoftmaxGate()) diff --git a/tests/unit/torch/blocks/test_experts.py b/tests/unit/torch/blocks/test_experts.py index 2452c4cd82..525c1e7b66 100644 --- a/tests/unit/torch/blocks/test_experts.py +++ b/tests/unit/torch/blocks/test_experts.py @@ -1,12 +1,32 @@ +import pytest import torch import merlin.models.torch as mm -from merlin.models.torch.blocks.experts import ( # CGCBlock,; PLEBlock,; PLEExpertGateBlock, +from merlin.models.torch.blocks.experts import ( + CGCBlock, ExpertGateBlock, MMOEBlock, + PLEBlock, + PLEExpertGateBlock, ) from merlin.models.torch.utils import module_utils +dict_inputs = {"experts": torch.rand((10, 4, 5)), "shortcut": torch.rand((10, 8))} + + +class TestExpertGateBlock: + @pytest.fixture + def expert_gate(self): + return ExpertGateBlock(num_experts=4) + + def test_requires_dict_input(self, expert_gate): + with pytest.raises(RuntimeError, match="ExpertGateBlock requires a dictionary input"): + expert_gate(torch.rand((10, 5))) + + def test_forward_pass(self, expert_gate): + result = module_utils.module_test(expert_gate, dict_inputs) + assert result.shape == (10, 5) + class TestMMOEBlock: def test_init(self): @@ -20,7 +40,7 @@ def test_init(self): assert mmoe[0][0][str(i)][3].out_features == 2 assert isinstance(mmoe[0][0].post[0], mm.Stack) assert isinstance(mmoe[1], ExpertGateBlock) - assert mmoe[1][0][0].out_features == 1 + assert mmoe[1][0][0].out_features == 2 def test_init_with_outputs(self): outputs = mm.ParallelBlock({"a": mm.BinaryOutput(), "b": mm.BinaryOutput()}) @@ -44,3 +64,60 @@ def test_forward_with_outputs(self): outputs = module_utils.module_test(outputs, torch.rand(5, 5)) assert outputs["a"].shape == (5, 1) assert outputs["b"].shape == (5, 1) + + +class TestPLEExpertGateBlock: + @pytest.fixture + def ple_expert_gate(self): + return PLEExpertGateBlock( + num_experts=6, experts=mm.repeat_parallel(mm.MLPBlock([5, 5]), 2), name="a" + ) + + def test_repr(self, ple_expert_gate): + assert "(experts)" in str(ple_expert_gate) + assert "(gate)" in str(ple_expert_gate) + + def test_requires_dict_input(self, ple_expert_gate): + with pytest.raises(RuntimeError, match="ExpertGateBlock requires a dictionary input"): + ple_expert_gate(torch.rand((10, 5))) + + def test_ple_forward(self, ple_expert_gate): + result = module_utils.module_test(ple_expert_gate, dict_inputs) + assert result.shape == (10, 5) + + +class TestCGCBlock: + @pytest.mark.parametrize("shared_gate", [True, False]) + def test_forward(self, music_streaming_data, shared_gate): + output_block = mm.TabularOutputBlock(music_streaming_data.schema, init="defaults") + cgc = CGCBlock( + mm.MLPBlock([5]), + num_shared_experts=2, + num_task_experts=2, + outputs=output_block, + shared_gate=shared_gate, + ) + + outputs = module_utils.module_test(cgc, torch.rand(5, 5)) + assert len(outputs) == len(output_block) + (2 if shared_gate else 0) + + +class TestPLEBlock: + def test_forward(self, music_streaming_data): + output_block = mm.TabularOutputBlock(music_streaming_data.schema, init="defaults") + ple = PLEBlock( + mm.MLPBlock([5]), + num_shared_experts=2, + num_task_experts=2, + depth=2, + outputs=output_block, + ) + + assert isinstance(ple[0], CGCBlock) + assert len(ple[0][1]) == len(output_block) + 1 + assert isinstance(ple[0][1]["experts"][0], ExpertGateBlock) + assert isinstance(ple[1], CGCBlock) + assert list(ple[1][1].branches.keys()) == list(ple[1][1].branches.keys()) + + outputs = module_utils.module_test(ple, torch.rand(5, 5)) + assert len(outputs) == len(output_block) From ee8c419dd63bd6106842608b4806aa0a13806c9b Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 11 Jul 2023 10:06:40 +0200 Subject: [PATCH 6/7] Improving doc-strings --- merlin/models/torch/__init__.py | 4 ++ merlin/models/torch/blocks/experts.py | 81 +++++++++++++++++++++------ 2 files changed, 67 insertions(+), 18 deletions(-) diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index ba6a4b422d..bcbc77c5d5 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -26,6 +26,7 @@ repeat_parallel_like, ) from merlin.models.torch.blocks.dlrm import DLRMBlock +from merlin.models.torch.blocks.experts import CGCBlock, MMOEBlock, PLEBlock from merlin.models.torch.blocks.mlp import MLPBlock from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys @@ -88,4 +89,7 @@ "DLRMBlock", "DLRMModel", "DCNModel", + "MMOEBlock", + "PLEBlock", + "CGCBlock", ] diff --git a/merlin/models/torch/blocks/experts.py b/merlin/models/torch/blocks/experts.py index a301286980..d23a912085 100644 --- a/merlin/models/torch/blocks/experts.py +++ b/merlin/models/torch/blocks/experts.py @@ -29,6 +29,22 @@ class MMOEBlock(Block): """ Multi-gate Mixture-of-Experts (MMoE) Block introduced in [1]. + The MMoE model builds upon the concept of an expert model by using a mixture + of experts for decision making. Each expert contributes independently to the + final decision, allowing for increased model complexity and performance + in multi-task learning scenarios. + + Example usage for multi-task learning:: + >>> outputs = mm.TabularOutputBlock(schema, init="defaults") + >>> mmoe = mm.MMOEBlock( + expert=mm.MLPBlock([5]), + num_experts=2, + outputs=outputs, + ) + >>> outputs.prepend_for_each(mm.MLPBlock([64])) # Add task-towers + >>> outputs.prepend(mmoe) + + References ---------- [1] Ma, Jiaqi, et al. "Modeling task relationships in multi-task learning with @@ -38,12 +54,14 @@ class MMOEBlock(Block): Parameters ---------- expert : nn.Module - The base expert model to be used. + The base expert model that serves as the foundation for the MMoE structure. num_experts : int - The number of experts to be used. + The total number of experts in the MMoE model. Each expert operates independently + in the decision-making process. outputs : Optional[ParallelBlock] - The output block. If it is an instance of ParallelBlock, - repeat it for each expert, otherwise use a single ExpertGateBlock. + The output block of the model. + If it is an instance of ParallelBlock, the block is repeated for each expert. + Otherwise, a single ExpertGateBlock is used. """ def __init__( @@ -62,20 +80,40 @@ class PLEBlock(Block): """ Progressive Layered Extraction (PLE) Block proposed in [1]. + The PLE model enhances the architecture of a typical expert model by organizing + shared and task-specific experts in a layered format. This layered structure + allows the extraction of increasingly complex features at each level and can + improve performance in multi-task settings. + + Example usage for multi-task learning:: + >>> outputs = mm.TabularOutputBlock(schema, init="defaults") + >>> ple = mm.PLEBlock( + expert=mm.MLPBlock([5]), + num_shared_experts=2, + num_task_experts=2, + depth=2, + outputs=outputs, + ) + >>> outputs.prepend_for_each(mm.MLPBlock([64])) # Add task-towers + >>> outputs.prepend(ple) + {ple_reference} Parameters ---------- expert : nn.Module - The base expert model to be used. + The base expert model that forms the basis of the PLE structure. num_shared_experts : int - The number of shared experts. + The total count of shared experts. These experts contribute to the + decision process in all tasks. num_task_experts : int - The number of task-specific experts. + The total count of task-specific experts. These experts contribute + only to their specific tasks. depth : int - The depth of the network. + The depth of the layered structure. Each layer comprises a set of experts + and the depth determines the number of such layers. outputs : ParallelBlock - The output block. + The output block, which encapsulates the final output from the model. """ def __init__( @@ -101,20 +139,27 @@ class CGCBlock(Block): """ Implements the Customized Gate Control (CGC) proposed in [1]. + The CGC model extends the capability of a typical expert model by introducing + shared and task-specific experts, thereby customizing the gating control per task, + which may lead to improved performance in multi-task settings. + {ple_reference} Parameters ---------- expert : nn.Module - The base expert model to be used. + The base expert model that is used as the foundation for the gating mechanism. num_shared_experts : int - The number of shared experts. + The total count of shared experts. These experts contribute to the decision + process in all tasks. num_task_experts : int - The number of task-specific experts. + The total count of task-specific experts. These experts contribute only + to their specific tasks. outputs : ParallelBlock - The output block. + The output block, which encapsulates the final output from the model. shared_gate : bool, optional - If true, use a shared gate for all tasks. Defaults to False. + Defines whether a shared gate is used across all tasks or not. + If set to True, a shared gate is used. Defaults to False. """ def __init__( @@ -185,17 +230,17 @@ class PLEExpertGateBlock(Block): ---------- num_experts : int The number of experts used. - experts : nn.Module + task_experts : nn.Module The expert module. name : str The name of the task. """ - def __init__(self, num_experts: int, experts: nn.Module, name: str): + def __init__(self, num_experts: int, task_experts: nn.Module, name: str): super().__init__(ExpertGateBlock(num_experts), name=f"PLEExpertGateBlock[{name}]") self.stack = Stack(dim=1) self.concat = Concat(dim=1) - self.experts = experts + self.task_experts = task_experts self.task_name = name def forward( @@ -204,7 +249,7 @@ def forward( if torch.jit.isinstance(inputs, torch.Tensor): raise RuntimeError("ExpertGateBlock requires a dictionary input") - task_experts = self.experts(inputs["shortcut"], batch=batch) + task_experts = self.task_experts(inputs["shortcut"], batch=batch) if torch.jit.isinstance(task_experts, torch.Tensor): _task = task_experts elif torch.jit.isinstance(task_experts, Dict[str, torch.Tensor]): From 85575cf813a1add1b7906c88ae0cae1017f5ff9e Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 11 Jul 2023 10:18:46 +0200 Subject: [PATCH 7/7] Fixing failing tests --- merlin/models/torch/blocks/experts.py | 4 ++-- tests/unit/torch/blocks/test_experts.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/merlin/models/torch/blocks/experts.py b/merlin/models/torch/blocks/experts.py index d23a912085..627b96a1fd 100644 --- a/merlin/models/torch/blocks/experts.py +++ b/merlin/models/torch/blocks/experts.py @@ -179,7 +179,7 @@ def __init__( for name in outputs.branches: gates.branches[name] = PLEExpertGateBlock( num_shared_experts + num_task_experts, - experts=repeat_parallel(expert, num_task_experts, agg=Stack(dim=1)), + task_experts=repeat_parallel(expert, num_task_experts, agg=Stack(dim=1)), name=name, ) if shared_gate: @@ -267,7 +267,7 @@ def forward( def __repr__(self) -> str: indent_str = " " - output = textwrap.indent("\n(experts): " + repr(self.experts), indent_str) + output = textwrap.indent("\n(task_experts): " + repr(self.task_experts), indent_str) output += textwrap.indent("\n(gate): " + repr(self.values[0]), indent_str) return f"{self._get_name()}({output}\n)" diff --git a/tests/unit/torch/blocks/test_experts.py b/tests/unit/torch/blocks/test_experts.py index 525c1e7b66..5f32bf3485 100644 --- a/tests/unit/torch/blocks/test_experts.py +++ b/tests/unit/torch/blocks/test_experts.py @@ -70,11 +70,11 @@ class TestPLEExpertGateBlock: @pytest.fixture def ple_expert_gate(self): return PLEExpertGateBlock( - num_experts=6, experts=mm.repeat_parallel(mm.MLPBlock([5, 5]), 2), name="a" + num_experts=6, task_experts=mm.repeat_parallel(mm.MLPBlock([5, 5]), 2), name="a" ) def test_repr(self, ple_expert_gate): - assert "(experts)" in str(ple_expert_gate) + assert "(task_experts)" in str(ple_expert_gate) assert "(gate)" in str(ple_expert_gate) def test_requires_dict_input(self, ple_expert_gate):