From 91774db207c29a6e687de99b8bb93dd8d5cbebb1 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Sat, 1 Jul 2023 15:57:41 +0200 Subject: [PATCH 1/6] Removing Link in favour of some new Blocks like: ResidualBlock & ShortcutBlock --- merlin/models/torch/__init__.py | 4 +- merlin/models/torch/block.py | 112 ++++++++++++++++++++++++----- merlin/models/torch/blocks/dlrm.py | 40 +++++++---- merlin/models/torch/container.py | 97 ++++--------------------- merlin/models/torch/link.py | 83 --------------------- tests/unit/torch/test_block.py | 58 +++++++++++---- tests/unit/torch/test_container.py | 48 +------------ tests/unit/torch/test_link.py | 76 -------------------- 8 files changed, 183 insertions(+), 335 deletions(-) delete mode 100644 merlin/models/torch/link.py delete mode 100644 tests/unit/torch/test_link.py diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index d2326af5e9..988897ef44 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -16,7 +16,7 @@ from merlin.models.torch import schema from merlin.models.torch.batch import Batch, Sequence -from merlin.models.torch.block import Block, ParallelBlock +from merlin.models.torch.block import Block, ParallelBlock, ResidualBlock, ShortcutBlock from merlin.models.torch.blocks.dlrm import DLRMBlock from merlin.models.torch.blocks.mlp import MLPBlock from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables @@ -45,9 +45,11 @@ "ParallelBlock", "Sequence", "RegressionOutput", + "ResidualBlock", "RouterBlock", "SelectKeys", "SelectFeatures", + "ShortcutBlock", "TabularInputBlock", "Concat", "Stack", diff --git a/merlin/models/torch/block.py b/merlin/models/torch/block.py index 44d6212909..77b9844464 100644 --- a/merlin/models/torch/block.py +++ b/merlin/models/torch/block.py @@ -25,7 +25,6 @@ from merlin.models.torch import schema from merlin.models.torch.batch import Batch from merlin.models.torch.container import BlockContainer, BlockContainerDict -from merlin.models.torch.link import Link, LinkType from merlin.models.torch.registry import registry from merlin.models.torch.utils.traversal_utils import TraversableMixin, leaf from merlin.models.utils.registry import RegistryMixin @@ -41,8 +40,6 @@ class Block(BlockContainer, RegistryMixin, TraversableMixin): Variable length argument list of PyTorch modules to be contained in the block. name : Optional[str], default = None The name of the block. If None, no name is assigned. - track_schema : bool, default = True - If True, the schema of the output tensors are tracked. """ registry = registry @@ -73,7 +70,7 @@ def forward( return inputs - def repeat(self, n: int = 1, link: Optional[LinkType] = None, name=None) -> "Block": + def repeat(self, n: int = 1, name=None) -> "Block": """ Creates a new block by repeating the current block `n` times. Each repetition is a deep copy of the current block. @@ -97,9 +94,6 @@ def repeat(self, n: int = 1, link: Optional[LinkType] = None, name=None) -> "Blo raise ValueError("n must be greater than 0") repeats = [self.copy() for _ in range(n - 1)] - if link: - parsed_link = Link.parse(link) - repeats = [parsed_link.copy().setup_link(repeat) for repeat in repeats] return Block(self, *repeats, name=name) @@ -221,7 +215,7 @@ def forward( return outputs - def append(self, module: nn.Module, link: Optional[LinkType] = None): + def append(self, module: nn.Module): """Appends a module to the post-processing stage. Parameters @@ -235,7 +229,7 @@ def append(self, module: nn.Module, link: Optional[LinkType] = None): The current object itself. """ - self.post.append(module, link=link) + self.post.append(module) return self @@ -244,7 +238,7 @@ def prepend(self, module: nn.Module): return self - def append_to(self, name: str, module: nn.Module, link: Optional[LinkType] = None): + def append_to(self, name: str, module: nn.Module): """Appends a module to a specified branch. Parameters @@ -260,11 +254,11 @@ def append_to(self, name: str, module: nn.Module, link: Optional[LinkType] = Non The current object itself. """ - self.branches[name].append(module, link=link) + self.branches[name].append(module) return self - def prepend_to(self, name: str, module: nn.Module, link: Optional[LinkType] = None): + def prepend_to(self, name: str, module: nn.Module): """Prepends a module to a specified branch. Parameters @@ -279,11 +273,11 @@ def prepend_to(self, name: str, module: nn.Module, link: Optional[LinkType] = No ParallelBlock The current object itself. """ - self.branches[name].prepend(module, link=link) + self.branches[name].prepend(module) return self - def append_for_each(self, module: nn.Module, shared=False, link: Optional[LinkType] = None): + def append_for_each(self, module: nn.Module, shared=False): """Appends a module to each branch. Parameters @@ -300,11 +294,11 @@ def append_for_each(self, module: nn.Module, shared=False, link: Optional[LinkTy The current object itself. """ - self.branches.append_for_each(module, shared=shared, link=link) + self.branches.append_for_each(module, shared=shared) return self - def prepend_for_each(self, module: nn.Module, shared=False, link: Optional[LinkType] = None): + def prepend_for_each(self, module: nn.Module, shared=False): """Prepends a module to each branch. Parameters @@ -321,7 +315,7 @@ def prepend_for_each(self, module: nn.Module, shared=False, link: Optional[LinkT The current object itself. """ - self.branches.prepend_for_each(module, shared=shared, link=link) + self.branches.prepend_for_each(module, shared=shared) return self @@ -415,6 +409,90 @@ def __repr__(self) -> str: return self._get_name() + branches +class ResidualBlock(Block): + """ + A block that applies each contained module sequentially on the input + and performs a residual connection after each module. + + Parameters + ---------- + *module : nn.Module + Variable length argument list of PyTorch modules to be contained in the block. + name : Optional[str], default = None + The name of the block. If None, no name is assigned. + + """ + + def forward(self, inputs: torch.Tensor, batch: Optional[Batch] = None): + """ + Forward pass through the block. Applies each contained module sequentially on the input. + + Parameters + ---------- + inputs : Union[torch.Tensor, Dict[str, torch.Tensor]] + The input data as a tensor or a dictionary of tensors. + batch : Optional[Batch], default = None + Optional batch of data. If provided, it is used by the `module`s. + + Returns + ------- + torch.Tensor or Dict[str, torch.Tensor] + The output of the block after processing the input. + """ + shortcut, outputs = inputs, inputs + for module in self.values: + outputs = shortcut + module(outputs, batch=batch) + + return outputs + + +class ShortcutBlock(Block): + def __init__( + self, + *module: nn.Module, + name: Optional[str] = None, + shortcut_name: str = "shortcut", + output_name: str = "output", + ): + super().__init__(*module, name=name) + self.shortcut_name = shortcut_name + self.output_name = output_name + + def forward( + self, inputs: torch.Tensor, batch: Optional[Batch] = None + ) -> Dict[str, torch.Tensor]: + shortcut, output = inputs, inputs + for module in self.values: + if getattr(module, "accepts_dict", False): + module_output = module(self._create_dict(shortcut, output), batch=batch) + if torch.jit.isinstance(module_output, torch.Tensor): + output = module_output + elif isinstance(module_output, Dict[str, torch.Tensor]): + output = module_output[self.output_name] + else: + raise ValueError( + f"Module {module} must return a tensor or a dict ", + f"with key {self.output_name}", + ) + else: + output = module(output, batch=batch) + + return self._create_dict(shortcut, output) + + def _create_dict(self, shortcut: torch.Tensor, output: torch.Tensor) -> Dict[str, torch.Tensor]: + return {self.shortcut_name: shortcut, self.output_name: output} + + +class CrossBlock(Block): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + x0 = inputs + current = inputs + for module in self.values: + current = x0 * module(current) + current + + return current + + def get_pre(module: nn.Module) -> BlockContainer: if hasattr(module, "pre"): return module.pre diff --git a/merlin/models/torch/blocks/dlrm.py b/merlin/models/torch/blocks/dlrm.py index a24e4d1f71..3b638ada08 100644 --- a/merlin/models/torch/blocks/dlrm.py +++ b/merlin/models/torch/blocks/dlrm.py @@ -1,13 +1,13 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Union import torch from torch import nn +from merlin.models.torch.batch import Batch from merlin.models.torch.block import Block from merlin.models.torch.inputs.embedding import EmbeddingTables from merlin.models.torch.inputs.tabular import TabularInputBlock -from merlin.models.torch.link import Link -from merlin.models.torch.transforms.agg import MaybeAgg, Stack +from merlin.models.torch.transforms.agg import Stack from merlin.models.utils.doc_utils import docstring_parameter from merlin.schema import Schema, Tags @@ -77,7 +77,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return interactions_flat -class ShortcutConcatContinuous(Link): +class InteractionBlock(Block): """ A shortcut connection that concatenates continuous input features and intermediate outputs. @@ -85,13 +85,28 @@ class ShortcutConcatContinuous(Link): When there's no continuous input, the intermediate output is returned. """ - def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: - intermediate_output = self.output(inputs) + def __init__( + self, + *module: nn.Module, + name: Optional[str] = None, + prepend_agg: bool = True, + ): + if prepend_agg: + module = (Stack(dim=1),) + module + super().__init__(*module, name=name) - if "continuous" in inputs: - return torch.cat((inputs["continuous"], intermediate_output), dim=1) + def forward( + self, inputs: Union[Dict[str, torch.Tensor], torch.Tensor], batch: Optional[Batch] = None + ) -> torch.Tensor: + outputs = inputs + for module in self.values: + outputs = module(outputs, batch) - return intermediate_output + if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]): + if "continuous" in inputs: + return torch.cat((inputs["continuous"], outputs), dim=1) + + return outputs @docstring_parameter(dlrm_reference=_DLRM_REF) @@ -131,11 +146,6 @@ def __init__( interaction: Optional[nn.Module] = None, ): super().__init__(DLRMInputBlock(schema, dim, bottom_block)) - - self.append( - Block(MaybeAgg(Stack(dim=1)), interaction or DLRMInteraction()), - link=ShortcutConcatContinuous(), - ) - + self.append(InteractionBlock(interaction or DLRMInteraction())) if top_block: self.append(top_block) diff --git a/merlin/models/torch/container.py b/merlin/models/torch/container.py index ed185092ac..e289c694fa 100644 --- a/merlin/models/torch/container.py +++ b/merlin/models/torch/container.py @@ -21,7 +21,6 @@ from torch import nn from torch._jit_internal import _copy_to_script_wrapper -from merlin.models.torch.link import Link, LinkType from merlin.models.torch.utils import torchscript_utils @@ -47,52 +46,37 @@ def __init__(self, *inputs: nn.Module, name: Optional[str] = None): self._name: str = name - def append(self, module: nn.Module, link: Optional[Link] = None): + def append(self, module: nn.Module): """Appends a given module to the end of the list. Parameters ---------- module : nn.Module The PyTorch module to be appended. - link : Optional[LinkType] - The link to use for the module. If None, no link is used. - This can either be a Module or a string, options are: - - "residual": Adds a residual connection to the module. - - "shortcut": Adds a shortcut connection to the module. - - "shortcut-concat": Adds a shortcut connection by concatenating - the input and output. Returns ------- self """ - _module = self._check_link(module, link=link) - self.values.append(self.wrap_module(_module)) + self.values.append(self.wrap_module(module)) return self - def prepend(self, module: nn.Module, link: Optional[Link] = None): + def prepend(self, module: nn.Module): """Prepends a given module to the beginning of the list. Parameters ---------- module : nn.Module The PyTorch module to be prepended. - link : Optional[LinkType] - The link to use for the module. If None, no link is used. - This can either be a Module or a string, options are: - - "residual": Adds a residual connection to the module. - - "shortcut": Adds a shortcut connection to the module. - - "shortcut-concat": Adds a shortcut connection by concatenating - the input and output. Returns ------- self """ - return self.insert(0, module, link=link) + return self.insert(0, module) - def insert(self, index: int, module: nn.Module, link: Optional[Link] = None): + def insert(self, index: int, module: nn.Module): """Inserts a given module at the specified index. Parameters @@ -101,20 +85,12 @@ def insert(self, index: int, module: nn.Module, link: Optional[Link] = None): The index at which the module is to be inserted. module : nn.Module The PyTorch module to be inserted. - link : Optional[LinkType] - The link to use for the module. If None, no link is used. - This can either be a Module or a string, options are: - - "residual": Adds a residual connection to the module. - - "shortcut": Adds a shortcut connection to the module. - - "shortcut-concat": Adds a shortcut connection by concatenating - the input and output. Returns ------- self """ - _module = self._check_link(module, link=link) - self.values.insert(index, self.wrap_module(_module)) + self.values.insert(index, self.wrap_module(module)) return self @@ -193,15 +169,6 @@ def __repr__(self) -> str: def _get_name(self) -> str: return super()._get_name() if self._name is None else self._name - def _check_link(self, module: nn.Module, link: Optional[LinkType] = None) -> nn.Module: - if link: - linked_module: Link = Link.parse(link) - linked_module.setup_link(module) - - return linked_module - - return module - class BlockContainerDict(nn.ModuleDict): """A container class for PyTorch `nn.Module` that allows for manipulation and traversal @@ -232,9 +199,7 @@ def __init__( super().__init__(modules) self._name: str = name - def append_to( - self, name: str, module: nn.Module, link: Optional[LinkType] = None - ) -> "BlockContainerDict": + def append_to(self, name: str, module: nn.Module) -> "BlockContainerDict": """Appends a module to a specified name. Parameters @@ -243,13 +208,6 @@ def append_to( The name of the branch. module : nn.Module The module to append. - link : Optional[LinkType] - The link to use for the module. If None, no link is used. - This can either be a Module or a string, options are: - - "residual": Adds a residual connection to the module. - - "shortcut": Adds a shortcut connection to the module. - - "shortcut-concat": Adds a shortcut connection by concatenating - the input and output. Returns ------- @@ -257,13 +215,11 @@ def append_to( The current object itself. """ - self._modules[name].append(module, link=link) + self._modules[name].append(module) return self - def prepend_to( - self, name: str, module: nn.Module, link: Optional[LinkType] = None - ) -> "BlockContainerDict": + def prepend_to(self, name: str, module: nn.Module) -> "BlockContainerDict": """Prepends a module to a specified name. Parameters @@ -272,13 +228,6 @@ def prepend_to( The name of the branch. module : nn.Module The module to prepend. - link : Optional[LinkType] - The link to use for the module. If None, no link is used. - This can either be a Module or a string, options are: - - "residual": Adds a residual connection to the module. - - "shortcut": Adds a shortcut connection to the module. - - "shortcut-concat": Adds a shortcut connection by concatenating - the input and output. Returns ------- @@ -286,11 +235,9 @@ def prepend_to( The current object itself. """ - self._modules[name].prepend(module, link=link) + self._modules[name].prepend(module) - def append_for_each( - self, module: nn.Module, shared=False, link: Optional[LinkType] = None - ) -> "BlockContainerDict": + def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict": """Appends a module to each branch. Parameters @@ -300,13 +247,6 @@ def append_for_each( shared : bool, default=False If True, the same module is shared across all elements. Otherwise a deep copy of the module is used in each element. - link : Optional[LinkType] - The link to use for the module. If None, no link is used. - This can either be a Module or a string, options are: - - "residual": Adds a residual connection to the module. - - "shortcut": Adds a shortcut connection to the module. - - "shortcut-concat": Adds a shortcut connection by concatenating - the input and output. Returns ------- @@ -316,13 +256,11 @@ def append_for_each( for branch in self.values(): _module = module if shared else deepcopy(module) - branch.append(_module, link=link) + branch.append(_module) return self - def prepend_for_each( - self, module: nn.Module, shared=False, link: Optional[LinkType] = None - ) -> "BlockContainerDict": + def prepend_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict": """Prepends a module to each branch. Parameters @@ -332,13 +270,6 @@ def prepend_for_each( shared : bool, default=False If True, the same module is shared across all elements. Otherwise a deep copy of the module is used in each element. - link : Optional[LinkType] - The link to use for the module. If None, no link is used. - This can either be a Module or a string, options are: - - "residual": Adds a residual connection to the module. - - "shortcut": Adds a shortcut connection to the module. - - "shortcut-concat": Adds a shortcut connection by concatenating - the input and output. Returns ------- @@ -347,7 +278,7 @@ def prepend_for_each( """ for branch in self.values(): _module = module if shared else deepcopy(module) - branch.prepend(_module, link=link) + branch.prepend(_module) return self diff --git a/merlin/models/torch/link.py b/merlin/models/torch/link.py deleted file mode 100644 index f490aeec7f..0000000000 --- a/merlin/models/torch/link.py +++ /dev/null @@ -1,83 +0,0 @@ -import copy -from typing import Dict, Optional, Union - -import torch -from torch import nn - -from merlin.models.torch.registry import TorchRegistryMixin - -LinkType = Union[str, "Link"] - - -class Link(nn.Module, TorchRegistryMixin): - """Base class for different types of network links. - - This is typically used as part of a `Block` to connect different modules. - - Some examples of links are: - - `residual`: Adds the input to the output of the module. - - `shortcut`: Outputs a dictionary with the output of the module and the input. - - `shortcut-concat`: Concatenates the input and the output of the module. - - """ - - def __init__(self, output: Optional[nn.Module] = None): - super().__init__() - - if output is not None: - self.setup_link(output) - - def setup_link(self, output: nn.Module) -> "Link": - """ - Setup function for the link. - - Parameters - ---------- - output : nn.Module - The output module for the link. - - Returns - ------- - Link - The updated Link instance. - """ - - self.output = output - - return self - - def copy(self) -> "Link": - """ - Returns a copy of the link. - - Returns - ------- - Link - The copied link. - """ - return copy.deepcopy(self) - - -@Link.registry.register("residual") -class Residual(Link): - """Adds the input to the output of the module.""" - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x + self.output(x) - - -@Link.registry.register("shortcut") -class Shortcut(Link): - """Outputs a dictionary with the output of the module and the input.""" - - def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: - return {"output": self.output(x), "shortcut": x} - - -@Link.registry.register("shortcut-concat") -class ShortcutConcat(Link): - """Concatenates the input and the output of the module.""" - - def forward(self, x: torch.Tensor) -> torch.Tensor: - intermediate_output = self.output(x) - return torch.cat((x, intermediate_output), dim=1) diff --git a/tests/unit/torch/test_block.py b/tests/unit/torch/test_block.py index a2d2d9b627..ae950045df 100644 --- a/tests/unit/torch/test_block.py +++ b/tests/unit/torch/test_block.py @@ -20,9 +20,15 @@ from torch import nn import merlin.models.torch as mm -from merlin.models.torch import link from merlin.models.torch.batch import Batch -from merlin.models.torch.block import Block, ParallelBlock, get_pre, set_pre +from merlin.models.torch.block import ( + Block, + ParallelBlock, + ResidualBlock, + ShortcutBlock, + get_pre, + set_pre, +) from merlin.models.torch.container import BlockContainer, BlockContainerDict from merlin.models.torch.utils import module_utils from merlin.schema import Tags @@ -63,8 +69,8 @@ def test_insertion(self): assert torch.equal(outputs, inputs + 2) - block.append(PlusOne(), link="residual") - assert isinstance(block[-1], link.Residual) + # block.append(PlusOne(), link="residual") + # assert isinstance(block[-1], link.Residual) def test_copy(self): block = Block(PlusOne()) @@ -89,18 +95,18 @@ def test_repeat(self): with pytest.raises(ValueError, match="n must be greater than 0"): block.repeat(0) - def test_repeat_with_link(self): - block = Block(PlusOne()) + # def test_repeat_with_link(self): + # block = Block(PlusOne()) - repeated = block.repeat(2, link="residual") - assert isinstance(repeated, Block) - assert len(repeated) == 2 - assert isinstance(repeated[-1], link.Residual) + # repeated = block.repeat(2, link="residual") + # assert isinstance(repeated, Block) + # assert len(repeated) == 2 + # assert isinstance(repeated[-1], link.Residual) - inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) - outputs = module_utils.module_test(repeated, inputs) + # inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + # outputs = module_utils.module_test(repeated, inputs) - assert torch.equal(outputs, (inputs + 1) + (inputs + 1) + 1) + # assert torch.equal(outputs, (inputs + 1) + (inputs + 1) + 1) def test_from_registry(self): @Block.registry.register("my_block") @@ -282,3 +288,29 @@ def test_input_schema_pre(self): assert input_schema == mm.schema.input(pb2) assert mm.schema.output(pb2) == mm.schema.output(pb) + + +class TestResidualBlock: + def test_forward(self): + input_tensor = torch.randn(1, 3, 64, 64) + conv = nn.Conv2d(3, 3, kernel_size=3, padding=1) + residual = ResidualBlock(conv) + + output_tensor = module_utils.module_test(residual, input_tensor) + expected_tensor = input_tensor + conv(input_tensor) + + assert torch.allclose(output_tensor, expected_tensor) + + +class TestShortcutBlock: + def test_forward(self): + input_tensor = torch.randn(1, 3, 64, 64) + conv = nn.Conv2d(3, 3, kernel_size=3, padding=1) + shortcut = ShortcutBlock(conv) + + output_dict = module_utils.module_test(shortcut, input_tensor) + + assert "output" in output_dict + assert "shortcut" in output_dict + assert torch.allclose(output_dict["output"], conv(input_tensor)) + assert torch.allclose(output_dict["shortcut"], input_tensor) diff --git a/tests/unit/torch/test_container.py b/tests/unit/torch/test_container.py index 4479cc70de..d295760001 100644 --- a/tests/unit/torch/test_container.py +++ b/tests/unit/torch/test_container.py @@ -15,13 +15,11 @@ # import pytest -import torch import torch.nn as nn import merlin.models.torch as mm -from merlin.models.torch import link from merlin.models.torch.container import BlockContainer, BlockContainerDict -from merlin.models.torch.utils import module_utils, torchscript_utils +from merlin.models.torch.utils import torchscript_utils from merlin.schema import Tags @@ -39,16 +37,6 @@ def test_append(self): assert len(self.block_container) == 1 assert self.block_container != BlockContainer(name="test_container") - def test_append_link(self): - module = nn.Linear(20, 20) - self.block_container.append(module, link="residual") - assert len(self.block_container) == 1 - - inputs = torch.randn(1, 20) - outputs = module_utils.module_test(self.block_container[0], inputs) - - assert torch.equal(inputs + module(inputs), outputs) - def test_prepend(self): module1 = nn.Linear(20, 30) module2 = nn.Linear(30, 40) @@ -57,16 +45,6 @@ def test_prepend(self): assert len(self.block_container) == 2 assert isinstance(self.block_container[0], nn.Linear) - def test_prepend_link(self): - module = nn.Linear(20, 20) - self.block_container.prepend(module, link="residual") - assert len(self.block_container) == 1 - - inputs = torch.randn(1, 20) - outputs = module_utils.module_test(self.block_container[0], inputs) - - assert torch.equal(inputs + module(inputs), outputs) - def test_insert(self): module1 = nn.Linear(20, 30) module2 = nn.Linear(30, 40) @@ -75,16 +53,6 @@ def test_insert(self): assert len(self.block_container) == 2 assert isinstance(self.block_container[0], nn.Linear) - def test_insert_link(self): - module = nn.Linear(20, 20) - self.block_container.insert(0, module, link="residual") - assert len(self.block_container) == 1 - - inputs = torch.randn(1, 20) - outputs = module_utils.module_test(self.block_container[0], inputs) - - assert torch.equal(inputs + module(inputs), outputs) - def test_len(self): module = nn.Linear(20, 30) self.block_container.append(module) @@ -192,16 +160,10 @@ def test_append_to(self): self.container.append_to("test", self.module) assert "test" in self.container._modules - self.container.append_to("test", self.module, link="residual") - assert isinstance(self.container["test"][-1], link.Residual) - def test_prepend_to(self): self.container.prepend_to("test", self.module) assert "test" in self.container._modules - self.container.prepend_to("test", self.module, link="residual") - assert isinstance(self.container["test"][0], link.Residual) - def test_append_for_each(self): container = BlockContainerDict({"a": nn.Module(), "b": nn.Module()}) @@ -216,10 +178,6 @@ def test_append_for_each(self): assert len(container["b"]) == 3 assert container["a"][-1] == container["b"][-1] - container.append_for_each(to_add, link="residual") - assert isinstance(container["a"][-1], link.Residual) - assert isinstance(container["b"][-1], link.Residual) - def test_prepend_for_each(self): container = BlockContainerDict({"a": nn.Module(), "b": nn.Module()}) @@ -233,7 +191,3 @@ def test_prepend_for_each(self): assert len(container["a"]) == 3 assert len(container["b"]) == 3 assert container["a"][0] == container["b"][0] - - container.prepend_for_each(to_add, link="residual") - assert isinstance(container["a"][0], link.Residual) - assert isinstance(container["b"][0], link.Residual) diff --git a/tests/unit/torch/test_link.py b/tests/unit/torch/test_link.py deleted file mode 100644 index 5514154184..0000000000 --- a/tests/unit/torch/test_link.py +++ /dev/null @@ -1,76 +0,0 @@ -# -# Copyright (c) 2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import torch -from torch import nn - -from merlin.models.torch.link import Link, Residual, Shortcut, ShortcutConcat -from merlin.models.torch.utils import module_utils - - -class TestResidual: - def test_forward(self): - input_tensor = torch.randn(1, 3, 64, 64) - conv = nn.Conv2d(3, 3, kernel_size=3, padding=1) - residual = Residual(conv) - - output_tensor = module_utils.module_test(residual, input_tensor) - expected_tensor = input_tensor + conv(input_tensor) - - assert torch.allclose(output_tensor, expected_tensor) - - def test_from_registry(self): - residual = Link.parse("residual") - - assert isinstance(residual, Residual) - - -class TestShortcut: - def test_forward(self): - input_tensor = torch.randn(1, 3, 64, 64) - conv = nn.Conv2d(3, 3, kernel_size=3, padding=1) - shortcut = Shortcut(conv) - - output_dict = module_utils.module_test(shortcut, input_tensor) - - assert "output" in output_dict - assert "shortcut" in output_dict - assert torch.allclose(output_dict["output"], conv(input_tensor)) - assert torch.allclose(output_dict["shortcut"], input_tensor) - - def test_from_registry(self): - shortcut = Link.parse("shortcut") - - assert isinstance(shortcut, Shortcut) - - -class TestShortcutConcat: - def test_forward(self): - input_tensor = torch.randn(1, 3, 64, 64) - conv = nn.Conv2d( - 3, 10, kernel_size=3, padding=1 - ) # Output channels are different for concatenation - shortcut_concat = ShortcutConcat(conv) - - output_tensor = module_utils.module_test(shortcut_concat, input_tensor) - expected_tensor = torch.cat((input_tensor, conv(input_tensor)), dim=1) - - assert torch.allclose(output_tensor, expected_tensor) - - def test_from_registry(self): - shortcut_concat = Link.parse("shortcut-concat") - - assert isinstance(shortcut_concat, ShortcutConcat) From ddcf3bc7c4b38c64616318c48dc25fded4cd8e0e Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Sat, 1 Jul 2023 17:32:00 +0200 Subject: [PATCH 2/6] Removing Link in favour of some new Blocks like: ResidualBlock & ShortcutBlock --- merlin/models/torch/block.py | 49 +++++++++++++++++++----------- tests/unit/torch/test_block.py | 8 +++++ tests/unit/torch/test_container.py | 2 ++ 3 files changed, 41 insertions(+), 18 deletions(-) diff --git a/merlin/models/torch/block.py b/merlin/models/torch/block.py index 77b9844464..33e9cfb155 100644 --- a/merlin/models/torch/block.py +++ b/merlin/models/torch/block.py @@ -459,15 +459,31 @@ def __init__( self.output_name = output_name def forward( - self, inputs: torch.Tensor, batch: Optional[Batch] = None + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None ) -> Dict[str, torch.Tensor]: - shortcut, output = inputs, inputs + if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]): + if self.shortcut_name not in inputs: + raise ValueError(f"Shortcut name {self.shortcut_name} not found in inputs {inputs}") + shortcut = inputs[self.shortcut_name] + else: + shortcut = inputs + + output = inputs for module in self.values: - if getattr(module, "accepts_dict", False): - module_output = module(self._create_dict(shortcut, output), batch=batch) + if getattr(module, "accepts_dict", False) or hasattr(module, "values"): + if torch.jit.isinstance(output, Dict[str, torch.Tensor]): + module_output = module(output, batch=batch) + else: + to_pass: Dict[str, torch.Tensor] = { + self.shortcut_name: shortcut, + self.output_name: torch.jit.annotate(torch.Tensor, output), + } + + module_output = module(to_pass, batch=batch) + if torch.jit.isinstance(module_output, torch.Tensor): output = module_output - elif isinstance(module_output, Dict[str, torch.Tensor]): + elif torch.jit.isinstance(module_output, Dict[str, torch.Tensor]): output = module_output[self.output_name] else: raise ValueError( @@ -475,22 +491,19 @@ def forward( f"with key {self.output_name}", ) else: + if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]) and torch.jit.isinstance( + output, Dict[str, torch.Tensor] + ): + output = output[self.output_name] output = module(output, batch=batch) - return self._create_dict(shortcut, output) - - def _create_dict(self, shortcut: torch.Tensor, output: torch.Tensor) -> Dict[str, torch.Tensor]: - return {self.shortcut_name: shortcut, self.output_name: output} - - -class CrossBlock(Block): - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - x0 = inputs - current = inputs - for module in self.values: - current = x0 * module(current) + current + to_return = {self.shortcut_name: shortcut} + if torch.jit.isinstance(output, Dict[str, torch.Tensor]): + to_return.update(output) + else: + to_return[self.output_name] = output - return current + return to_return def get_pre(module: nn.Module) -> BlockContainer: diff --git a/tests/unit/torch/test_block.py b/tests/unit/torch/test_block.py index ae950045df..b6fb667d9f 100644 --- a/tests/unit/torch/test_block.py +++ b/tests/unit/torch/test_block.py @@ -314,3 +314,11 @@ def test_forward(self): assert "shortcut" in output_dict assert torch.allclose(output_dict["output"], conv(input_tensor)) assert torch.allclose(output_dict["shortcut"], input_tensor) + + def test_nesting(self): + inputs = torch.rand(5, 5) + shortcut = ShortcutBlock(ShortcutBlock(PlusOne())) + output = module_utils.module_test(shortcut, inputs) + + assert torch.equal(output["shortcut"], inputs) + assert torch.equal(output["output"], inputs + 1) diff --git a/tests/unit/torch/test_container.py b/tests/unit/torch/test_container.py index d295760001..4c8b14be9f 100644 --- a/tests/unit/torch/test_container.py +++ b/tests/unit/torch/test_container.py @@ -30,6 +30,7 @@ def setup_method(self): def test_init(self): assert isinstance(self.block_container, BlockContainer) assert self.block_container._name == "test_container" + assert self.block_container != "" def test_append(self): module = nn.Linear(20, 30) @@ -147,6 +148,7 @@ def test_init(self): assert isinstance(self.container, BlockContainerDict) assert self.container._get_name() == "test" assert isinstance(self.container.unwrap()["test"], BlockContainer) + assert self.container != "" def test_empty(self): container = BlockContainerDict() From ab0f63504fffd880b3a2f6c7ae1428baf767368d Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Sat, 1 Jul 2023 18:08:58 +0200 Subject: [PATCH 3/6] Add conversion test --- tests/unit/torch/test_block.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/unit/torch/test_block.py b/tests/unit/torch/test_block.py index b6fb667d9f..1e46e37005 100644 --- a/tests/unit/torch/test_block.py +++ b/tests/unit/torch/test_block.py @@ -322,3 +322,14 @@ def test_nesting(self): assert torch.equal(output["shortcut"], inputs) assert torch.equal(output["output"], inputs + 1) + + def test_convert(self): + block = Block(PlusOne()) + shortcut = ShortcutBlock(*block) + + assert isinstance(shortcut[0], PlusOne) + inputs = torch.rand(5, 5) + assert torch.equal( + module_utils.module_test(shortcut, inputs)["output"], + module_utils.module_test(ShortcutBlock(PlusOne()), inputs)["output"], + ) From 7c71ac84431846e5992dc8ac5f105970f296bb65 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 3 Jul 2023 10:00:53 +0200 Subject: [PATCH 4/6] Some bug fixes + 100% test-coverage for block.py --- merlin/models/torch/block.py | 31 +++++++++----- tests/unit/torch/test_block.py | 78 ++++++++++++++++++++++++++-------- 2 files changed, 82 insertions(+), 27 deletions(-) diff --git a/merlin/models/torch/block.py b/merlin/models/torch/block.py index 33e9cfb155..049a3224cf 100644 --- a/merlin/models/torch/block.py +++ b/merlin/models/torch/block.py @@ -17,7 +17,7 @@ import inspect import textwrap from copy import deepcopy -from typing import Dict, Optional, Tuple, TypeVar, Union +from typing import Dict, Final, Optional, Tuple, TypeVar, Union import torch from torch import nn @@ -26,7 +26,7 @@ from merlin.models.torch.batch import Batch from merlin.models.torch.container import BlockContainer, BlockContainerDict from merlin.models.torch.registry import registry -from merlin.models.torch.utils.traversal_utils import TraversableMixin, leaf +from merlin.models.torch.utils.traversal_utils import TraversableMixin from merlin.models.utils.registry import RegistryMixin from merlin.schema import Schema @@ -43,6 +43,7 @@ class Block(BlockContainer, RegistryMixin, TraversableMixin): """ registry = registry + is_block: Final[bool] = True def __init__(self, *module: nn.Module, name: Optional[str] = None): super().__init__(*module, name=name) @@ -350,10 +351,7 @@ def leaf(self) -> nn.Module: raise ValueError("Cannot call leaf() on a ParallelBlock with multiple branches") first = list(self.branches.values())[0] - if hasattr(first, "leaf"): - return first.leaf() - - return leaf(first) + return first.leaf() def __getitem__(self, idx: Union[slice, int, str]): if isinstance(idx, str) and idx in self.branches: @@ -451,26 +449,30 @@ def __init__( self, *module: nn.Module, name: Optional[str] = None, + propagate_shortcut: bool = False, shortcut_name: str = "shortcut", output_name: str = "output", ): super().__init__(*module, name=name) self.shortcut_name = shortcut_name self.output_name = output_name + self.propagate_shortcut = propagate_shortcut def forward( self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None ) -> Dict[str, torch.Tensor]: if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]): if self.shortcut_name not in inputs: - raise ValueError(f"Shortcut name {self.shortcut_name} not found in inputs {inputs}") + raise RuntimeError( + f"Shortcut name {self.shortcut_name} not found in inputs {inputs}" + ) shortcut = inputs[self.shortcut_name] else: shortcut = inputs output = inputs for module in self.values: - if getattr(module, "accepts_dict", False) or hasattr(module, "values"): + if self.propagate_shortcut: if torch.jit.isinstance(output, Dict[str, torch.Tensor]): module_output = module(output, batch=batch) else: @@ -486,7 +488,7 @@ def forward( elif torch.jit.isinstance(module_output, Dict[str, torch.Tensor]): output = module_output[self.output_name] else: - raise ValueError( + raise RuntimeError( f"Module {module} must return a tensor or a dict ", f"with key {self.output_name}", ) @@ -495,7 +497,16 @@ def forward( output, Dict[str, torch.Tensor] ): output = output[self.output_name] - output = module(output, batch=batch) + _output = module(output, batch=batch) + if torch.jit.isinstance(_output, torch.Tensor) or torch.jit.isinstance( + _output, Dict[str, torch.Tensor] + ): + output = _output + else: + raise RuntimeError( + f"Module {module} must return a tensor or a dict ", + f"with key {self.output_name}", + ) to_return = {self.shortcut_name: shortcut} if torch.jit.isinstance(output, Dict[str, torch.Tensor]): diff --git a/tests/unit/torch/test_block.py b/tests/unit/torch/test_block.py index 1e46e37005..ea36aaa412 100644 --- a/tests/unit/torch/test_block.py +++ b/tests/unit/torch/test_block.py @@ -69,9 +69,6 @@ def test_insertion(self): assert torch.equal(outputs, inputs + 2) - # block.append(PlusOne(), link="residual") - # assert isinstance(block[-1], link.Residual) - def test_copy(self): block = Block(PlusOne()) @@ -95,19 +92,6 @@ def test_repeat(self): with pytest.raises(ValueError, match="n must be greater than 0"): block.repeat(0) - # def test_repeat_with_link(self): - # block = Block(PlusOne()) - - # repeated = block.repeat(2, link="residual") - # assert isinstance(repeated, Block) - # assert len(repeated) == 2 - # assert isinstance(repeated[-1], link.Residual) - - # inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) - # outputs = module_utils.module_test(repeated, inputs) - - # assert torch.equal(outputs, (inputs + 1) + (inputs + 1) + 1) - def test_from_registry(self): @Block.registry.register("my_block") class MyBlock(Block): @@ -289,6 +273,22 @@ def test_input_schema_pre(self): assert input_schema == mm.schema.input(pb2) assert mm.schema.output(pb2) == mm.schema.output(pb) + def test_leaf(self): + block = ParallelBlock({"a": PlusOne()}) + + assert isinstance(block.leaf(), PlusOne) + + block.branches["b"] = PlusOne() + with pytest.raises(ValueError): + block.leaf() + + block.prepend(PlusOne()) + with pytest.raises(ValueError): + block.leaf() + + block = ParallelBlock({"a": nn.Sequential(PlusOne())}) + assert isinstance(block.leaf(), PlusOne) + class TestResidualBlock: def test_forward(self): @@ -326,10 +326,54 @@ def test_nesting(self): def test_convert(self): block = Block(PlusOne()) shortcut = ShortcutBlock(*block) + nested = ShortcutBlock(ShortcutBlock(shortcut), propagate_shortcut=True) assert isinstance(shortcut[0], PlusOne) inputs = torch.rand(5, 5) assert torch.equal( module_utils.module_test(shortcut, inputs)["output"], - module_utils.module_test(ShortcutBlock(PlusOne()), inputs)["output"], + module_utils.module_test(nested, inputs)["output"], ) + + def test_with_parallel(self): + parallel = ParallelBlock({"a": PlusOne(), "b": PlusOne()}) + shortcut = ShortcutBlock(parallel) + + inputs = torch.rand(5, 5) + + outputs = shortcut(inputs) + + outputs = module_utils.module_test(shortcut, inputs) + assert torch.equal(outputs["shortcut"], inputs) + assert torch.equal(outputs["a"], inputs + 1) + assert torch.equal(outputs["b"], inputs + 1) + + def test_propagate_shortcut(self): + class PlusOneShortcut(nn.Module): + def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: + return inputs["shortcut"] + 1 + + shortcut = ShortcutBlock(PlusOneShortcut(), propagate_shortcut=True) + shortcut = ShortcutBlock(shortcut, propagate_shortcut=True) + inputs = torch.rand(5, 5) + outputs = module_utils.module_test(shortcut, inputs) + + assert torch.equal(outputs["output"], inputs + 1) + + with pytest.raises(RuntimeError): + shortcut({"a": inputs}) + + def test_exception(self): + with_tuple = Block(PlusOneTuple()) + shortcut = ShortcutBlock(with_tuple) + + with pytest.raises(RuntimeError): + module_utils.module_test(shortcut, torch.rand(5, 5)) + + class PlusOneShortcutTuple(nn.Module): + def forward(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + return inputs["shortcut"] + 1, inputs["shortcut"] + + shortcut_propagate = ShortcutBlock(PlusOneShortcutTuple(), propagate_shortcut=True) + with pytest.raises(RuntimeError): + module_utils.module_test(shortcut_propagate, torch.rand(5, 5)) From edf5725d372626a71f2549edf468182e413fabd6 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 3 Jul 2023 10:12:24 +0200 Subject: [PATCH 5/6] Improve doc-strings --- merlin/models/torch/block.py | 51 ++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/merlin/models/torch/block.py b/merlin/models/torch/block.py index 049a3224cf..b450f17eda 100644 --- a/merlin/models/torch/block.py +++ b/merlin/models/torch/block.py @@ -445,6 +445,35 @@ def forward(self, inputs: torch.Tensor, batch: Optional[Batch] = None): class ShortcutBlock(Block): + """ + A block with a 'shortcut' or a 'skip connection'. + + The shortcut tensor can be propagated through the layers of the module or not, + depending on the value of `propagate_shortcut` argument: + If `propagate_shortcut` is True, the shortcut tensor is passed through + each layer of the module. + If `propagate_shortcut` is False, the shortcut tensor is only used as part of + the final output dictionary. + + Example usage:: + >>> shortcut = mm.ShortcutBlock(nn.Identity()) + >>> shortcut(torch.ones(1, 1)) + {'shortcut': tensor([[1.]]), 'output': tensor([[1.]])} + + Parameters + ---------- + *module : nn.Module + Variable length argument list of PyTorch modules to be contained in the block. + name : str, optional + The name of the module, by default None. + propagate_shortcut : bool, optional + If True, propagates the shortcut tensor through the layers of this block, by default False. + shortcut_name : str, optional + The name to use for the shortcut tensor, by default "shortcut". + output_name : str, optional + The name to use for the output tensor, by default "output". + """ + def __init__( self, *module: nn.Module, @@ -461,6 +490,28 @@ def __init__( def forward( self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None ) -> Dict[str, torch.Tensor]: + """ + Defines the forward propagation of the module. + + Parameters + ---------- + inputs : Union[torch.Tensor, Dict[str, torch.Tensor]] + The input tensor or a dictionary of tensors. + batch : Batch, optional + A batch of inputs, by default None. + + Returns + ------- + Dict[str, torch.Tensor] + The output tensor as a dictionary. + + Raises + ------ + RuntimeError + If the shortcut name is not found in the input dictionary, or + if the module does not return a tensor or a dictionary with a key 'output_name'. + """ + if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]): if self.shortcut_name not in inputs: raise RuntimeError( From e3114ff90567ba182750242a4a2121bb72f4219e Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 3 Jul 2023 10:18:45 +0200 Subject: [PATCH 6/6] Remove un-used isblock again --- merlin/models/torch/block.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/merlin/models/torch/block.py b/merlin/models/torch/block.py index b450f17eda..42dede5b9b 100644 --- a/merlin/models/torch/block.py +++ b/merlin/models/torch/block.py @@ -17,7 +17,7 @@ import inspect import textwrap from copy import deepcopy -from typing import Dict, Final, Optional, Tuple, TypeVar, Union +from typing import Dict, Optional, Tuple, TypeVar, Union import torch from torch import nn @@ -43,7 +43,6 @@ class Block(BlockContainer, RegistryMixin, TraversableMixin): """ registry = registry - is_block: Final[bool] = True def __init__(self, *module: nn.Module, name: Optional[str] = None): super().__init__(*module, name=name)