Skip to content

Commit

Permalink
fix typo (code dupe cute/paste)
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Gschwind committed Apr 9, 2024
1 parent 241bac5 commit 568947b
Showing 1 changed file with 0 additions and 80 deletions.
80 changes: 0 additions & 80 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,86 +474,6 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
### HQQ

class WeightOnlyInt4HqqQuantHandler:
def __init__(self, mod, group_size):
self.mod = mod
self.groupsize = group_size

def _create_quantized_state_dict(self):
from hqq.core.quantize import Quantizer # TODO maybe torchao

for m in self.mod.modules():
for name, child in m.named_children():
if isinstance(child, torch.nn.Linear):
child.weight = torch.nn.Parameter(
Quantizer.dequantize(
*Quantizer.quantize(
child.weight,
nbits=4,
group_size=self.groupsize,
axis=1,
)
)
)

return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict()

def _convert_for_runtime(self):
return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True)


class WeightOnlyInt4HqqQuantHandler:
def __init__(self, mod, groupsize):
self.mod = mod
self.groupsize = groupsize

def _create_quantized_state_dict(self):
from hqq.core.quantize import Quantizer # TODO maybe torchao

for m in self.mod.modules():
for name, child in m.named_children():
if isinstance(child, torch.nn.Linear):
child.weight = torch.nn.Parameter(
Quantizer.dequantize(
*Quantizer.quantize(
child.weight,
nbits=4,
group_size=self.groupsize,
axis=1,
)
)
)

return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict()

def _convert_for_runtime(self):
return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True)
class WeightOnlyInt4HqqQuantHandler:
def __init__(self, mod, groupsize):
self.mod = mod
self.groupsize = groupsize

def _create_quantized_state_dict(self):
from hqq.core.quantize import Quantizer # TODO maybe torchao

for m in self.mod.modules():
for name, child in m.named_children():
if isinstance(child, torch.nn.Linear):
child.weight = torch.nn.Parameter(
Quantizer.dequantize(
*Quantizer.quantize(
child.weight,
nbits=4,
group_size=self.groupsize,
axis=1,
)
)
)

return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict()

def _convert_for_runtime(self):
return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True)
class WeightOnlyInt4HqqQuantHandler:
def __init__(self, mod, groupsize):
self.mod = mod
self.groupsize = groupsize
Expand Down

0 comments on commit 568947b

Please sign in to comment.