Skip to content

Commit

Permalink
typo
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Gschwind committed Apr 16, 2024
1 parent a6e2ca3 commit 7276777
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ def create_quantized_state_dict(self):


def convert_for_runtime(self):
replace_linear_int4(self.mod, self.device, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda)
replace_linear_int4(self.mod, self.device, self.groupsize, self.inner_k_tiles, self.padding_allowed)
return self.mod

def quantized_model(self) -> nn.Module:
Expand Down Expand Up @@ -1295,8 +1295,8 @@ def make_names_and_values_dict_func(q, qparams):
super().__init__()


def convert_for_runtime(self, use_cuda):
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda)
def convert_for_runtime(self):
replace_linear_int4(self.mod, self.device, self.groupsize, self.inner_k_tiles, self.padding)
return self.mod

def quantized_model(self) -> nn.Module:
Expand Down Expand Up @@ -1384,6 +1384,7 @@ def quantized_model(self) -> nn.Module:
class WeightOnlyInt4HqqQuantHandler:
def __init__(self, mod, device, *, groupsize):
self.mod = mod
self.device = device
self.groupsize = groupsize

def create_quantized_state_dict(self):
Expand All @@ -1407,15 +1408,15 @@ def create_quantized_state_dict(self):
# we use Int4 packaged in an int8 for now, packing to follow
# return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict()
return WeightOnlyInt8QuantHandler(
self.mod, bitwidth=4, groupsize=self.groupsize
self.mod, self.device, bitwidth=4, groupsize=self.groupsize
).create_quantized_state_dict()

def convert_for_runtime(self):
# we use Int4 packaged in an int8 for now, packing to follow
# ALSO: all code must work for CPU, CUDA, MPS
# return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True)
# return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime()
return WeightOnlyInt4GPTQQuantHandler(
self.mod, bitwidth=4, groupsize=self.groupsize
self.mod, self.device, bitwidth=4, groupsize=self.groupsize
).convert_for_runtime()

def quantized_model(self) -> nn.Module:
Expand Down

0 comments on commit 7276777

Please sign in to comment.