Skip to content

Commit

Permalink
Merge pull request #82 from KamitaniLab/fix-parse-layer-name
Browse files Browse the repository at this point in the history
fix the behavior of _parse_layer_name
  • Loading branch information
ShuntaroAoki authored Dec 18, 2023
2 parents 9ed810d + 6f4c52f commit bc75ff6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
14 changes: 8 additions & 6 deletions bdpy/dl/torch/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Model definitions."""

from __future__ import annotations

from typing import Dict, Union, Optional, Sequence

Expand Down Expand Up @@ -93,7 +94,7 @@ def _parse_layer_name(model: nn.Module, layer_name: str) -> nn.Module:
Network model.
layer_name : str
Layer name. It accepts the following formats: 'layer_name',
'layer_name[index]', 'parent_name.child_name', and combinations of them.
'[index]', 'parent_name.child_name', and combinations of them.
Returns
-------
Expand Down Expand Up @@ -123,18 +124,19 @@ def _get_value_by_indices(array, indices):
model = _parse_layer_name(model, top_most_layer_name)
return _parse_layer_name(model, child_layer_name)

# parse layer name having index (e.g., 'features[0]', 'backbone[0][1]')
pattern = re.compile(r'^(?P<layer_name>\w+)(?P<index>(\[(\d+)\])+)$')
# parse layer name having index (e.g., '[0]', 'features[0]', 'backbone[0][1]')
pattern = re.compile(r'^(?P<layer_name>[a-zA-Z_]+[a-zA-Z0-9_]*)?(?P<index>(\[(\d+)\])+)$')
m = pattern.match(layer_name)
if m is not None:
layer_name = m.group('layer_name')
layer_name: str | None = m.group('layer_name') # NOTE: layer_name can be None
index_str = m.group('index')

indeces = re.findall(r'\[(\d+)\]', index_str)
indeces = [int(i) for i in indeces]

if hasattr(model, layer_name):
return _get_value_by_indices(getattr(model, layer_name), indeces)
if isinstance(layer_name, str) and hasattr(model, layer_name):
model = getattr(model, layer_name)
return _get_value_by_indices(model, indeces)

raise ValueError(
f"Invalid layer name: '{layer_name}'. Either the syntax of '{layer_name}' is not supported, "
Expand Down
23 changes: 23 additions & 0 deletions test/dl/torch/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
from bdpy.dl.torch import models


def _removeprefix(text: str, prefix: str) -> str:
"""Remove prefix from text. (Workaround for Python 3.8)"""
if text.startswith(prefix):
return text[len(prefix):]
return text


class MockModule(nn.Module):
def __init__(self):
super(MockModule, self).__init__()
Expand Down Expand Up @@ -69,6 +76,22 @@ def test_parse_layer_name(self):
self.assertRaises(
ValueError, models._parse_layer_name, self.mock, 'layers["key"]')

def test_parse_layer_name_for_sequential(self):
"""Test _parse_layer_name for nn.Sequential.
nn.Sequential is a special case because the submodules are directly
accessible like a list. For example, `model[0]` will return the first
module in the model.
"""
sequential_module = self.mock.layers
accessors = [accessor for accessor in self.accessors if accessor['name'].startswith('layers')]
for accessor in accessors:
accsessor_key = _removeprefix(accessor['name'], 'layers')
layer = models._parse_layer_name(sequential_module, accsessor_key)
self.assertIsInstance(layer, accessor['type'])
for attr, value in accessor['attrs'].items():
self.assertEqual(getattr(layer, attr), value)


class TestVGG19(unittest.TestCase):
def setUp(self):
Expand Down

2 comments on commit bc75ff6

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
bdpy
   __init__.py330%9–11
bdpy/bdata
   __init__.py220%7–8
   bdata.py3983980%26–919
   featureselector.py64640%8–124
   metadata.py67670%8–154
   utils.py1131130%4–263
bdpy/dataform
   __init__.py440%7–10
   datastore.py1071070%7–265
   features.py2982980%8–549
   pd.py990%7–44
   sparse.py67670%6–126
   utils.py12120%3–18
bdpy/dataset
   utils.py45450%3–98
bdpy/distcomp
   __init__.py110%6
   distcomp.py92920%7–127
bdpy/dl
   caffe.py60600%4–129
bdpy/dl/torch
   __init__.py220%1–2
   base.py43430%6–105
   models.py3343340%3–876
   torch.py1091090%3–258
bdpy/evals
   metrics.py95950%3–179
bdpy/feature
   __init__.py110%3
   feature.py30300%1–74
bdpy/fig
   __init__.py440%6–9
   draw_group_image_set.py90900%3–182
   fig.py88880%16–164
   makeplots.py3363360%1–729
   tile_images.py59590%1–193
bdpy/ml
   __init__.py770%8–14
   crossvalidation.py59590%7–196
   ensemble.py13130%5–46
   learning.py3083080%4–613
   model.py1401400%4–285
   regress.py11110%6–38
   searchlight.py16160%4–51
bdpy/mri
   __init__.py770%7–13
   fmriprep.py4974970%4–866
   glm.py40400%4–95
   image.py24240%4–54
   load_epi.py28280%7–88
   load_mri.py19190%4–36
   roi.py2482480%4–499
   spm.py1581580%1–300
bdpy/opendata
   __init__.py110%1
   openneuro.py2102100%1–329
bdpy/preproc
   __init__.py330%8–10
   interface.py52520%8–217
   preprocessor.py1291290%8–236
   select_top.py22220%8–61
   util.py660%6–22
bdpy/recon
   utils.py55550%4–146
bdpy/recon/torch
   __init__.py110%1
   icnn.py1611610%15–478
bdpy/stats
   __init__.py110%13
   corr.py43430%6–112
bdpy/util
   __init__.py330%7–9
   info.py47470%4–79
   math.py13130%4–38
   utils.py36360%7–145
TOTAL489148910% 

Tests Skipped Failures Errors Time
115 0 💤 13 ❌ 6 🔥 9.267s ⏱️

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
bdpy
   __init__.py330%9–11
bdpy/bdata
   __init__.py220%7–8
   bdata.py3983980%26–919
   featureselector.py64640%8–124
   metadata.py67670%8–154
   utils.py1131130%4–263
bdpy/dataform
   __init__.py440%7–10
   datastore.py1071070%7–265
   features.py2982980%8–549
   pd.py990%7–44
   sparse.py67670%6–126
   utils.py12120%3–18
bdpy/dataset
   utils.py45450%3–98
bdpy/distcomp
   __init__.py110%6
   distcomp.py92920%7–127
bdpy/dl
   caffe.py60600%4–129
bdpy/dl/torch
   __init__.py220%1–2
   base.py43430%6–105
   models.py3343340%3–876
   torch.py1091090%3–258
bdpy/evals
   metrics.py95950%3–179
bdpy/feature
   __init__.py110%3
   feature.py30300%1–74
bdpy/fig
   __init__.py440%6–9
   draw_group_image_set.py90900%3–182
   fig.py88880%16–164
   makeplots.py3363360%1–729
   tile_images.py59590%1–193
bdpy/ml
   __init__.py770%8–14
   crossvalidation.py59590%7–196
   ensemble.py13130%5–46
   learning.py3083080%4–613
   model.py1401400%4–285
   regress.py11110%6–38
   searchlight.py16160%4–51
bdpy/mri
   __init__.py770%7–13
   fmriprep.py4974970%4–866
   glm.py40400%4–95
   image.py24240%4–54
   load_epi.py28280%7–88
   load_mri.py19190%4–36
   roi.py2482480%4–499
   spm.py1581580%1–300
bdpy/opendata
   __init__.py110%1
   openneuro.py2102100%1–329
bdpy/preproc
   __init__.py330%8–10
   interface.py52520%8–217
   preprocessor.py1291290%8–236
   select_top.py22220%8–61
   util.py660%6–22
bdpy/recon
   utils.py55550%4–146
bdpy/recon/torch
   __init__.py110%1
   icnn.py1611610%15–478
bdpy/stats
   __init__.py110%13
   corr.py43430%6–112
bdpy/util
   __init__.py330%7–9
   info.py47470%4–79
   math.py13130%4–38
   utils.py36360%7–145
TOTAL489148910% 

Tests Skipped Failures Errors Time
115 0 💤 13 ❌ 6 🔥 9.073s ⏱️

Please sign in to comment.