Skip to content

Commit

Permalink
fix shufflenet v2
Browse files Browse the repository at this point in the history
  • Loading branch information
ain-soph committed Jan 15, 2024
1 parent c24b402 commit ea6de69
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions trojanvision/models/torchvision/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class _ShuffleNetV2(_ImageModel):
def __init__(self, name: str = 'shufflenetv2_x1_0', **kwargs):
def __init__(self, name: str = 'shufflenet_v2_x1_0', **kwargs):
try:
assert name in ShuffleNetV2.available_models, f'{name=}'
except Exception:
Expand Down Expand Up @@ -50,11 +50,11 @@ class ShuffleNetV2(ImageModel):
.. code-block:: python3
{'shufflenetv2', 'shufflenetv2_comp',
'shufflenetv2_x0_5', 'shufflenetv2_x1_0',
'shufflenetv2_x1_5', 'shufflenetv2_x2_0',
'shufflenetv2_x0_5_comp', 'shufflenetv2_x1_0_comp',
'shufflenetv2_x1_5_comp', 'shufflenetv2_x2_0_comp'}
{'shufflenet_v2', 'shufflenet_v2_comp',
'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0',
'shufflenet_v2_x0_5_comp', 'shufflenet_v2_x1_0_comp',
'shufflenet_v2_x1_5_comp', 'shufflenet_v2_x2_0_comp'}
See Also:
* torchvision: :any:`torchvision.models.shufflenet_v2_x0_5`
Expand All @@ -70,20 +70,20 @@ class ShuffleNetV2(ImageModel):
.. _ShuffleNet V2\: Practical Guidelines for Efficient CNN Architecture Design:
https://arxiv.org/abs/1807.11164
"""
available_models = {'shufflenetv2', 'shufflenetv2_comp',
'shufflenetv2_x0_5', 'shufflenetv2_x1_0',
'shufflenetv2_x1_5', 'shufflenetv2_x2_0',
'shufflenetv2_x0_5_comp', 'shufflenetv2_x1_0_comp',
'shufflenetv2_x1_5_comp', 'shufflenetv2_x2_0_comp'}
available_models = {'shufflenet_v2', 'shufflenet_v2_comp',
'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0',
'shufflenet_v2_x0_5_comp', 'shufflenet_v2_x1_0_comp',
'shufflenet_v2_x1_5_comp', 'shufflenet_v2_x2_0_comp'}

weights = {
'shufflenetv2_x0_5': ShuffleNet_V2_X0_5_Weights,
'shufflenetv2_x1_0': ShuffleNet_V2_X1_0_Weights,
'shufflenetv2_x1_5': ShuffleNet_V2_X1_5_Weights,
'shufflenetv2_x2_0': ShuffleNet_V2_X2_0_Weights,
'shufflenet_v2_x0_5': ShuffleNet_V2_X0_5_Weights,
'shufflenet_v2_x1_0': ShuffleNet_V2_X1_0_Weights,
'shufflenet_v2_x1_5': ShuffleNet_V2_X1_5_Weights,
'shufflenet_v2_x2_0': ShuffleNet_V2_X2_0_Weights,
}

def __init__(self, name: str = 'shufflenetv2', layer: str = '_x0_5',
def __init__(self, name: str = 'shufflenet_v2', layer: str = '_x0_5',
model: type[_ShuffleNetV2] = _ShuffleNetV2, **kwargs):
super().__init__(name=name, layer=layer, model=model, **kwargs)

Expand Down

0 comments on commit ea6de69

Please sign in to comment.