Skip to content

Commit

Permalink
updates output of tabular transforms and add TabularMaskLast block
Browse files Browse the repository at this point in the history
  • Loading branch information
sararb committed Jun 28, 2023
1 parent 509cbc6 commit eabad77
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 105 deletions.
181 changes: 103 additions & 78 deletions merlin/models/torch/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
from torch import nn

from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.schema import Selection, select
from merlin.schema import ColumnSchema, Schema, Tags

MASK_PREFIX = "__mask"


class TabularBatchPadding(nn.Module):
class TabularPadding(nn.Module):
"""A PyTorch module for padding tabular sequence data.
Parameters
Expand All @@ -48,6 +49,10 @@ class TabularBatchPadding(nn.Module):
schema=schema, max_sequence_length=_max_sequence_length
)
padded_batch = padding_op(Batch(feaures))
Note:
- If the schema includes continuous list features, please make sure they are normalized between [0,1]
As we will pad them to `max_sequence_length` using the minimum value `0.0`.
"""

def __init__(
Expand Down Expand Up @@ -88,20 +93,18 @@ def forward(self, batch: Batch) -> Batch:
for key, value in batch.features.items():
if key.endswith("__offsets"):
col_name = key[: -len("__offsets")]
padded_values = self._pad_ragged_tensor(
batch.features[f"{col_name}__values"], value, _max_sequence_length
)
batch_padded[col_name] = padded_values
if col_name in self.features:
padded_values = self._pad_ragged_tensor(
batch.features[f"{col_name}__values"], value, _max_sequence_length
)
batch_padded[col_name] = padded_values
elif key.endswith("__values"):
continue
else:
col_name = key
if seq_inputs_lengths.get(col_name) is not None:
if col_name in self.features and seq_inputs_lengths.get(col_name) is not None:
# pad dense list features
batch_padded[col_name] = self._pad_dense_tensor(value, _max_sequence_length)
else:
# context features are not modified
batch_padded[col_name] = value

# Pad targets of the batch
targets_padded = None
Expand Down Expand Up @@ -189,40 +192,26 @@ class TabularSequenceTransform(nn.Module):
def __init__(
self,
schema: Schema,
target: Union[str, Tags, ColumnSchema, Schema], # TODO: multiple-targets support?
target: Selection,
apply_padding: bool = True,
max_sequence_length: int = None,
):
super().__init__()
self.schema = schema
self.target = target
self.target_name = self._get_target(target)
self.features = schema.column_names
self.target = select(self.schema, target)
self.target_name = self._get_target(self.target)
self.padding_idx = 0
self.apply_padding = apply_padding
if self.apply_padding:
self.padding_operator = TabularPadding(
schema=self.schema, max_sequence_length=max_sequence_length
)

def _get_target(self, target):
if (
(isinstance(target, str) and target not in self.schema.column_names)
or (isinstance(target, Tags) and len(self.schema.select_by_tag(target)) > 0)
or (isinstance(target, ColumnSchema) and target not in self.schema)
):
raise ValueError("The target column needs to be part of the provided sequential schema")

target_name = target
if isinstance(target, ColumnSchema):
target_name = [target.name]
if isinstance(target, Tags):
if len(self.schema.select_by_tag(target)) > 1:
raise ValueError(
"Only 1 column should the Tag ({target}) provided for target, but"
f"the following columns have that tag: "
f"{self.schema.select_by_tag(target).column_names}"
)
target_name = self.schema.select_by_tag(target).column_names
if isinstance(target, Schema):
target_name = target.column_names
if isinstance(target, str):
target_name = [target]
return target_name

def forward(self, inputs: Batch, **kwargs) -> Tuple:
return target.column_names

def forward(self, batch: Batch, **kwargs) -> Tuple:
raise NotImplementedError()

def _check_seq_inputs_targets(self, batch: Batch):
Expand Down Expand Up @@ -252,9 +241,7 @@ def _check_input_sequence_lengths(self, batch):
raise ValueError(
"The input `batch` should include information about input sequences lengths"
)
sequence_lengths = torch.stack(
[batch.sequences.length(name) for name in self.schema.column_names]
)
sequence_lengths = torch.stack([batch.sequences.length(name) for name in self.features])
assert torch.all(sequence_lengths.eq(sequence_lengths[0])), (
"All tabular sequence features need to have the same sequence length, "
f"found {sequence_lengths}"
Expand Down Expand Up @@ -288,44 +275,34 @@ def _generate_causal_mask(self, seq_lengths, max_len):
"""
return torch.arange(max_len)[None, :] < seq_lengths[:, None]

def forward(self, batch: Batch, **kwargs) -> Tuple:
def forward(self, batch: Batch, **kwargs) -> Batch:
if self.apply_padding:
batch = self.padding_operator(batch)
self._check_seq_inputs_targets(batch)

# Shifts the target column to be the next item of corresponding input column
targets = batch.targets
new_targets = {}
for name in self.target_name:
new_target = batch.features[name]
new_target = new_target[:, 1:]
new_targets[name] = new_target
if targets is None:
targets = new_targets
elif isinstance(targets, dict):
targets.update(new_targets)
else:
raise ValueError("Targets should be None or a dict of tensors")

# Removes the last item of the sequence, as it belongs to the target
new_inputs = dict()
for k, v in batch.features.items():
if k in self.schema.column_names:
if k in self.features:
new_inputs[k] = v[:, :-1]
else:
new_inputs[k] = v

# Generates information about new lengths and causal masks
new_lengths, causal_mask = batch.sequences.lengths, batch.sequences.masks
if causal_mask is None:
causal_mask = {}
new_lengths, causal_masks = {}, {}
_max_length = new_target.shape[-1] # all new targets have same output sequence length
for name in self.schema.column_names:
new_lengths[name] = new_lengths[name] - 1
causal_mask[name] = self._generate_causal_mask(new_lengths[name], _max_length)
for name in self.features:
new_lengths[name] = batch.sequences.lengths[name] - 1
causal_masks[name] = self._generate_causal_mask(new_lengths[name], _max_length)

return Batch(
features=new_inputs,
targets=targets,
sequences=Sequence(new_lengths, masks=causal_mask),
targets=new_targets,
sequences=Sequence(new_lengths, masks=causal_masks),
)


Expand Down Expand Up @@ -367,32 +344,24 @@ def __init__(
super().__init__(schema, target, **kwargs)

def forward(self, batch: Batch, **kwargs) -> Tuple:
if self.apply_padding:
batch = self.padding_operator(batch)
self._check_seq_inputs_targets(batch)

new_targets = dict({name: torch.clone(batch.features[name]) for name in self.target_name})
targets = batch.targets
if targets is None:
targets = new_targets
elif isinstance(targets, dict):
targets.update(new_targets)
else:
raise ValueError("Targets should be None or a dict of tensors")
new_inputs = {feat: batch.features[feat] for feat in self.features}
sequence_lengths = {feat: batch.sequences.length(feat) for feat in self.features}

# Generates mask information for the group of input sequences
target_mask = self._generate_random_mask(new_targets[self.target_name[0]])
random_mask = batch.sequences.masks
if random_mask is None:
random_mask = {}
for name in self.schema.column_names:
random_mask[name] = target_mask
target_mask = self._generate_mask(new_targets[self.target_name[0]])
random_mask = {name: target_mask for name in self.features}

return Batch(
features=batch.features,
targets=targets,
sequences=Sequence(batch.sequences.lengths, masks=random_mask),
features=new_inputs,
targets=new_targets,
sequences=Sequence(sequence_lengths, masks=random_mask),
)

def _generate_random_mask(self, new_target: torch.Tensor) -> torch.Tensor:
def _generate_mask(self, new_target: torch.Tensor) -> torch.Tensor:
"""Generate mask information at random positions from a 2D target sequence"""

non_padded_mask = new_target != self.padding_idx
Expand Down Expand Up @@ -420,3 +389,59 @@ def _generate_random_mask(self, new_target: torch.Tensor) -> torch.Tensor:
rows_to_unmask = torch.masked_select(rows_ids, sequences_with_only_labels)
mask_targets[rows_to_unmask, targets_to_unmask] = False
return mask_targets


class TabularMaskLast(TabularSequenceTransform):
"""This transform copies one of the sequence input features to be
the target feature. The last item of the target sequence is selected (masked)
to be predicted.
The target masks are returned by copying the related input features.
Parameters
----------
schema : Schema
The schema with the sequential inputs to be masked
target : Union[str, List[str], Tags, ColumnSchema, Schema]
The sequential input column(s) that will be used to compute the masked positions.
Targets can be one or multiple input features with the same sequence length.
"""

def __init__(
self,
schema: Schema,
target: Union[str, Tags, ColumnSchema],
masking_prob: float = 0.2,
**kwargs,
):
self.masking_prob = masking_prob
super().__init__(schema, target, **kwargs)

def forward(self, batch: Batch, **kwargs) -> Tuple:
if self.apply_padding:
batch = self.padding_operator(batch)
self._check_seq_inputs_targets(batch)
new_targets = dict({name: torch.clone(batch.features[name]) for name in self.target_name})
new_inputs = {feat: batch.features[feat] for feat in self.features}
sequence_lengths = {feat: batch.sequences.length(feat) for feat in self.features}

# Generates mask information for the group of input sequences
target_mask = self._generate_mask(new_targets[self.target_name[0]])
masks = {name: target_mask for name in self.features}

return Batch(
features=new_inputs,
targets=new_targets,
sequences=Sequence(sequence_lengths, masks=masks),
)

def _generate_mask(self, new_target: torch.Tensor) -> torch.Tensor:
"""Generate mask information at last positions from a 2D target sequence"""
target_mask = new_target != self.padding_idx
last_non_padded_indices = (target_mask.sum(dim=1) - 1).unsqueeze(-1)

mask_targets = (
torch.arange(target_mask.size(1), device=target_mask.device).unsqueeze(0)
== last_non_padded_indices
)
return mask_targets
Loading

0 comments on commit eabad77

Please sign in to comment.