Skip to content

Commit

Permalink
fix the behavior of _parse_layer_name
Browse files Browse the repository at this point in the history
  • Loading branch information
ganow committed Dec 15, 2023
1 parent 9ed810d commit bcc1772
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
13 changes: 7 additions & 6 deletions bdpy/dl/torch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _parse_layer_name(model: nn.Module, layer_name: str) -> nn.Module:
Network model.
layer_name : str
Layer name. It accepts the following formats: 'layer_name',
'layer_name[index]', 'parent_name.child_name', and combinations of them.
'[index]', 'parent_name.child_name', and combinations of them.
Returns
-------
Expand Down Expand Up @@ -123,18 +123,19 @@ def _get_value_by_indices(array, indices):
model = _parse_layer_name(model, top_most_layer_name)
return _parse_layer_name(model, child_layer_name)

# parse layer name having index (e.g., 'features[0]', 'backbone[0][1]')
pattern = re.compile(r'^(?P<layer_name>\w+)(?P<index>(\[(\d+)\])+)$')
# parse layer name having index (e.g., '[0]', 'features[0]', 'backbone[0][1]')
pattern = re.compile(r'^(?P<layer_name>[a-zA-Z_]+[a-zA-Z0-9_]*)?(?P<index>(\[(\d+)\])+)$')
m = pattern.match(layer_name)
if m is not None:
layer_name = m.group('layer_name')
layer_name = m.group('layer_name') # NOTE: layer_name can be None
index_str = m.group('index')

indeces = re.findall(r'\[(\d+)\]', index_str)
indeces = [int(i) for i in indeces]

if hasattr(model, layer_name):
return _get_value_by_indices(getattr(model, layer_name), indeces)
if isinstance(layer_name, str) and hasattr(model, layer_name):
model = getattr(model, layer_name)
return _get_value_by_indices(model, indeces)

raise ValueError(
f"Invalid layer name: '{layer_name}'. Either the syntax of '{layer_name}' is not supported, "
Expand Down
23 changes: 23 additions & 0 deletions test/dl/torch/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
from bdpy.dl.torch import models


def _removeprefix(text: str, prefix: str) -> str:
"""Remove prefix from text. (Workaround for Python 3.8)"""
if text.startswith(prefix):
return text[len(prefix):]
return text


class MockModule(nn.Module):
def __init__(self):
super(MockModule, self).__init__()
Expand Down Expand Up @@ -69,6 +76,22 @@ def test_parse_layer_name(self):
self.assertRaises(
ValueError, models._parse_layer_name, self.mock, 'layers["key"]')

def test_parse_layer_name_for_sequential(self):
"""Test _parse_layer_name for nn.Sequential.
nn.Sequential is a special case because the submodules are directly
accessible like a list. For example, `model[0]` will return the first
module in the model.
"""
sequential_module = self.mock.layers
accessors = [accessor for accessor in self.accessors if accessor['name'].startswith('layers')]
for accessor in accessors:
accsessor_key = _removeprefix(accessor['name'], 'layers')
layer = models._parse_layer_name(sequential_module, accsessor_key)
self.assertIsInstance(layer, accessor['type'])
for attr, value in accessor['attrs'].items():
self.assertEqual(getattr(layer, attr), value)


class TestVGG19(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit bcc1772

Please sign in to comment.