diff --git a/merlin/models/torch/transforms/sequences.py b/merlin/models/torch/transforms/sequences.py index e4dd847929..94f506de69 100644 --- a/merlin/models/torch/transforms/sequences.py +++ b/merlin/models/torch/transforms/sequences.py @@ -77,11 +77,12 @@ def __init__( ): super().__init__() if schema: - self.setup_schema(schema) + self.initialize_from_schema(schema) + self._initialized_from_schema = True self.max_sequence_length = max_sequence_length self.padding_idx = 0 - def setup_schema(self, schema: Schema): + def initialize_from_schema(self, schema: Schema): self.schema = schema self.features: List[str] = self.schema.column_names self.sparse_features = self.schema.select_by_tag(Tags.SEQUENCE).column_names