diff --git a/bdpy/dl/torch/models.py b/bdpy/dl/torch/models.py index e83dc40e..5ab0d61a 100644 --- a/bdpy/dl/torch/models.py +++ b/bdpy/dl/torch/models.py @@ -1,5 +1,6 @@ """Model definitions.""" +from __future__ import annotations from typing import Dict, Union, Optional, Sequence @@ -127,7 +128,7 @@ def _get_value_by_indices(array, indices): pattern = re.compile(r'^(?P[a-zA-Z_]+[a-zA-Z0-9_]*)?(?P(\[(\d+)\])+)$') m = pattern.match(layer_name) if m is not None: - layer_name = m.group('layer_name') # NOTE: layer_name can be None + layer_name: str | None = m.group('layer_name') # NOTE: layer_name can be None index_str = m.group('index') indeces = re.findall(r'\[(\d+)\]', index_str)