Skip to content

Commit

Permalink
Adding some extra methods to make working with blocks easier (#1154)
Browse files Browse the repository at this point in the history
* First commit

* Adding leaf to traversal_utils

* Add some doc-strings

* Adding extra test for leaf

* Some fixes
  • Loading branch information
marcromeyn authored Jun 26, 2023
1 parent aa501f0 commit 7b22df2
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 107 deletions.
29 changes: 26 additions & 3 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import inspect
import textwrap
from copy import deepcopy
from typing import Dict, Optional, TypeVar, Union
from typing import Dict, Optional, Tuple, TypeVar, Union

import torch
from torch import nn
Expand All @@ -27,11 +27,12 @@
from merlin.models.torch.container import BlockContainer, BlockContainerDict
from merlin.models.torch.link import Link, LinkType
from merlin.models.torch.registry import registry
from merlin.models.torch.utils.traversal_utils import TraversableMixin, leaf
from merlin.models.utils.registry import RegistryMixin
from merlin.schema import Schema


class Block(BlockContainer, RegistryMixin):
class Block(BlockContainer, RegistryMixin, TraversableMixin):
"""A base-class that calls it's modules sequentially.
Parameters
Expand Down Expand Up @@ -113,6 +114,15 @@ def copy(self) -> "Block":
"""
return deepcopy(self)

@torch.jit.ignore
def select(self, selection: schema.Selection) -> "Block":
return _select_block(self, selection)

@torch.jit.ignore
def extract(self, selection: schema.Selection) -> Tuple[nn.Module, nn.Module]:
selected = self.select(selection)
return _extract_block(self, selection, selected), selected


class ParallelBlock(Block):
"""A base-class that calls its modules in parallel.
Expand Down Expand Up @@ -338,6 +348,19 @@ def replace(self, pre=None, branches=None, post=None) -> "ParallelBlock":

return output

def leaf(self) -> nn.Module:
if self.pre:
raise ValueError("Cannot call leaf() on a ParallelBlock with a pre-processing stage")

if len(self.branches) != 1:
raise ValueError("Cannot call leaf() on a ParallelBlock with multiple branches")

first = list(self.branches.values())[0]
if hasattr(first, "leaf"):
return first.leaf()

return leaf(first)

def __getitem__(self, idx: Union[slice, int, str]):
if isinstance(idx, str) and idx in self.branches:
return self.branches[idx]
Expand Down Expand Up @@ -541,7 +564,7 @@ def _extract_parallel(main, selection, route, name=None):


@schema.extract.register(BlockContainer)
def _(main, selection, route, name=None):
def _extract_block(main, selection, route, name=None):
if isinstance(main, ParallelBlock):
return _extract_parallel(main, selection, route=route, name=name)

Expand Down
2 changes: 1 addition & 1 deletion merlin/models/torch/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def configure_optimizers(self):

def model_outputs(self) -> List[ModelOutput]:
"""Finds all instances of `ModelOutput` in the model."""
return module_utils.find_all_instances(self, ModelOutput)
return self.find(ModelOutput)

def first(self) -> nn.Module:
"""Returns the first block in the model."""
Expand Down
2 changes: 2 additions & 0 deletions merlin/models/torch/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def add_route(
branch = module
else:
if self.prepend_routing_module:
if not routing_module:
return self
branch = routing_module
else:
raise ValueError("Must provide a module.")
Expand Down
6 changes: 6 additions & 0 deletions merlin/models/torch/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,15 @@ def __call__(self, to_select: ToSelectT, selection: Selection) -> ToSelectT:

return output

def dispatched(self, to_select: ToSelectT, selection: Selection) -> ToSelectT:
return super().__call__(to_select, selection)


class _ExtractDispatch(LazyDispatcher):
def __call__(self, module: nn.Module, selection: Selection) -> Tuple[nn.Module, nn.Module]:
if hasattr(module, "extract"):
return module.extract(selection)

extraction = select(module, selection)
module_with_extraction = self.extract(module, selection, extraction)

Expand Down
38 changes: 1 addition & 37 deletions merlin/models/torch/utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#
import inspect
from typing import Dict, List, Tuple, Type, TypeVar, Union
from typing import Dict, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -196,42 +196,6 @@ def _all_close_dict(left, right):
raise ValueError("The outputs of the original and scripted modules are not the same")


ToSearch = TypeVar("ToSearch", bound=Type[nn.Module])


def find_all_instances(module: nn.Module, to_search: ToSearch) -> List[ToSearch]:
"""
This function searches a given PyTorch module for all the child module that
matches a specific type of a module.
Parameters
----------
module: nn.Module
The PyTorch module to search through.
to_search: ToSearch
The specific PyTorch module type or an instance to search for.
Returns
-------
List[ToSearch]
A list of all instances found in 'module' that match 'to_search'.
"""
if isinstance(to_search, nn.Module):
to_search = type(to_search)

if isinstance(module, to_search):
return [module]
elif module == to_search:
return [module]

results = []
children = module.children()
for sub_module in children:
results.extend(find_all_instances(sub_module, to_search))

return results


def initialize(module, data: Union[Dataset, Loader, Batch], dtype=torch.float32):
"""
This function is useful for initializing a PyTorch module with specific
Expand Down
174 changes: 174 additions & 0 deletions merlin/models/torch/utils/traversal_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Callable, List, Tuple, Type, TypeVar, Union

import torch
from torch import nn
from typing_extensions import Self

from merlin.models.torch import schema

ModuleType = TypeVar("ModuleType", bound=nn.Module)
PredicateFn = Callable[[ModuleType], bool]


def find(module: nn.Module, to_search: Union[PredicateFn, Type[ModuleType]]) -> List[ModuleType]:
"""
Traverse a PyTorch Module and find submodules matching a given condition.
Finding a module-type::
>>> model = nn.Sequential(nn.Linear(10, 20), nn.ReLU())
>>> find(model, nn.Linear) # find all Linear layers
[Linear(in_features=10, out_features=20, bias=True)]
Finding a module-type with a condition::
>>> model = nn.Sequential(nn.Linear(10, 20), nn.ReLU())
>>> find(model, lambda x: isinstance(x, nn.Linear) and x.out_features == 10)
[Linear(in_features=20, out_features=10, bias=True)]
Parameters
----------
module : nn.Module
The PyTorch module to traverse.
to_search : Union[Callable[[ModuleType], bool], Type[ModuleType]]
The condition to match. Can be either a subclass of nn.Module, in which case
submodules of that type are searched, or a Callable, which is applied to each
submodule and should return True for matches.
Returns
-------
List[ModuleType]
List of matching submodules.
Raises
------
ValueError
If `to_search` is neither a subclass of nn.Module nor a Callable.
"""

if isinstance(to_search, type) and issubclass(to_search, nn.Module):
predicate = lambda x: isinstance(x, to_search) # noqa: E731
elif callable(to_search):
predicate = to_search
else:
raise ValueError("to_search must be either a subclass of nn.Module or a callable.")

result = []

def apply_fn(m: nn.Module):
nonlocal result
if predicate(m):
result.append(m)

module.apply(apply_fn)

return result


def leaf(module) -> nn.Module:
"""
Recursively fetch the deepest child module.
Example usage::
>>> model = nn.Sequential(nn.Linear(10, 20))
>>> print(leaf(model))
Linear(in_features=10, out_features=20, bias=True)
Parameters
----------
module : torch.nn.Module
PyTorch module to fetch the deepest child from.
Returns
-------
torch.nn.Module
The deepest child module.
Raises
------
ValueError
If any level of the module has more than one child.
"""

from merlin.models.torch.container import BlockContainer, BlockContainerDict

containers = (
nn.Sequential,
nn.ModuleList,
nn.ModuleDict,
BlockContainerDict,
BlockContainer,
)

children = list(module.children())
if len(children) == 0 or not isinstance(module, containers):
# If no children, return the module itself (the leaf).
return module
elif len(children) == 1:
child = children[0]

if hasattr(child, "unwrap"):
child = child.unwrap()

if hasattr(child, "leaf"):
return child.leaf()
return leaf(child)
else:
# If more than one child, throw an exception.
raise ValueError(
f"Module {module} has multiple children, cannot determine the deepest child."
)


class TraversableMixin:
def find(self, to_search: Union[PredicateFn, Type[ModuleType]]) -> List[ModuleType]:
"""
Traverse the current module and find submodules matching a given condition.
Parameters
----------
to_search : Union[Callable[[ModuleType], bool], Type[ModuleType]]
The condition to match. Can be either a subclass of nn.Module, in which case
submodules of that type are searched, or a Callable, which is applied to each
submodule and should return True for matches.
Returns
-------
List[ModuleType]
List of matching submodules.
"""
return find(self, to_search)

def leaf(self) -> nn.Module:
"""
Recursively fetch the deepest child module.
Returns
-------
torch.nn.Module
The deepest child module.
"""
return leaf(self)

@torch.jit.ignore
def select(self, selection: schema.Selection) -> Self:
return schema.select.dispatched(self, selection)

@torch.jit.ignore
def extract(self, selection: schema.Selection) -> Tuple[nn.Module, nn.Module]:
extraction = schema.select(self, selection)

return schema.extract.extract(self, selection, extraction), extraction
4 changes: 2 additions & 2 deletions tests/unit/torch/inputs/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_extract_route_two_tower(self):
assert set(mm.schema.input(towers).column_names) == input_cols
assert mm.schema.output(towers).column_names == ["user", "item"]

categorical = mm.schema.select(towers, Tags.CATEGORICAL)
categorical = towers.select(Tags.CATEGORICAL)
outputs = module_utils.module_test(towers, self.batch)

assert mm.schema.extract(towers, Tags.CATEGORICAL)[1] == categorical
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_extract_route_nesting(self):
outputs = module_utils.module_test(input_block, self.batch)
assert outputs.shape == (10, 107)

no_user_id, user_id_route = mm.schema.extract(input_block, ColumnSchema("user_id"))
no_user_id, user_id_route = input_block.extract(ColumnSchema("user_id"))

assert no_user_id

Expand Down
Loading

0 comments on commit 7b22df2

Please sign in to comment.