Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantization Update #150

Open
wants to merge 35 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b6ee526
added utils
sydarb Aug 31, 2023
55f74c6
removed reconstruct
sydarb Aug 31, 2023
8871329
refactored quant modules
sydarb Aug 31, 2023
fcff933
modified quant methods
sydarb Aug 31, 2023
678b45e
modified quant base
sydarb Aug 31, 2023
a13775f
modified lapq
sydarb Aug 31, 2023
4ce8872
modified brecq
sydarb Aug 31, 2023
d214aa4
modified bitsplit
sydarb Aug 31, 2023
a0597b8
added qconfig helpers
sydarb Sep 3, 2023
c7e7e2d
added temp torch.ao.quantizer modules
sydarb Sep 3, 2023
2d7e239
updated quant model with bn act fusion
sydarb Sep 3, 2023
3a4462e
added new quantizer modules
sydarb Sep 3, 2023
e1b6108
updated quantizer methods
sydarb Sep 3, 2023
be251a8
fixed lapq method name bug
sydarb Sep 3, 2023
8a6b95f
fixed lapq refactoring errors
sydarb Sep 3, 2023
a51eaf9
customized model forward pass
sydarb Sep 4, 2023
761515f
modified qconfig generation
sydarb Sep 4, 2023
c4dd4b7
added base true quant
sydarb Sep 4, 2023
d0cf17b
added true quant to lapq
sydarb Sep 4, 2023
d5ef9dc
debug lapq true quant
sydarb Sep 5, 2023
a124452
made test progress bar as optional
sydarb Sep 5, 2023
1fde13e
modified quant dtype with reduced byte length
sydarb Sep 5, 2023
e31290e
made base mobilenetv2 model class cleaner
sydarb Sep 6, 2023
3c9b170
fixed quant module refactor bugs
sydarb Sep 6, 2023
7765fd5
added quantized basicblock
sydarb Sep 6, 2023
ec445c9
Revert "made base mobilenetv2 model class cleaner"
sydarb Sep 6, 2023
7afa24a
fixed lapq bugs
sydarb Sep 6, 2023
f8506bd
modified lapq expt nb
sydarb Sep 29, 2023
a37928a
added quant procedure info
sydarb Sep 29, 2023
797567a
removed nnq files
sydarb Sep 29, 2023
82bce33
changed class structure
sydarb Sep 30, 2023
a96122d
added diagrams to info
sydarb Oct 2, 2023
fb5b798
added observer methods
sydarb Oct 2, 2023
e760182
refactoring for observer approach
sydarb Oct 2, 2023
77ff4e8
modified diagram
sydarb Oct 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,064 changes: 565 additions & 499 deletions experiments/quantization/LAPQ/lapq_demo.ipynb

Large diffs are not rendered by default.

20 changes: 11 additions & 9 deletions trailmet/algorithms/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,20 @@ def accuracy(self, output, target, topk=(1, )):
res.append(correct_k.mul_(100.0 / batch_size))
return res

def test(self, model, dataloader, loss_fn=None, device=None):
def test(self, model, dataloader, loss_fn=None, device=None, progress=True):
"""This method is used to test the performance of the trained model."""
if device is None:
device = next(model.parameters()).device
else:
model.to(device)
model.eval()
counter = 0
tk1 = tqdm_notebook(dataloader, total=len(dataloader))
running_acc1 = 0
running_acc5 = 0
running_loss = 0
pbar = tqdm_notebook(dataloader, total=len(dataloader)) if progress else dataloader
with torch.no_grad():
for images, targets in tk1:
for images, targets in pbar:
counter += 1
images = images.to(device)
targets = targets.to(device)
Expand All @@ -204,13 +204,15 @@ def test(self, model, dataloader, loss_fn=None, device=None):
if loss_fn is not None:
loss = loss_fn(outputs, targets)
running_loss += loss.item()
tk1.set_postfix(
loss=running_loss / counter,
acc1=running_acc1 / counter,
acc5=running_acc5 / counter,
)
if progress:
pbar.set_postfix(
loss=running_loss / counter,
acc1=running_acc1 / counter,
acc5=running_acc5 / counter,
)
else:
tk1.set_postfix(acc1=running_acc1 / counter,
if progress:
pbar.set_postfix(acc1=running_acc1 / counter,
acc5=running_acc5 / counter)
if loss_fn is not None:
return running_acc1 / counter, running_loss / counter
Expand Down
55 changes: 5 additions & 50 deletions trailmet/algorithms/quantize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,8 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from .bitsplit import BitSplit
from .brecq import BRECQ
from .lapq import LAPQ
from .methods import (
UniformAffineQuantizer,
AdaRoundQuantizer,
BitSplitQuantizer,
ActQuantizer,
QuantizationBase,
UniformQuantization,
ClippedUniformQuantization,
FixedClipValueQuantization,
MaxAbsStaticQuantization,
LearnedStepSizeQuantization,
LpNormQuantization,
)
from .qmodel import (
QuantBasicBlock,
QuantBottleneck,
QuantInvertedResidual,
QuantModule,
BaseQuantBlock,
QBasicBlock,
QBottleneck,
QInvertedResidual,
ActivationModuleWrapper,
ParameterModuleWrapper,
)
from .quantize import (
BaseQuantization,
StraightThrough,
RoundSTE,
Conv2dFunctor,
LinearFunctor,
FoldBN,
)
from .reconstruct import (
StopForwardException,
DataSaverHook,
GetLayerInpOut,
save_inp_oup_data,
GradSaverHook,
GetLayerGrad,
save_grad_data,
LinearTempDecay,
LayerLossFunction,
layer_reconstruction,
BlockLossFunction,
block_reconstruction,
)

from . import quantize
from . import lapq
from . import bitsplit
from . import brecq
105 changes: 105 additions & 0 deletions trailmet/algorithms/quantize/_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
import torch.nn as nn
from typing import Dict, Callable
from trailmet.algorithms.quantize.observers import BaseObserver, MinMaxObserver, LpNormObserver
from trailmet.algorithms.quantize.utils import reshape_qparams_by_channel



OBSERVER_MAPPING: Dict[str, Callable] = {
'min_max': MinMaxObserver,
'lp_norm': LpNormObserver
}


class RoundSTE(torch.autograd.Function):
"""grad enabled round function"""
@staticmethod
def forward(ctx, input):
return torch.round(input)

@staticmethod
def backward(ctx, grad_output):
return grad_output


class FloorSTE(torch.autograd.Function):
"""grad enabled floor function"""
@staticmethod
def forward(ctx, input):
return torch.floor(input)

@staticmethod
def backward(ctx, grad_output):
return grad_output


class BaseQuantizer(nn.Module):
def __init__(self, kwargs: dict):
self.observer: BaseObserver = OBSERVER_MAPPING[kwargs.get(
'observer', 'min_max')](**kwargs)
self.quant_min = self.observer.quant_min
self.quant_max = self.observer.quant_max
self.per_channel = kwargs.get('per_channel', False)
self.ch_axis = kwargs.get('ch_axis', 0)
self.enable_observation = True
self.enable_quantization = True

def __register_buffer__(self, name, value):
if hasattr(self, name):
delattr(self, name)
self.register_buffer(name, value)

def __register_parameter__(self, name, value):
if hasattr(self, name):
delattr(self, name)
self.register_parameter(name, nn.Parameter(value))

def quantize(self, x: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor,
round_mode: str = 'nearest'):
if self.per_channel:
scale, zero_point = reshape_qparams_by_channel(
x, scale, zero_point, self.ch_axis)
if round_mode == 'nearest':
x_int = RoundSTE.apply(x / scale)
elif round_mode == 'stochastic':
x_floor = FloorSTE.apply(x / scale)
x_int = x_floor + torch.bernoulli((x / scale) - x_floor)
else:
raise NotImplementedError
x_quant = torch.clamp(x_int + zero_point, self.quant_min, self.quant_max)
return x_quant

def dequantize(self, x_quant: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor):
x_dequant = (x_quant - zero_point) * scale
return x_dequant

def reset_bitwidth(self, n_bits: int):
self.observer.reset_bitwidth(n_bits)
self.quant_min = self.observer.quant_min
self.quant_max = self.observer.quant_max


class UniformQuantizer(BaseQuantizer):
def __init__(self, kwargs: dict):
super().__init__(kwargs)
self.__register_buffer__('scale', torch.tensor([1.0], dtype=torch.float))
self.__register_buffer__('zero_point', torch.tensor([0], dtype=torch.int))

def forward(self, x: torch.Tensor):
if self.enable_observation:
x = self.observer(x)

if self.enable_quantization:
self.scale, self.zero_point = self.observer.calculate_qparams()
self.scale, self.zero_point = self.scale.to(x.device), self.zero_point.to(x.device)
x_quant = self.quantize(x, self.scale, self.zero_point)
x_dequant = self.dequantize(x_quant, self.scale, self.zero_point)
return x_dequant

return x


class AdaRoundQuantizer(BaseQuantizer):
def __init__(self, kwargs: dict):
super().__init__(kwargs)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading