-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
achaiah
authored and
achaiah
committed
Jan 15, 2020
1 parent
1a15aff
commit 1dc03c1
Showing
8 changed files
with
316 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,8 +29,18 @@ have [docs](https://pywick.readthedocs.io/en/latest/)! They're still a | |
work in progress though so apologies for anything that's broken. | ||
|
||
## What's New (highlights) | ||
- **Jan. 15, 2020** | ||
- New release: 0.5.5 | ||
- Mish activation function (SoTA) | ||
- [rwightman's](https://github.com/rwightman/gen-efficientnet-pytorch) models of pretrained/ported variants for classification (44 total) | ||
- efficientnet Tensorflow port b0-b8, with and without AP, el/em/es, cc | ||
- mixnet L/M/S | ||
- mobilenetv3 | ||
- mnasnet | ||
- spnasnet | ||
- Additional loss functions | ||
- **Aug. 1, 2019** | ||
- New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet | ||
- New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet | ||
- New Loss Functions: Focal Tversky Loss, OHEM CrossEntropy Loss, various combination losses | ||
- Major restructuring and standardization of NN models and loading functionality | ||
- General bug fixes and code improvements | ||
|
@@ -40,8 +50,7 @@ work in progress though so apologies for anything that's broken. | |
|
||
or specific version from git: | ||
|
||
`pip | ||
install git+https://github.com/achaiah/[email protected]` | ||
`pip install git+https://github.com/achaiah/[email protected]` | ||
|
||
## ModuleTrainer | ||
The `ModuleTrainer` class provides a high-level training interface which abstracts | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,8 +29,18 @@ have [docs](https://pywick.readthedocs.io/en/latest/)! They're still a | |
work in progress though so apologies for anything that's broken. | ||
|
||
## What's New (highlights) | ||
- **Jan. 15, 2020** | ||
- New release: 0.5.5 | ||
- Mish activation function (SoTA) | ||
- [rwightman's](https://github.com/rwightman/gen-efficientnet-pytorch) models of pretrained/ported variants for classification (44 total) | ||
- efficientnet Tensorflow port b0-b8, with and without AP, el/em/es, cc | ||
- mixnet L/M/S | ||
- mobilenetv3 | ||
- mnasnet | ||
- spnasnet | ||
- Additional loss functions | ||
- **Aug. 1, 2019** | ||
- New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet | ||
- New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet | ||
- New Loss Functions: Focal Tversky Loss, OHEM CrossEntropy Loss, various combination losses | ||
- Major restructuring and standardization of NN models and loading functionality | ||
- General bug fixes and code improvements | ||
|
@@ -40,8 +50,7 @@ work in progress though so apologies for anything that's broken. | |
|
||
or specific version from git: | ||
|
||
`pip | ||
install git+https://github.com/achaiah/[email protected]` | ||
`pip install git+https://github.com/achaiah/[email protected]` | ||
|
||
## ModuleTrainer | ||
The `ModuleTrainer` class provides a high-level training interface which abstracts | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Source: https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/activations/activations_autofn.py (Apache 2.0) | ||
|
||
import torch | ||
from torch import nn as nn | ||
from torch.nn import functional as F | ||
|
||
|
||
__all__ = ['swish_auto', 'SwishAuto', 'mish_auto', 'MishAuto'] | ||
|
||
|
||
class SwishAutoFn(torch.autograd.Function): | ||
"""Swish - Described in: https://arxiv.org/abs/1710.05941 | ||
Memory efficient variant from: | ||
https://medium.com/the-artificial-impostor/more-memory-efficient-swish-activation-function-e07c22c12a76 | ||
""" | ||
@staticmethod | ||
def forward(ctx, x): | ||
result = x.mul(torch.sigmoid(x)) | ||
ctx.save_for_backward(x) | ||
return result | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
x = ctx.saved_tensors[0] | ||
x_sigmoid = torch.sigmoid(x) | ||
return grad_output.mul(x_sigmoid * (1 + x * (1 - x_sigmoid))) | ||
|
||
|
||
def swish_auto(x, inplace=False): | ||
# inplace ignored | ||
return SwishAutoFn.apply(x) | ||
|
||
|
||
class SwishAuto(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(SwishAuto, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return SwishAutoFn.apply(x) | ||
|
||
|
||
class MishAutoFn(torch.autograd.Function): | ||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 | ||
Experimental memory-efficient variant | ||
""" | ||
|
||
@staticmethod | ||
def forward(ctx, x): | ||
ctx.save_for_backward(x) | ||
y = x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x))) | ||
return y | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
x = ctx.saved_tensors[0] | ||
x_sigmoid = torch.sigmoid(x) | ||
x_tanh_sp = F.softplus(x).tanh() | ||
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) | ||
|
||
|
||
def mish_auto(x, inplace=False): | ||
# inplace ignored | ||
return MishAutoFn.apply(x) | ||
|
||
|
||
class MishAuto(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(MishAuto, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return MishAutoFn.apply(x) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Source: https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/activations/activations_jit.py (Apache 2.0) | ||
|
||
import torch | ||
from torch import nn as nn | ||
from torch.nn import functional as F | ||
|
||
|
||
__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit'] | ||
#'hard_swish_jit', 'HardSwishJit', 'hard_sigmoid_jit', 'HardSigmoidJit'] | ||
|
||
|
||
@torch.jit.script | ||
def swish_jit_fwd(x): | ||
return x.mul(torch.sigmoid(x)) | ||
|
||
|
||
@torch.jit.script | ||
def swish_jit_bwd(x, grad_output): | ||
x_sigmoid = torch.sigmoid(x) | ||
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) | ||
|
||
|
||
class SwishJitAutoFn(torch.autograd.Function): | ||
""" torch.jit.script optimised Swish | ||
Inspired by conversation btw Jeremy Howard & Adam Pazske | ||
https://twitter.com/jeremyphoward/status/1188251041835315200 | ||
""" | ||
@staticmethod | ||
def forward(ctx, x): | ||
ctx.save_for_backward(x) | ||
return swish_jit_fwd(x) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
x = ctx.saved_tensors[0] | ||
return swish_jit_bwd(x, grad_output) | ||
|
||
|
||
def swish_jit(x, inplace=False): | ||
# inplace ignored | ||
return SwishJitAutoFn.apply(x) | ||
|
||
|
||
class SwishJit(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(SwishJit, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return SwishJitAutoFn.apply(x) | ||
|
||
|
||
@torch.jit.script | ||
def mish_jit_fwd(x): | ||
return x.mul(torch.tanh(F.softplus(x))) | ||
|
||
|
||
@torch.jit.script | ||
def mish_jit_bwd(x, grad_output): | ||
x_sigmoid = torch.sigmoid(x) | ||
x_tanh_sp = F.softplus(x).tanh() | ||
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) | ||
|
||
|
||
class MishJitAutoFn(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, x): | ||
ctx.save_for_backward(x) | ||
return mish_jit_fwd(x) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
x = ctx.saved_tensors[0] | ||
return mish_jit_bwd(x, grad_output) | ||
|
||
|
||
def mish_jit(x, inplace=False): | ||
# inplace ignored | ||
return MishJitAutoFn.apply(x) | ||
|
||
|
||
class MishJit(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(MishJit, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return MishJitAutoFn.apply(x) | ||
|
||
|
||
# @torch.jit.script | ||
# def hard_swish_jit(x, inplac: bool = False): | ||
# return x.mul(F.relu6(x + 3.).mul_(1./6.)) | ||
# | ||
# | ||
# class HardSwishJit(nn.Module): | ||
# def __init__(self, inplace: bool = False): | ||
# super(HardSwishJit, self).__init__() | ||
# | ||
# def forward(self, x): | ||
# return hard_swish_jit(x) | ||
# | ||
# | ||
# @torch.jit.script | ||
# def hard_sigmoid_jit(x, inplace: bool = False): | ||
# return F.relu6(x + 3.).mul(1./6.) | ||
# | ||
# | ||
# class HardSigmoidJit(nn.Module): | ||
# def __init__(self, inplace: bool = False): | ||
# super(HardSigmoidJit, self).__init__() | ||
# | ||
# def forward(self, x): | ||
# return hard_sigmoid_jit(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Source: https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/activations/activations.py (Apache 2.0) | ||
# Note. Cuda-compiled source can be found here: https://github.com/thomasbrandon/mish-cuda (MIT) | ||
|
||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
def mish(x, inplace: bool = False): | ||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 | ||
""" | ||
return x.mul(F.softplus(x).tanh()) | ||
|
||
class Mish(nn.Module): | ||
""" | ||
Mish - "Mish: A Self Regularized Non-Monotonic Neural Activation Function" | ||
https://arxiv.org/abs/1908.08681v1 | ||
implemented for PyTorch / FastAI by lessw2020 | ||
github: https://github.com/lessw2020/mish | ||
""" | ||
def __init__(self, inplace: bool = False): | ||
super(Mish, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return mish(x, self.inplace) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.