diff --git a/merlin/models/tf/core/tabular.py b/merlin/models/tf/core/tabular.py index 33b1ed5b42..3f2bb8e574 100644 --- a/merlin/models/tf/core/tabular.py +++ b/merlin/models/tf/core/tabular.py @@ -1,5 +1,5 @@ import abc -import collections +import collections.abc import copy from typing import Dict, List, Optional, Sequence, Union, overload @@ -600,7 +600,7 @@ def get_config(self): def select_by_tag(self, tags: Tags) -> Optional["Filter"]: if isinstance(self.feature_names, Tags): schema = self.schema.select_by_tag(self.feature_names).select_by_tag(tags) - elif isinstance(self.feature_names, collections.Sequence): + elif isinstance(self.feature_names, collections.abc.Sequence): schema = self.schema.select_by_name(self.feature_names).select_by_tag(tags) else: raise RuntimeError( diff --git a/merlin/models/tf/inputs/embedding.py b/merlin/models/tf/inputs/embedding.py index aff30184a6..156ef974e8 100644 --- a/merlin/models/tf/inputs/embedding.py +++ b/merlin/models/tf/inputs/embedding.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import collections +import collections.abc import inspect from copy import deepcopy from dataclasses import dataclass @@ -268,7 +268,7 @@ def select_by_tag(self, tags: Union[Tags, Sequence[Tags]]) -> Optional["Embeddin ------- An EmbeddingTable if the tags match. If no features match, it returns None. """ - if not isinstance(tags, collections.Sequence): + if not isinstance(tags, collections.abc.Sequence): tags = [tags] selected_schema = self.schema.select_by_tag(tags)