Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing BatchBlock #1192

Merged
merged 12 commits into from
Jul 11, 2023
6 changes: 6 additions & 0 deletions merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from merlin.models.torch import schema
from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.block import (
BatchBlock,
Block,
ParallelBlock,
ResidualBlock,
Expand All @@ -42,6 +43,7 @@
)
from merlin.models.torch.outputs.regression import RegressionOutput
from merlin.models.torch.outputs.tabular import TabularOutputBlock
from merlin.models.torch.predict import DaskEncoder, DaskPredictor, EncoderBlock
from merlin.models.torch.router import RouterBlock
from merlin.models.torch.transforms.agg import Concat, Stack
from merlin.models.torch.transforms.sequences import BroadcastToSequence, TabularPadding
Expand All @@ -55,6 +57,7 @@
"Batch",
"BinaryOutput",
"Block",
"BatchBlock",
"DLRMBlock",
"MLPBlock",
"Model",
Expand Down Expand Up @@ -95,4 +98,7 @@
"CGCBlock",
"TabularPadding",
"BroadcastToSequence",
"EncoderBlock",
"DaskEncoder",
"DaskPredictor",
]
126 changes: 125 additions & 1 deletion merlin/models/torch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __init__(
def __contains__(self, name: str) -> bool:
return name in self.lengths

def __bool__(self) -> bool:
return bool(self.lengths)

def length(self, name: str = "default") -> torch.Tensor:
"""Retrieves a length tensor from a sequence by name.

Expand Down Expand Up @@ -117,6 +120,16 @@ def device(self) -> torch.device:

raise ValueError("Sequence is empty")

def flatten_to_dict(self) -> Dict[str, torch.Tensor]:
outputs: Dict[str, torch.Tensor] = {}
for key, value in self.lengths.items():
outputs["lengths." + key] = value

for key, value in self.masks.items():
outputs["masks." + key] = value

return outputs


@torch.jit.script
class Batch:
Expand Down Expand Up @@ -164,7 +177,12 @@ def __init__(
else:
raise ValueError("Targets must be a tensor or a dictionary of tensors")
self.targets: Dict[str, torch.Tensor] = _targets
self.sequences: Optional[Sequence] = sequences
if torch.jit.isinstance(sequences, Sequence):
_sequences = Sequence(sequences.lengths, sequences.masks)
else:
_masks: Dict[str, torch.Tensor] = {}
_sequences = Sequence(_masks)
self.sequences: Sequence = _sequences

@staticmethod
@torch.jit.ignore
Expand Down Expand Up @@ -277,6 +295,112 @@ def target(self, name: str = "default") -> torch.Tensor:

raise ValueError("Batch has multiple target, please specify a target name")

def inputs(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if len(self.features) == 1 and "default" in self.features:
return self.features["default"]

return self.features

def flatten_as_dict(self, inputs: Optional["Batch"]) -> Dict[str, torch.Tensor]:
"""
Flatten features, targets, and sequences into a dictionary of tensors.

Each key should be prefixed with "features.", "targets.", "masks." or "lengths."

If inputs is provided, it includes all keys that are present in both self and inputs,
with the value from self being used when a key is present in both.

Parameters
----------
inputs : Batch, optional
Another Batch object to include in the flattening process. The keys from the input
batch are also added with a prefix of "inputs.", by default None

Returns
-------
Dict[str, torch.Tensor]
A dictionary containing all the flattened features, targets, and sequences.
"""
flat_dict: Dict[str, torch.Tensor] = self._flatten()
dummy_tensor = torch.tensor(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we need the dummy_tensor variable?

Copy link
Contributor Author

@marcromeyn marcromeyn Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never use it, but we need to have it in order to keep the type of the Dict[str, torch.Tensor]. We could store the original value but that might take more memory, so that's why I added the dummy. We only care about the keys of the original inputs.


if torch.jit.isinstance(inputs, Batch) and inputs is not self:
_input_dict: Dict[str, torch.Tensor] = inputs._flatten()
for key in _input_dict:
flat_dict["inputs." + key] = dummy_tensor

return flat_dict

def _flatten(self) -> Dict[str, torch.Tensor]:
"""
Helper function to flatten features, targets, and sequences of the current batch.

Returns
-------
Dict[str, torch.Tensor]
A dictionary containing all the flattened features, targets, and sequences.
"""
flat_dict = {}

for key, value in self.features.items():
flat_dict["features." + key] = value

for key, value in self.targets.items():
flat_dict["targets." + key] = value

_sequence_dict = self.sequences.flatten_to_dict()
if _sequence_dict:
flat_dict.update(_sequence_dict)

return flat_dict

@staticmethod
def from_partial_dict(input: Dict[str, torch.Tensor], batch: "Batch") -> "Batch":
"""
The input param comes from flatten_as_dict.

It could be that certain keys are missing from the input dict, in which case
we should use the values from the batch object.
"""
features = {}
targets = {}
lengths = {}
masks = {}

for key, value in input.items():
key_split = key.split(".")
if key_split[0] == "features":
features[key_split[1]] = value
elif key_split[0] == "targets":
targets[key_split[1]] = value
elif key_split[0] == "lengths":
lengths[key_split[1]] = value
elif key_split[0] == "masks":
masks[key_split[1]] = value

# If a key is missing in the input dict and was in the inputs of flatten_as_dict,
# use the value from the batch object
for key in batch.features:
if f"inputs.features.{key}" not in input:
features[key] = batch.features[key]
for key in batch.targets:
if f"inputs.targets.{key}" not in input:
targets[key] = batch.targets[key]
if batch.sequences is not None:
for key in batch.sequences.lengths:
if f"inputs.lengths.{key}" not in input:
lengths[key] = batch.sequences.lengths[key]
for key in batch.sequences.masks:
if f"inputs.masks.{key}" not in input:
masks[key] = batch.sequences.masks[key]

if lengths or masks:
sequences = Sequence(lengths, masks)
else:
sequences = None

return Batch(features, targets, sequences)

def __bool__(self) -> bool:
return bool(self.features)

Expand Down
107 changes: 100 additions & 7 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
from torch import nn

from merlin.models.torch import schema
from merlin.models.torch.batch import Batch
from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.container import BlockContainer, BlockContainerDict
from merlin.models.torch.registry import registry
from merlin.models.torch.utils.traversal_utils import TraversableMixin
from merlin.models.utils.registry import RegistryMixin
from merlin.schema import Schema

TensorOrDict = Union[torch.Tensor, Dict[str, torch.Tensor]]


@runtime_checkable
class HasKeys(Protocol):
Expand All @@ -53,9 +55,7 @@ class Block(BlockContainer, RegistryMixin, TraversableMixin):
def __init__(self, *module: nn.Module, name: Optional[str] = None):
super().__init__(*module, name=name)

def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
):
def forward(self, inputs: TensorOrDict, batch: Optional[Batch] = None):
"""
Forward pass through the block. Applies each contained module sequentially on the input.

Expand Down Expand Up @@ -167,9 +167,7 @@ def __init__(self, *inputs: Union[nn.Module, Dict[str, nn.Module]]):
self.branches = branches
self.post = post

def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
):
def forward(self, inputs: TensorOrDict, batch: Optional[Batch] = None):
"""Forward pass through the block.

The steps are as follows:
Expand Down Expand Up @@ -210,6 +208,9 @@ def forward(
raise RuntimeError(f"Duplicate output name: {key}")

outputs.update(branch_out)
elif torch.jit.isinstance(branch_out, Batch):
_flattened_batch: Dict[str, torch.Tensor] = branch_out.flatten_as_dict(batch)
outputs.update(_flattened_batch)
else:
raise TypeError(
f"Branch output must be a tensor or a dictionary of tensors. Got {_inputs}"
Expand Down Expand Up @@ -574,6 +575,98 @@ def forward(
return to_return


class BatchBlock(Block):
"""
Class to use for `Batch` creation. We can use this class to create a `Batch` from
- a tensor or a dictionary of tensors
- a `Batch` object
- a tuple of features and targets

Example usage::
>>> batch = mm.BatchBlock()(torch.ones(1, 1))
>>> batch
Batch(features={"default": tensor([[1.]])})

"""

def forward(
self,
inputs: Union[Batch, TensorOrDict],
targets: Optional[TensorOrDict] = None,
sequences: Optional[Sequence] = None,
batch: Optional[Batch] = None,
) -> Batch:
"""
Perform forward propagation on either a Batch object, or on inputs, targets and sequences
which are then packed into a Batch.

Parameters
----------
inputs : Union[Batch, TensorOrDict]
Either a Batch object or a dictionary of tensors.

targets : Optional[TensorOrDict], optional
A dictionary of tensors, by default None

sequences : Optional[Sequence], optional
A sequence of tensors, by default None

batch : Optional[Batch], optional
A Batch object, by default None

Returns
-------
Batch
The resulting Batch after forward propagation.
"""
if torch.jit.isinstance(batch, Batch):
return self.forward_batch(batch)
if torch.jit.isinstance(inputs, Batch):
return self.forward_batch(inputs)

return self.forward_batch(Batch(inputs, targets, sequences))

def forward_batch(self, batch: Batch) -> Batch:
"""
Perform forward propagation on a Batch object.

For each module in the block, this method performs a forward pass with the
current output features and the original batch object.
- If a module returns a Batch object, this becomes the new output.
- If a module returns a dictionary of tensors, a new Batch object is created
from this dictionary and the original batch object. The new Batch replaces
the current output. This is useful when a module only modifies a subset of
the batch.


Parameters
----------
batch : Batch
A Batch object.

Returns
-------
Batch
The resulting Batch after forward propagation.

Raises
------
RuntimeError
When the output of a module is neither a Batch object nor a dictionary of tensors.
"""
output = batch
for module in self.values:
module_out = module(output.features, batch=output)
if torch.jit.isinstance(module_out, Batch):
output = module_out
elif torch.jit.isinstance(module_out, Dict[str, torch.Tensor]):
output = Batch.from_partial_dict(module_out, batch)
else:
raise RuntimeError("Module must return a Batch or a dict of tensors")

return output


def _validate_n(n: int) -> None:
if not isinstance(n, int):
raise TypeError("n must be an integer")
Expand Down
22 changes: 18 additions & 4 deletions merlin/models/torch/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from merlin.dataloader.torch import Loader
from merlin.io import Dataset
from merlin.models.torch.batch import Batch
from merlin.models.torch.block import Block
from merlin.models.torch.block import BatchBlock, Block
from merlin.models.torch.outputs.base import ModelOutput
from merlin.models.torch.utils import module_utils
from merlin.models.utils.registry import camelcase_to_snakecase
Expand Down Expand Up @@ -73,14 +73,26 @@ class Model(LightningModule, Block):
... trainer.fit(model, Loader(dataset, batch_size=16))
"""

def __init__(self, *blocks: nn.Module, initialization="auto"):
def __init__(
self,
*blocks: nn.Module,
optimizer=torch.optim.Adam,
initialization="auto",
pre: Optional[BatchBlock] = None,
):
super().__init__()

# Copied from BlockContainer.__init__
self.values = nn.ModuleList()
for module in blocks:
self.values.append(self.wrap_module(module))
self.initialization = initialization
if isinstance(pre, BatchBlock):
self.pre = pre
elif pre is None:
self.pre = BatchBlock()
else:
raise ValueError(f"Invalid pre: {pre}, must be a BatchBlock")

@property
@torch.jit.ignore
Expand Down Expand Up @@ -128,9 +140,11 @@ def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
):
"""Performs a forward pass through the model."""
outputs = inputs
_batch: Batch = self.pre(inputs, batch=batch)

outputs = _batch.inputs()
for block in self.values:
marcromeyn marked this conversation as resolved.
Show resolved Hide resolved
outputs = block(outputs, batch=batch)
outputs = block(outputs, batch=_batch)
return outputs

def training_step(self, batch, batch_idx):
Expand Down
Loading