Skip to content

Commit

Permalink
Merge pull request #65 from paganpasta/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
paganpasta authored Oct 25, 2022
2 parents 2e40082 + 5413e5f commit a1401ff
Show file tree
Hide file tree
Showing 17 changed files with 277 additions and 41 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Picking a model and doing a forward pass is as simple as ...

## What's New?

- `FCN` and `DeepLabV3` added as new image segmentation models.
- `FCN`, `DeepLabV3` and `LRASPP` added as new image segmentation models.
- Backward incompatible changes to `v0.2.0` for loading a `pretrained` model.
- Almost all image classification models are ported from `torchvision`.
- New tutorial for generating `adversarial examples` and others coming soon.
Expand Down
12 changes: 12 additions & 0 deletions docs/api/models/segmentation/lraspp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# LRASPP


::: eqxvision.models.LRASPP
selection:
members:
- __init__
- __call__

---

::: eqxvision.models.lraspp_mobilenet_v3_large
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ pip install eqxvision
```

## What's New?
- `FCN`, `DeepLabV3` and `LRASPP` segmentation models are now supported (checkout the [tutorial](getting_started/FCN_Segmentation.ipynb)).
- Backward incompatible changes to `v0.2.0` for loading a `pretrained` model.
- `FCN` and `DeepLabV3` segmentation models are now supported (checkout the [tutorial](getting_started/FCN_Segmentation.ipynb)).
- Almost all image classification models are ported from `torchvision`.
- New tutorial for generating [adversarial examples](getting_started/Adversarial_Attack.ipynb) and others coming soon.

Expand Down
2 changes: 1 addition & 1 deletion eqxvision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r"""Root package info."""
__version__ = "0.2.5"
__version__ = "0.2.6"

from . import experimental, layers, models, utils
38 changes: 23 additions & 15 deletions eqxvision/experimental.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Callable

import equinox as eqx
from jaxtyping import PyTree
import equinox.nn as nn


class AuxData:
Expand Down Expand Up @@ -33,7 +33,7 @@ def __call__(self, x, *, key=None):


def intermediate_layer_getter(
model: PyTree, get_target_layers: Callable
model: "eqx.Module", get_target_layers: Callable
) -> "eqx.Module":
"""Wraps intermediate layers of a model for accessing intermediate activations. Based on a discussion
[here](https://github.com/patrick-kidger/equinox/issues/186).
Expand All @@ -49,26 +49,34 @@ def intermediate_layer_getter(
of layers from the `model`
**Returns:**
A `PyTree`, encapsulating `model` for storing intermediate outputs from target layers.
The returned model will now return a `tuple` with
!!! info
The returned model will now return a `tuple` with
0. The final output of `model`
1. An ordered list of intermediate activations
1. The final output of `model`
2. An ordered list of intermediate activations
"""
target_layers = get_target_layers(model)
auxs, wrappers = zip(
*[_make_intermediate_layer_wrapper() for _ in range(len(target_layers))]
)
model = eqx.tree_at(
where=get_target_layers,
pytree=model,
replace=[
wrapper(target_layer)
for (wrapper, target_layer) in zip(wrappers, target_layers)
],
)
if isinstance(model, nn.Sequential):
new_modules, updated_count = [], 0
for idx, module in enumerate(model.layers):
if idx in target_layers:
new_modules.append(wrappers[updated_count](module))
updated_count += 1
else:
new_modules.append(module)
model = nn.Sequential(new_modules)
else:
model = eqx.tree_at(
where=get_target_layers,
pytree=model,
replace=[
wrapper(target_layer)
for (wrapper, target_layer) in zip(wrappers, target_layers)
],
)

class IntermediateLayerGetter(eqx.Module):
model: eqx.Module
Expand Down
1 change: 1 addition & 0 deletions eqxvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,4 @@
)
from .segmentation.deeplabv3 import DeepLabV3, deeplabv3
from .segmentation.fcn import FCN, fcn
from .segmentation.lraspp import LRASPP, lraspp_mobilenet_v3_large
5 changes: 4 additions & 1 deletion eqxvision/models/classification/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,10 @@ def mobilenet_v3_large(torch_weights: str = None, **kwargs: Any) -> MobileNetV3:
"""
arch = "mobilenet_v3_large"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
dilated = kwargs.pop("dilated", False)
inverted_residual_setting, last_channel = _mobilenet_v3_conf(
arch, dilated=dilated, **kwargs
)
model = _mobilenet_v3(arch, inverted_residual_setting, last_channel, **kwargs)
if torch_weights:
model = load_torch_weights(model, torch_weights=torch_weights)
Expand Down
2 changes: 1 addition & 1 deletion eqxvision/models/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import deeplabv3, fcn
from . import deeplabv3, fcn, lraspp
18 changes: 10 additions & 8 deletions eqxvision/models/segmentation/deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from jaxtyping import Array

from ...experimental import intermediate_layer_getter
from ...utils import CLASSIFICATION_URLS, load_torch_weights
from ...utils import load_torch_weights
from ..classification import resnet
from ._utils import _SimpleSegmentationModel
from .fcn import FCNHead
Expand Down Expand Up @@ -139,17 +139,17 @@ def deeplabv3(
num_classes: Optional[int] = 21,
backbone: "eqx.Module" = None,
intermediate_layers: Callable = None,
classifier_module: "eqx.Module" = DeepLabHead,
classifier_module: "eqx.Module" = None,
classifier_in_channels: int = 2048,
aux_classifier_module: "eqx.Module" = FCNHead,
aux_classifier_module: "eqx.Module" = None,
aux_in_channels: int = 1024,
silence_layers: Callable = None,
torch_weights: str = None,
*,
key: Optional["jax.random.PRNGKey"] = None,
) -> DeepLabV3:
"""Implements DeepLabV3 model from
["Rethinking Atrous Convolution for Semantic Image Segmentation"](https://arxiv.org/abs/1706.05587) paper.
[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587) paper.
!!! info "Sample call"
```python
Expand All @@ -167,7 +167,7 @@ def deeplabv3(
- `num_classes`: Number of classes in the segmentation task.
Also controls the final output shape `(num_classes, height, width)`. Defaults to `21`
- `backbone`: The neural network to use for extracting features. If `None`, then all params are set to
`DeepLabV3_RESNET50` with a **pre-trained** backbone but **untrained** DeepLabV3 heads
`DeepLabV3_RESNET50` with `untrained` weights
- `intermediate_layers`: Layers from `backbone` to be used for generating output maps. Default sets it to
`layer3` and `layer4` from `DeepLabV3_RESNET50`
- `classifier_module`: Uses the `DeepLabHead` by default
Expand All @@ -179,15 +179,17 @@ def deeplabv3(
the `fc` layers can be dropped. This is particularly useful when loading weights from `torchvision`. By
default, `.fc` layer of a model is set to identity to avoid tracking weights.
- `torch_weights`: A `Path` or `URL` for the `PyTorch` weights. Defaults to `None`
- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
"""
if key is None:
key = jr.PRNGKey(0)
keys = jr.split(key, 2)

if not classifier_module:
classifier_module = DeepLabHead
if not aux_classifier_module:
aux_classifier_module = FCNHead
if backbone is None:
backbone = resnet.resnet50(
torch_weights=CLASSIFICATION_URLS["resnet50"],
replace_stride_with_dilation=[False, True, True],
)
num_layers = len(intermediate_layers(backbone))
Expand Down
12 changes: 6 additions & 6 deletions eqxvision/models/segmentation/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax.random as jr

from ...experimental import intermediate_layer_getter
from ...utils import CLASSIFICATION_URLS, load_torch_weights
from ...utils import load_torch_weights
from ..classification import resnet
from ._utils import _SimpleSegmentationModel

Expand Down Expand Up @@ -38,7 +38,7 @@ def fcn(
num_classes: Optional[int] = 21,
backbone: "eqx.Module" = None,
intermediate_layers: Callable = None,
classifier_module: "eqx.Module" = FCNHead,
classifier_module: "eqx.Module" = None,
classifier_in_channels: int = 2048,
aux_in_channels: int = None,
silence_layers: Callable = None,
Expand All @@ -64,7 +64,7 @@ def fcn(
- `num_classes`: Number of classes in the segmentation task.
Also controls the final output shape `(num_classes, height, width)`. Defaults to `21`
- `backbone`: The neural network to use for extracting features. If `None`, then all params are set to
`FCN_RESNET50` with a **pre-trained** backbone but an **untrained** FCN
`FCN_RESNET50` with `untrained` weights
- `intermediate_layers`: Layers from `backbone` to be used for generating output maps. Default sets it to
`layer3` and `layer4` from `FCN_RESNET50`
- `classifier_module`: Uses the `FCNHead` by default
Expand All @@ -75,15 +75,15 @@ def fcn(
the `fc` layers can be dropped. This is particularly useful when loading weights from `torchvision`. By
default, `.fc` layer of a model is set to identity to avoid tracking weights.
- `torch_weights`: A `Path` or `URL` for the `PyTorch` weights. Defaults to `None`
- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
"""
if key is None:
key = jr.PRNGKey(0)
keys = jr.split(key, 2)

if classifier_module is None:
classifier_module = FCNHead
if backbone is None:
backbone = resnet.resnet50(
torch_weights=CLASSIFICATION_URLS["resnet50"],
replace_stride_with_dilation=[False, True, True],
)
num_layers = len(intermediate_layers(backbone))
Expand Down
Loading

0 comments on commit a1401ff

Please sign in to comment.