diff --git a/quantize.py b/quantize.py index 4ca4bb1fd..4d44e90fd 100644 --- a/quantize.py +++ b/quantize.py @@ -785,7 +785,7 @@ def create_quantized_state_dict(self): def convert_for_runtime(self): - replace_linear_int4(self.mod, self.device, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda) + replace_linear_int4(self.mod, self.device, self.groupsize, self.inner_k_tiles, self.padding_allowed) return self.mod def quantized_model(self) -> nn.Module: @@ -1295,8 +1295,8 @@ def make_names_and_values_dict_func(q, qparams): super().__init__() - def convert_for_runtime(self, use_cuda): - replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda) + def convert_for_runtime(self): + replace_linear_int4(self.mod, self.device, self.groupsize, self.inner_k_tiles, self.padding) return self.mod def quantized_model(self) -> nn.Module: @@ -1384,6 +1384,7 @@ def quantized_model(self) -> nn.Module: class WeightOnlyInt4HqqQuantHandler: def __init__(self, mod, device, *, groupsize): self.mod = mod + self.device = device self.groupsize = groupsize def create_quantized_state_dict(self): @@ -1407,15 +1408,15 @@ def create_quantized_state_dict(self): # we use Int4 packaged in an int8 for now, packing to follow # return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict() return WeightOnlyInt8QuantHandler( - self.mod, bitwidth=4, groupsize=self.groupsize + self.mod, self.device, bitwidth=4, groupsize=self.groupsize ).create_quantized_state_dict() def convert_for_runtime(self): # we use Int4 packaged in an int8 for now, packing to follow # ALSO: all code must work for CPU, CUDA, MPS - # return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True) + # return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime() return WeightOnlyInt4GPTQQuantHandler( - self.mod, bitwidth=4, groupsize=self.groupsize + self.mod, self.device, bitwidth=4, groupsize=self.groupsize ).convert_for_runtime() def quantized_model(self) -> nn.Module: