From 509cbc6d44fa853dd1fe2a6b9e71c1f7ed1b509d Mon Sep 17 00:00:00 2001 From: sararb Date: Tue, 27 Jun 2023 16:48:01 +0000 Subject: [PATCH 1/4] first version of sequence transforms applied to Batch input --- merlin/models/torch/sequences.py | 422 ++++++++++++++++++++++++++++++ tests/unit/torch/test_sequence.py | 206 +++++++++++++++ 2 files changed, 628 insertions(+) create mode 100644 merlin/models/torch/sequences.py create mode 100644 tests/unit/torch/test_sequence.py diff --git a/merlin/models/torch/sequences.py b/merlin/models/torch/sequences.py new file mode 100644 index 0000000000..b0b5a3e952 --- /dev/null +++ b/merlin/models/torch/sequences.py @@ -0,0 +1,422 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from merlin.models.torch.batch import Batch, Sequence +from merlin.schema import ColumnSchema, Schema, Tags + +MASK_PREFIX = "__mask" + + +class TabularBatchPadding(nn.Module): + """A PyTorch module for padding tabular sequence data. + + Parameters + ---------- + schema : Schema + The schema of the tabular data, which defines the column names of input features. + max_sequence_length : Optional[int], default=None + The maximum length of the sequences after padding. + If None, sequences will be padded to the maximum length in the current batch. + + + Examples: + features = { + 'feature1': torch.tensor([[4, 3], [5, 2]), + 'feature2': torch.tensor([[3,8], [7,9]]) + } + schema = Schema(["feature1", "feature2"]) + _max_sequence_length = 10 + padding_op = TabularBatchPadding( + schema=schema, max_sequence_length=_max_sequence_length + ) + padded_batch = padding_op(Batch(feaures)) + """ + + def __init__( + self, + schema: Schema, + max_sequence_length: Optional[int] = None, + ): + super().__init__() + self.schema = schema + self.max_sequence_length = max_sequence_length + self.features: List[str] = self.schema.column_names + self.sparse_features = self.schema.select_by_tag(Tags.SEQUENCE).column_names + self.padding_idx = 0 + + def forward(self, batch: Batch) -> Batch: + _max_sequence_length = self.max_sequence_length + if not _max_sequence_length: + # Infer the maximum length from the current batch + batch_max_sequence_length = 0 + for key, val in batch.features.items(): + if key.endswith("__offsets"): + offsets = val + max_row_length = int(torch.max(offsets[1:] - offsets[:-1])) + batch_max_sequence_length = max(max_row_length, batch_max_sequence_length) + _max_sequence_length = batch_max_sequence_length + + # Store the non-padded lengths of list features + seq_inputs_lengths = self._get_sequence_lengths(batch.features) + seq_shapes = list(seq_inputs_lengths.values()) + if not all(torch.all(x == seq_shapes[0]) for x in seq_shapes): + raise ValueError( + "The sequential inputs must have the same length for each row in the batch, " + f"but they are different: {seq_shapes}" + ) + + # Pad the features of the batch + batch_padded = {} + for key, value in batch.features.items(): + if key.endswith("__offsets"): + col_name = key[: -len("__offsets")] + padded_values = self._pad_ragged_tensor( + batch.features[f"{col_name}__values"], value, _max_sequence_length + ) + batch_padded[col_name] = padded_values + elif key.endswith("__values"): + continue + else: + col_name = key + if seq_inputs_lengths.get(col_name) is not None: + # pad dense list features + batch_padded[col_name] = self._pad_dense_tensor(value, _max_sequence_length) + else: + # context features are not modified + batch_padded[col_name] = value + + # Pad targets of the batch + targets_padded = None + if batch.targets is not None: + targets_padded = {} + for key, value in batch.targets.items(): + if key.endswith("__offsets"): + col_name = key[: -len("__offsets")] + padded_values = self._pad_ragged_tensor( + batch.targets[f"{col_name}__values"], value, _max_sequence_length + ) + targets_padded[col_name] = padded_values + elif key.endswith("__values"): + continue + else: + targets_padded[key] = value + # TODO: do we store lengths of sequential targets features too? + return Batch( + features=batch_padded, targets=targets_padded, sequences=Sequence(seq_inputs_lengths) + ) + + def _get_sequence_lengths(self, sequences: Dict[str, torch.Tensor]): + """Compute the effective length of each sequence in a dictionary of sequences.""" + seq_inputs_lengths = {} + for key, val in sequences.items(): + if key.endswith("__offsets"): + seq_inputs_lengths[key[: -len("__offsets")]] = val[1:] - val[:-1] + elif key in self.sparse_features: + seq_inputs_lengths[key] = (val != self.padding_idx).sum(-1) + return seq_inputs_lengths + + def _squeeze(self, tensor: torch.Tensor): + """Squeeze a tensor of shape (N,1) to shape (N).""" + if len(tensor.shape) == 2: + return tensor.squeeze(1) + return tensor + + def _get_indices(self, offsets: torch.Tensor, diff_offsets: torch.Tensor): + """Compute indices for a sparse tensor from offsets and their differences.""" + row_ids = torch.arange(len(offsets) - 1, device=offsets.device) + row_ids_repeated = torch.repeat_interleave(row_ids, diff_offsets) + row_offset_repeated = torch.repeat_interleave(offsets[:-1], diff_offsets) + col_ids = ( + torch.arange(len(row_offset_repeated), device=offsets.device) - row_offset_repeated + ) + indices = torch.cat([row_ids_repeated.unsqueeze(-1), col_ids.unsqueeze(-1)], dim=1) + return indices + + def _pad_ragged_tensor(self, values: torch.Tensor, offsets: torch.Tensor, padding_length: int): + """Pad a ragged features represented by "values" and "offsets" to a dense tensor + of length `padding_length`. + """ + values = self._squeeze(values) + offsets = self._squeeze(offsets) + num_rows = len(offsets) - 1 + diff_offsets = offsets[1:] - offsets[:-1] + max_length = int(diff_offsets.max()) + indices = self._get_indices(offsets, diff_offsets) + sparse_tensor = torch.sparse_coo_tensor( + indices.T, values, torch.Size([num_rows, max_length]), device=values.device + ) + + return self._pad_dense_tensor(sparse_tensor.to_dense(), padding_length) + + def _pad_dense_tensor(self, t: torch.Tensor, length: int) -> torch.Tensor: + """Pad a dense tensor along its second dimension to a specified length.""" + if len(t.shape) == 2: + pad_diff = length - t.shape[1] + return F.pad(input=t, pad=(0, pad_diff, 0, 0)) + return t + + +class TabularSequenceTransform(nn.Module): + """Base class for preparing targets from a batch of sequential inputs. + Parameters + ---------- + schema : Schema + The schema with the sequential columns to be truncated + target : Union[str, Tags, ColumnSchema, Schema] + The sequential input column that will be used to extract the target. + For multiple targets usecase, one should provide a Schema containing + all target columns. + """ + + def __init__( + self, + schema: Schema, + target: Union[str, Tags, ColumnSchema, Schema], # TODO: multiple-targets support? + ): + super().__init__() + self.schema = schema + self.target = target + self.target_name = self._get_target(target) + self.padding_idx = 0 + + def _get_target(self, target): + if ( + (isinstance(target, str) and target not in self.schema.column_names) + or (isinstance(target, Tags) and len(self.schema.select_by_tag(target)) > 0) + or (isinstance(target, ColumnSchema) and target not in self.schema) + ): + raise ValueError("The target column needs to be part of the provided sequential schema") + + target_name = target + if isinstance(target, ColumnSchema): + target_name = [target.name] + if isinstance(target, Tags): + if len(self.schema.select_by_tag(target)) > 1: + raise ValueError( + "Only 1 column should the Tag ({target}) provided for target, but" + f"the following columns have that tag: " + f"{self.schema.select_by_tag(target).column_names}" + ) + target_name = self.schema.select_by_tag(target).column_names + if isinstance(target, Schema): + target_name = target.column_names + if isinstance(target, str): + target_name = [target] + return target_name + + def forward(self, inputs: Batch, **kwargs) -> Tuple: + raise NotImplementedError() + + def _check_seq_inputs_targets(self, batch: Batch): + self._check_input_sequence_lengths(batch) + self._check_target_shape(batch) + + def _check_target_shape(self, batch): + for name in self.target_name: + if name not in batch.features: + raise ValueError(f"Inputs features do not contain target column ({name})") + + target = batch.features[name] + if target.ndim < 2: + raise ValueError( + f"Sequential target column ({name}) " + f"must be a 2D tensor, but shape is {target.ndim}" + ) + lengths = batch.sequences.length(name) + if any(lengths <= 1): + raise ValueError( + f"2nd dim of target column ({name})" + "must be greater than 1 for sequential input to be shifted as target" + ) + + def _check_input_sequence_lengths(self, batch): + if batch.sequences is None: + raise ValueError( + "The input `batch` should include information about input sequences lengths" + ) + sequence_lengths = torch.stack( + [batch.sequences.length(name) for name in self.schema.column_names] + ) + assert torch.all(sequence_lengths.eq(sequence_lengths[0])), ( + "All tabular sequence features need to have the same sequence length, " + f"found {sequence_lengths}" + ) + + +class TabularPredictNext(TabularSequenceTransform): + """Prepares sequential inputs and targets for next-item prediction. + The target is extracted from the shifted sequence of the target feature and + the sequential input features are truncated in the last position. + + Parameters + ---------- + schema : Schema + The schema with the sequential columns to be truncated + target : Union[str, List[str], Tags, ColumnSchema, Schema] + The sequential input column(s) that will be used to extract the target. + Targets can be one or multiple input features with the same sequence length. + + Examples: + transform = TabularPredictNext( + schema=schema.select_by_tag(Tags.SEQUENCE), target="a" + ) + batch_output = transform(padded_batch) + + """ + + def _generate_causal_mask(self, seq_lengths, max_len): + """ + Generate a 2D mask from a tensor of sequence lengths. + """ + return torch.arange(max_len)[None, :] < seq_lengths[:, None] + + def forward(self, batch: Batch, **kwargs) -> Tuple: + self._check_seq_inputs_targets(batch) + + # Shifts the target column to be the next item of corresponding input column + targets = batch.targets + new_targets = {} + for name in self.target_name: + new_target = batch.features[name] + new_target = new_target[:, 1:] + new_targets[name] = new_target + if targets is None: + targets = new_targets + elif isinstance(targets, dict): + targets.update(new_targets) + else: + raise ValueError("Targets should be None or a dict of tensors") + + # Removes the last item of the sequence, as it belongs to the target + new_inputs = dict() + for k, v in batch.features.items(): + if k in self.schema.column_names: + new_inputs[k] = v[:, :-1] + else: + new_inputs[k] = v + + # Generates information about new lengths and causal masks + new_lengths, causal_mask = batch.sequences.lengths, batch.sequences.masks + if causal_mask is None: + causal_mask = {} + _max_length = new_target.shape[-1] # all new targets have same output sequence length + for name in self.schema.column_names: + new_lengths[name] = new_lengths[name] - 1 + causal_mask[name] = self._generate_causal_mask(new_lengths[name], _max_length) + + return Batch( + features=new_inputs, + targets=targets, + sequences=Sequence(new_lengths, masks=causal_mask), + ) + + +class TabularMaskRandom(TabularSequenceTransform): + """This transform implements the Masked Language Modeling (MLM) training approach + introduced in BERT (NLP) and later adapted to RecSys by BERT4Rec [1]. + Given an input `Batch` with input features including the sequence of candidates ids, + some positions are randomly selected (masked) to be the targets for prediction. + The targets are output being the same as the input candidates ids sequence. + The target masks are returned within the `Bathc.Sequence` object. + + References + ---------- + .. [1] Sun, Fei, et al. "BERT4Rec: Sequential recommendation with bidirectional encoder + representations from transformer." Proceedings of the 28th ACM international + conference on information and knowledge management. 2019. + + Parameters + ---------- + schema : Schema + The schema with the sequential inputs to be masked + target : Union[str, List[str], Tags, ColumnSchema, Schema] + The sequential input column(s) that will be used to compute the masked positions. + Targets can be one or multiple input features with the same sequence length. + masking_prob : float, optional + Probability of a candidate to be selected (masked) as a label of the given sequence. + Note: We enforce that at least one candidate is masked for each sequence, so that it + is useful for training, by default 0.2 + """ + + def __init__( + self, + schema: Schema, + target: Union[str, Tags, ColumnSchema], + masking_prob: float = 0.2, + **kwargs, + ): + self.masking_prob = masking_prob + super().__init__(schema, target, **kwargs) + + def forward(self, batch: Batch, **kwargs) -> Tuple: + self._check_seq_inputs_targets(batch) + + new_targets = dict({name: torch.clone(batch.features[name]) for name in self.target_name}) + targets = batch.targets + if targets is None: + targets = new_targets + elif isinstance(targets, dict): + targets.update(new_targets) + else: + raise ValueError("Targets should be None or a dict of tensors") + + # Generates mask information for the group of input sequences + target_mask = self._generate_random_mask(new_targets[self.target_name[0]]) + random_mask = batch.sequences.masks + if random_mask is None: + random_mask = {} + for name in self.schema.column_names: + random_mask[name] = target_mask + + return Batch( + features=batch.features, + targets=targets, + sequences=Sequence(batch.sequences.lengths, masks=random_mask), + ) + + def _generate_random_mask(self, new_target: torch.Tensor) -> torch.Tensor: + """Generate mask information at random positions from a 2D target sequence""" + + non_padded_mask = new_target != self.padding_idx + rows_ids = torch.arange(new_target.size(0), dtype=torch.long, device=new_target.device) + + # 1. Selects a percentage of non-padded candidates to be masked (selected as targets) + probability_matrix = torch.full( + new_target.shape, self.masking_prob, device=new_target.device + ) + mask_targets = torch.bernoulli(probability_matrix).bool() & non_padded_mask + + # 2. Set at least one candidate in the sequence to mask, so that the network + # can learn something with this session + one_random_index_by_row = torch.multinomial( + non_padded_mask.float(), num_samples=1 + ).squeeze() + mask_targets[rows_ids, one_random_index_by_row] = True + + # 3. If a sequence has only masked targets, unmasks one of the targets + sequences_with_only_labels = mask_targets.sum(dim=1) == non_padded_mask.sum(dim=1) + sampled_targets_to_unmask = torch.multinomial(mask_targets.float(), num_samples=1).squeeze() + targets_to_unmask = torch.masked_select( + sampled_targets_to_unmask, sequences_with_only_labels + ) + rows_to_unmask = torch.masked_select(rows_ids, sequences_with_only_labels) + mask_targets[rows_to_unmask, targets_to_unmask] = False + return mask_targets diff --git a/tests/unit/torch/test_sequence.py b/tests/unit/torch/test_sequence.py new file mode 100644 index 0000000000..04b47c1c7f --- /dev/null +++ b/tests/unit/torch/test_sequence.py @@ -0,0 +1,206 @@ +import re +from itertools import accumulate + +import pytest +import torch + +from merlin.models.torch.batch import Batch, Sequence +from merlin.models.torch.sequences import ( + TabularBatchPadding, + TabularMaskRandom, + TabularPredictNext, + TabularSequenceTransform, +) +from merlin.schema import ColumnSchema, Schema, Tags + + +def _get_values_offsets(data): + values = [] + row_lengths = [] + for row in data: + row_lengths.append(len(row)) + values += row + offsets = [0] + list(accumulate(row_lengths)) + return torch.tensor(values), torch.tensor(offsets) + + +class TestPadBatch: + @pytest.fixture + def sequence_batch(self): + a_values, a_offsets = _get_values_offsets(data=[[1, 2], [], [3, 4, 5]]) + b_values, b_offsets = _get_values_offsets([[34, 30], [], [33, 23, 50]]) + features = { + "a__values": a_values, + "a__offsets": a_offsets, + "b__values": b_values, + "b__offsets": b_offsets, + "c_dense": torch.Tensor([[1, 2, 0], [0, 0, 0], [4, 5, 6]]), + "d_context": torch.Tensor([1, 2, 3]), + } + targets = None + return Batch(features, targets) + + @pytest.fixture + def sequence_schema(self): + return Schema( + [ + ColumnSchema("a", tags=[Tags.SEQUENCE]), + ColumnSchema("b", tags=[Tags.SEQUENCE]), + ColumnSchema("c_dense", tags=[Tags.SEQUENCE]), + ColumnSchema("d_context", tags=[Tags.CONTEXT]), + ] + ) + + def test_padded_features(self, sequence_batch, sequence_schema): + _max_sequence_length = 8 + padding_op = TabularBatchPadding( + schema=sequence_schema, max_sequence_length=_max_sequence_length + ) + padded_batch = padding_op(sequence_batch) + + assert torch.equal(padded_batch.sequences.length("a"), torch.Tensor([2, 0, 3])) + assert set(padded_batch.features.keys()) == set(sequence_schema.column_names) + for feature in ["a", "b", "c_dense"]: + assert padded_batch.features[feature].shape[1] == _max_sequence_length + assert torch.equal(padded_batch.features["d_context"], sequence_batch.features["d_context"]) + + def test_batch_invalid_lengths(self): + # Test when targets is not a tensor nor a dictionary of tensors + a_values, a_offsets = _get_values_offsets(data=[[1, 2], [], [3, 4, 5]]) + b_values, b_offsets = _get_values_offsets([[34], [23, 56], [33, 23, 50, 4]]) + + with pytest.raises( + ValueError, + match="The sequential inputs must have the same length for each row in the batch", + ): + padding_op = TabularBatchPadding(schema=Schema(["a", "b"])) + padding_op( + Batch( + { + "a__values": a_values, + "a__offsets": a_offsets, + "b__values": b_values, + "b__offsets": b_offsets, + } + ) + ) + + def test_padded_targets(self, sequence_batch, sequence_schema): + _max_sequence_length = 8 + target_values, target_offsets = _get_values_offsets([[10, 11], [], [12, 13, 14]]) + sequence_batch.targets = { + "target_1": torch.Tensor([3, 4, 6]), + "target_2__values": target_values, + "target_2__offsets": target_offsets, + } + padding_op = TabularBatchPadding( + schema=sequence_schema, max_sequence_length=_max_sequence_length + ) + padded_batch = padding_op(sequence_batch) + + assert padded_batch.targets["target_2"].shape[1] == _max_sequence_length + assert torch.equal(padded_batch.targets["target_1"], sequence_batch.targets["target_1"]) + + +class TestTabularSequenceTransform: + @pytest.fixture + def sequence_batch(self): + a_values, a_offsets = _get_values_offsets(data=[[1, 2, 3], [3, 6], [3, 4, 5, 6]]) + b_values, b_offsets = _get_values_offsets([[34, 30, 31], [30, 31], [33, 23, 50, 51]]) + features = { + "a__values": a_values, + "a__offsets": a_offsets, + "b__values": b_values, + "b__offsets": b_offsets, + "c_dense": torch.Tensor([[1, 2, 3, 0], [5, 6, 0, 0], [4, 5, 6, 7]]), + "d_context": torch.Tensor([1, 2, 3, 4]), + } + targets = None + return Batch(features, targets) + + @pytest.fixture + def sequence_schema(self): + return Schema( + [ + ColumnSchema("a", tags=[Tags.SEQUENCE]), + ColumnSchema("b", tags=[Tags.SEQUENCE]), + ColumnSchema("c_dense", tags=[Tags.SEQUENCE]), + ColumnSchema("d_context", tags=[Tags.CONTEXT]), + ] + ) + + @pytest.fixture + def padded_batch(self, sequence_schema, sequence_batch): + _max_sequence_length = 5 + padding_op = TabularBatchPadding( + schema=sequence_schema, max_sequence_length=_max_sequence_length + ) + return padding_op(sequence_batch) + + def test_tabular_sequence_transform_wrong_inputs(self, padded_batch, sequence_schema): + transform = TabularSequenceTransform( + schema=sequence_schema.select_by_tag(Tags.SEQUENCE), target="a" + ) + with pytest.raises( + ValueError, + match="The input `batch` should include information about input sequences lengths", + ): + transform._check_input_sequence_lengths(Batch(padded_batch.features["b"])) + + with pytest.raises( + ValueError, + match="Inputs features do not contain target column", + ): + transform._check_target_shape(Batch(padded_batch.features["b"])) + + with pytest.raises( + ValueError, match="must be greater than 1 for sequential input to be shifted as target" + ): + transform._check_target_shape( + Batch( + {"a": torch.Tensor([[1, 2], [1, 0], [3, 4]])}, + sequences=Sequence(lengths={"a": torch.Tensor([2, 1, 2])}), + ) + ) + + with pytest.raises( + ValueError, + match=re.escape( + "Sequential target column (d_context) must be a 2D tensor, but shape is 1" + ), + ): + transform = TabularSequenceTransform(schema=sequence_schema, target="d_context") + transform._check_target_shape(padded_batch) + + def test_transform_predict_next(self, padded_batch, sequence_schema): + transform = TabularPredictNext( + schema=sequence_schema.select_by_tag(Tags.SEQUENCE), target="a" + ) + assert transform.target_name == ["a"] + + batch_output = transform(padded_batch) + + assert list(batch_output.features.keys()) == ["a", "b", "c_dense", "d_context"] + for k in ["a", "b", "c_dense"]: + assert torch.equal(batch_output.features[k], padded_batch.features[k][:, :-1]) + assert torch.equal(batch_output.features["d_context"], padded_batch.features["d_context"]) + assert torch.equal(batch_output.sequences.length("a"), torch.Tensor([2, 1, 3])) + + def test_transform_mask_random(self, padded_batch, sequence_schema): + transform = TabularMaskRandom( + schema=sequence_schema.select_by_tag(Tags.SEQUENCE), target="a" + ) + assert transform.target_name == ["a"] + + batch_output = transform(padded_batch) + + assert list(batch_output.features.keys()) == ["a", "b", "c_dense", "d_context"] + for name in ["a", "b", "c_dense", "d_context"]: + assert torch.equal(batch_output.features[name], padded_batch.features[name]) + assert torch.equal(batch_output.sequences.length("a"), torch.Tensor([3, 2, 4])) + + # check not all candidates are masked + pad_mask = padded_batch.features["a"] != 0 + assert torch.all(batch_output.sequences.mask("a").sum(1) != pad_mask.sum(1)) + # check that at least one candidate is masked + assert torch.all(batch_output.sequences.mask("a").sum(1) > 0) From eabad772d5cab132ce4f5d0aa2593cdb0d7b0181 Mon Sep 17 00:00:00 2001 From: sararb Date: Wed, 28 Jun 2023 21:15:08 +0000 Subject: [PATCH 2/4] updates output of tabular transforms and add TabularMaskLast block --- merlin/models/torch/sequences.py | 181 +++++++++++++++++------------- tests/unit/torch/test_sequence.py | 65 ++++++----- 2 files changed, 141 insertions(+), 105 deletions(-) diff --git a/merlin/models/torch/sequences.py b/merlin/models/torch/sequences.py index b0b5a3e952..5f75cb227f 100644 --- a/merlin/models/torch/sequences.py +++ b/merlin/models/torch/sequences.py @@ -20,12 +20,13 @@ from torch import nn from merlin.models.torch.batch import Batch, Sequence +from merlin.models.torch.schema import Selection, select from merlin.schema import ColumnSchema, Schema, Tags MASK_PREFIX = "__mask" -class TabularBatchPadding(nn.Module): +class TabularPadding(nn.Module): """A PyTorch module for padding tabular sequence data. Parameters @@ -48,6 +49,10 @@ class TabularBatchPadding(nn.Module): schema=schema, max_sequence_length=_max_sequence_length ) padded_batch = padding_op(Batch(feaures)) + + Note: + - If the schema includes continuous list features, please make sure they are normalized between [0,1] + As we will pad them to `max_sequence_length` using the minimum value `0.0`. """ def __init__( @@ -88,20 +93,18 @@ def forward(self, batch: Batch) -> Batch: for key, value in batch.features.items(): if key.endswith("__offsets"): col_name = key[: -len("__offsets")] - padded_values = self._pad_ragged_tensor( - batch.features[f"{col_name}__values"], value, _max_sequence_length - ) - batch_padded[col_name] = padded_values + if col_name in self.features: + padded_values = self._pad_ragged_tensor( + batch.features[f"{col_name}__values"], value, _max_sequence_length + ) + batch_padded[col_name] = padded_values elif key.endswith("__values"): continue else: col_name = key - if seq_inputs_lengths.get(col_name) is not None: + if col_name in self.features and seq_inputs_lengths.get(col_name) is not None: # pad dense list features batch_padded[col_name] = self._pad_dense_tensor(value, _max_sequence_length) - else: - # context features are not modified - batch_padded[col_name] = value # Pad targets of the batch targets_padded = None @@ -189,40 +192,26 @@ class TabularSequenceTransform(nn.Module): def __init__( self, schema: Schema, - target: Union[str, Tags, ColumnSchema, Schema], # TODO: multiple-targets support? + target: Selection, + apply_padding: bool = True, + max_sequence_length: int = None, ): super().__init__() self.schema = schema - self.target = target - self.target_name = self._get_target(target) + self.features = schema.column_names + self.target = select(self.schema, target) + self.target_name = self._get_target(self.target) self.padding_idx = 0 + self.apply_padding = apply_padding + if self.apply_padding: + self.padding_operator = TabularPadding( + schema=self.schema, max_sequence_length=max_sequence_length + ) def _get_target(self, target): - if ( - (isinstance(target, str) and target not in self.schema.column_names) - or (isinstance(target, Tags) and len(self.schema.select_by_tag(target)) > 0) - or (isinstance(target, ColumnSchema) and target not in self.schema) - ): - raise ValueError("The target column needs to be part of the provided sequential schema") - - target_name = target - if isinstance(target, ColumnSchema): - target_name = [target.name] - if isinstance(target, Tags): - if len(self.schema.select_by_tag(target)) > 1: - raise ValueError( - "Only 1 column should the Tag ({target}) provided for target, but" - f"the following columns have that tag: " - f"{self.schema.select_by_tag(target).column_names}" - ) - target_name = self.schema.select_by_tag(target).column_names - if isinstance(target, Schema): - target_name = target.column_names - if isinstance(target, str): - target_name = [target] - return target_name - - def forward(self, inputs: Batch, **kwargs) -> Tuple: + return target.column_names + + def forward(self, batch: Batch, **kwargs) -> Tuple: raise NotImplementedError() def _check_seq_inputs_targets(self, batch: Batch): @@ -252,9 +241,7 @@ def _check_input_sequence_lengths(self, batch): raise ValueError( "The input `batch` should include information about input sequences lengths" ) - sequence_lengths = torch.stack( - [batch.sequences.length(name) for name in self.schema.column_names] - ) + sequence_lengths = torch.stack([batch.sequences.length(name) for name in self.features]) assert torch.all(sequence_lengths.eq(sequence_lengths[0])), ( "All tabular sequence features need to have the same sequence length, " f"found {sequence_lengths}" @@ -288,44 +275,34 @@ def _generate_causal_mask(self, seq_lengths, max_len): """ return torch.arange(max_len)[None, :] < seq_lengths[:, None] - def forward(self, batch: Batch, **kwargs) -> Tuple: + def forward(self, batch: Batch, **kwargs) -> Batch: + if self.apply_padding: + batch = self.padding_operator(batch) self._check_seq_inputs_targets(batch) - # Shifts the target column to be the next item of corresponding input column - targets = batch.targets new_targets = {} for name in self.target_name: new_target = batch.features[name] new_target = new_target[:, 1:] new_targets[name] = new_target - if targets is None: - targets = new_targets - elif isinstance(targets, dict): - targets.update(new_targets) - else: - raise ValueError("Targets should be None or a dict of tensors") # Removes the last item of the sequence, as it belongs to the target new_inputs = dict() for k, v in batch.features.items(): - if k in self.schema.column_names: + if k in self.features: new_inputs[k] = v[:, :-1] - else: - new_inputs[k] = v # Generates information about new lengths and causal masks - new_lengths, causal_mask = batch.sequences.lengths, batch.sequences.masks - if causal_mask is None: - causal_mask = {} + new_lengths, causal_masks = {}, {} _max_length = new_target.shape[-1] # all new targets have same output sequence length - for name in self.schema.column_names: - new_lengths[name] = new_lengths[name] - 1 - causal_mask[name] = self._generate_causal_mask(new_lengths[name], _max_length) + for name in self.features: + new_lengths[name] = batch.sequences.lengths[name] - 1 + causal_masks[name] = self._generate_causal_mask(new_lengths[name], _max_length) return Batch( features=new_inputs, - targets=targets, - sequences=Sequence(new_lengths, masks=causal_mask), + targets=new_targets, + sequences=Sequence(new_lengths, masks=causal_masks), ) @@ -367,32 +344,24 @@ def __init__( super().__init__(schema, target, **kwargs) def forward(self, batch: Batch, **kwargs) -> Tuple: + if self.apply_padding: + batch = self.padding_operator(batch) self._check_seq_inputs_targets(batch) - new_targets = dict({name: torch.clone(batch.features[name]) for name in self.target_name}) - targets = batch.targets - if targets is None: - targets = new_targets - elif isinstance(targets, dict): - targets.update(new_targets) - else: - raise ValueError("Targets should be None or a dict of tensors") + new_inputs = {feat: batch.features[feat] for feat in self.features} + sequence_lengths = {feat: batch.sequences.length(feat) for feat in self.features} # Generates mask information for the group of input sequences - target_mask = self._generate_random_mask(new_targets[self.target_name[0]]) - random_mask = batch.sequences.masks - if random_mask is None: - random_mask = {} - for name in self.schema.column_names: - random_mask[name] = target_mask + target_mask = self._generate_mask(new_targets[self.target_name[0]]) + random_mask = {name: target_mask for name in self.features} return Batch( - features=batch.features, - targets=targets, - sequences=Sequence(batch.sequences.lengths, masks=random_mask), + features=new_inputs, + targets=new_targets, + sequences=Sequence(sequence_lengths, masks=random_mask), ) - def _generate_random_mask(self, new_target: torch.Tensor) -> torch.Tensor: + def _generate_mask(self, new_target: torch.Tensor) -> torch.Tensor: """Generate mask information at random positions from a 2D target sequence""" non_padded_mask = new_target != self.padding_idx @@ -420,3 +389,59 @@ def _generate_random_mask(self, new_target: torch.Tensor) -> torch.Tensor: rows_to_unmask = torch.masked_select(rows_ids, sequences_with_only_labels) mask_targets[rows_to_unmask, targets_to_unmask] = False return mask_targets + + +class TabularMaskLast(TabularSequenceTransform): + """This transform copies one of the sequence input features to be + the target feature. The last item of the target sequence is selected (masked) + to be predicted. + The target masks are returned by copying the related input features. + + + Parameters + ---------- + schema : Schema + The schema with the sequential inputs to be masked + target : Union[str, List[str], Tags, ColumnSchema, Schema] + The sequential input column(s) that will be used to compute the masked positions. + Targets can be one or multiple input features with the same sequence length. + """ + + def __init__( + self, + schema: Schema, + target: Union[str, Tags, ColumnSchema], + masking_prob: float = 0.2, + **kwargs, + ): + self.masking_prob = masking_prob + super().__init__(schema, target, **kwargs) + + def forward(self, batch: Batch, **kwargs) -> Tuple: + if self.apply_padding: + batch = self.padding_operator(batch) + self._check_seq_inputs_targets(batch) + new_targets = dict({name: torch.clone(batch.features[name]) for name in self.target_name}) + new_inputs = {feat: batch.features[feat] for feat in self.features} + sequence_lengths = {feat: batch.sequences.length(feat) for feat in self.features} + + # Generates mask information for the group of input sequences + target_mask = self._generate_mask(new_targets[self.target_name[0]]) + masks = {name: target_mask for name in self.features} + + return Batch( + features=new_inputs, + targets=new_targets, + sequences=Sequence(sequence_lengths, masks=masks), + ) + + def _generate_mask(self, new_target: torch.Tensor) -> torch.Tensor: + """Generate mask information at last positions from a 2D target sequence""" + target_mask = new_target != self.padding_idx + last_non_padded_indices = (target_mask.sum(dim=1) - 1).unsqueeze(-1) + + mask_targets = ( + torch.arange(target_mask.size(1), device=target_mask.device).unsqueeze(0) + == last_non_padded_indices + ) + return mask_targets diff --git a/tests/unit/torch/test_sequence.py b/tests/unit/torch/test_sequence.py index 04b47c1c7f..be5440fa12 100644 --- a/tests/unit/torch/test_sequence.py +++ b/tests/unit/torch/test_sequence.py @@ -6,8 +6,9 @@ from merlin.models.torch.batch import Batch, Sequence from merlin.models.torch.sequences import ( - TabularBatchPadding, + TabularMaskLast, TabularMaskRandom, + TabularPadding, TabularPredictNext, TabularSequenceTransform, ) @@ -53,16 +54,15 @@ def sequence_schema(self): def test_padded_features(self, sequence_batch, sequence_schema): _max_sequence_length = 8 - padding_op = TabularBatchPadding( + padding_op = TabularPadding( schema=sequence_schema, max_sequence_length=_max_sequence_length ) padded_batch = padding_op(sequence_batch) assert torch.equal(padded_batch.sequences.length("a"), torch.Tensor([2, 0, 3])) - assert set(padded_batch.features.keys()) == set(sequence_schema.column_names) + assert set(padded_batch.features.keys()) == set(["a", "b", "c_dense"]) for feature in ["a", "b", "c_dense"]: assert padded_batch.features[feature].shape[1] == _max_sequence_length - assert torch.equal(padded_batch.features["d_context"], sequence_batch.features["d_context"]) def test_batch_invalid_lengths(self): # Test when targets is not a tensor nor a dictionary of tensors @@ -73,7 +73,7 @@ def test_batch_invalid_lengths(self): ValueError, match="The sequential inputs must have the same length for each row in the batch", ): - padding_op = TabularBatchPadding(schema=Schema(["a", "b"])) + padding_op = TabularPadding(schema=Schema(["a", "b"])) padding_op( Batch( { @@ -93,7 +93,7 @@ def test_padded_targets(self, sequence_batch, sequence_schema): "target_2__values": target_values, "target_2__offsets": target_offsets, } - padding_op = TabularBatchPadding( + padding_op = TabularPadding( schema=sequence_schema, max_sequence_length=_max_sequence_length ) padded_batch = padding_op(sequence_batch) @@ -131,10 +131,7 @@ def sequence_schema(self): @pytest.fixture def padded_batch(self, sequence_schema, sequence_batch): - _max_sequence_length = 5 - padding_op = TabularBatchPadding( - schema=sequence_schema, max_sequence_length=_max_sequence_length - ) + padding_op = TabularPadding(schema=sequence_schema) return padding_op(sequence_batch) def test_tabular_sequence_transform_wrong_inputs(self, padded_batch, sequence_schema): @@ -163,39 +160,29 @@ def test_tabular_sequence_transform_wrong_inputs(self, padded_batch, sequence_sc ) ) - with pytest.raises( - ValueError, - match=re.escape( - "Sequential target column (d_context) must be a 2D tensor, but shape is 1" - ), - ): - transform = TabularSequenceTransform(schema=sequence_schema, target="d_context") - transform._check_target_shape(padded_batch) - - def test_transform_predict_next(self, padded_batch, sequence_schema): + def test_transform_predict_next(self, sequence_batch, padded_batch, sequence_schema): transform = TabularPredictNext( schema=sequence_schema.select_by_tag(Tags.SEQUENCE), target="a" ) assert transform.target_name == ["a"] - batch_output = transform(padded_batch) + batch_output = transform(sequence_batch) - assert list(batch_output.features.keys()) == ["a", "b", "c_dense", "d_context"] + assert list(batch_output.features.keys()) == ["a", "b", "c_dense"] for k in ["a", "b", "c_dense"]: assert torch.equal(batch_output.features[k], padded_batch.features[k][:, :-1]) - assert torch.equal(batch_output.features["d_context"], padded_batch.features["d_context"]) assert torch.equal(batch_output.sequences.length("a"), torch.Tensor([2, 1, 3])) - def test_transform_mask_random(self, padded_batch, sequence_schema): + def test_transform_mask_random(self, sequence_batch, padded_batch, sequence_schema): transform = TabularMaskRandom( schema=sequence_schema.select_by_tag(Tags.SEQUENCE), target="a" ) assert transform.target_name == ["a"] - batch_output = transform(padded_batch) + batch_output = transform(sequence_batch) - assert list(batch_output.features.keys()) == ["a", "b", "c_dense", "d_context"] - for name in ["a", "b", "c_dense", "d_context"]: + assert list(batch_output.features.keys()) == ["a", "b", "c_dense"] + for name in ["a", "b", "c_dense"]: assert torch.equal(batch_output.features[name], padded_batch.features[name]) assert torch.equal(batch_output.sequences.length("a"), torch.Tensor([3, 2, 4])) @@ -204,3 +191,27 @@ def test_transform_mask_random(self, padded_batch, sequence_schema): assert torch.all(batch_output.sequences.mask("a").sum(1) != pad_mask.sum(1)) # check that at least one candidate is masked assert torch.all(batch_output.sequences.mask("a").sum(1) > 0) + + def test_transform_mask_last(self, sequence_batch, padded_batch, sequence_schema): + transform = TabularMaskLast(schema=sequence_schema.select_by_tag(Tags.SEQUENCE), target="a") + assert transform.target_name == ["a"] + + batch_output = transform(sequence_batch) + + assert list(batch_output.features.keys()) == ["a", "b", "c_dense"] + for name in ["a", "b", "c_dense"]: + assert torch.equal(batch_output.features[name], padded_batch.features[name]) + assert torch.equal(batch_output.sequences.length("a"), torch.Tensor([3, 2, 4])) + + # check one candidate (last) per row is masked + assert torch.all(batch_output.sequences.mask("a").sum(1) == 1) + assert torch.all( + batch_output.sequences.mask("a") + == torch.Tensor( + [ + [False, False, True, False], + [False, True, False, False], + [False, False, False, True], + ] + ) + ) From e48029943d5e80faa8105014ef060b04fb8c29f6 Mon Sep 17 00:00:00 2001 From: sararb Date: Wed, 28 Jun 2023 21:18:27 +0000 Subject: [PATCH 3/4] fix linting --- merlin/models/torch/sequences.py | 6 ++++-- tests/unit/torch/test_sequence.py | 1 - 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/merlin/models/torch/sequences.py b/merlin/models/torch/sequences.py index 5f75cb227f..36426e6558 100644 --- a/merlin/models/torch/sequences.py +++ b/merlin/models/torch/sequences.py @@ -51,8 +51,10 @@ class TabularPadding(nn.Module): padded_batch = padding_op(Batch(feaures)) Note: - - If the schema includes continuous list features, please make sure they are normalized between [0,1] - As we will pad them to `max_sequence_length` using the minimum value `0.0`. + If the schema contains continuous list features, + ensure that they are normalized within the range of [0, 1]. + This is necessary because we will be padding them + to a max_sequence_length using the minimum value of 0.0. """ def __init__( diff --git a/tests/unit/torch/test_sequence.py b/tests/unit/torch/test_sequence.py index be5440fa12..70df495a6a 100644 --- a/tests/unit/torch/test_sequence.py +++ b/tests/unit/torch/test_sequence.py @@ -1,4 +1,3 @@ -import re from itertools import accumulate import pytest From caa72f8d362e9c8d036ac7cd239c7fd694a3348c Mon Sep 17 00:00:00 2001 From: sararb Date: Wed, 28 Jun 2023 21:29:55 +0000 Subject: [PATCH 4/4] update docstrings --- merlin/models/torch/sequences.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/merlin/models/torch/sequences.py b/merlin/models/torch/sequences.py index 36426e6558..89c0320f09 100644 --- a/merlin/models/torch/sequences.py +++ b/merlin/models/torch/sequences.py @@ -267,7 +267,7 @@ class TabularPredictNext(TabularSequenceTransform): transform = TabularPredictNext( schema=schema.select_by_tag(Tags.SEQUENCE), target="a" ) - batch_output = transform(padded_batch) + batch_output = transform(batch) """ @@ -333,6 +333,13 @@ class TabularMaskRandom(TabularSequenceTransform): Probability of a candidate to be selected (masked) as a label of the given sequence. Note: We enforce that at least one candidate is masked for each sequence, so that it is useful for training, by default 0.2 + + Examples: + transform = TabularMaskRandom( + schema=schema.select_by_tag(Tags.SEQUENCE), target="a", masking_prob=0.4 + ) + batch_output = transform(batch) + """ def __init__( @@ -407,17 +414,14 @@ class TabularMaskLast(TabularSequenceTransform): target : Union[str, List[str], Tags, ColumnSchema, Schema] The sequential input column(s) that will be used to compute the masked positions. Targets can be one or multiple input features with the same sequence length. - """ - def __init__( - self, - schema: Schema, - target: Union[str, Tags, ColumnSchema], - masking_prob: float = 0.2, - **kwargs, - ): - self.masking_prob = masking_prob - super().__init__(schema, target, **kwargs) + Examples: + transform = TabularMaskLast( + schema=schema.select_by_tag(Tags.SEQUENCE), target="a" + ) + batch_output = transform(batch) + + """ def forward(self, batch: Batch, **kwargs) -> Tuple: if self.apply_padding: