Skip to content

Commit

Permalink
Fix load quant tensors that span multiple shards
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Sep 11, 2023
1 parent 5dc32f0 commit 49c8d9e
Showing 1 changed file with 32 additions and 25 deletions.
57 changes: 32 additions & 25 deletions exllamav2/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,44 +25,51 @@ def device(self):
return _torch_device(self.device_idx)


def load_multi(self, keys):

tensors = {}
submap = {}
submap_i = {}

for k in keys:
ck = self.key + "." + k
if ck in self.model.config.tensor_file_map:
submap[k] = self.model.config.tensor_file_map[ck]

for k, v in submap.items():
if v not in submap_i:
submap_i[v] = []
submap_i[v].append(k)

for v, ks in submap_i.items():
with safe_open(v, framework="pt", device="cpu") as st:
for k in ks:
tensors[k] = st.get_tensor(self.key + "." + k).to(self.device())

return tensors


def load_weight(self):

# EXL2

if self.key + ".q_weight" in self.model.config.tensor_file_map:
filename = self.model.config.tensor_file_map[self.key + ".q_weight"]
with safe_open(filename, framework = "pt", device = "cpu") as st:
qtensors = {}
qtensors["q_weight"] = st.get_tensor(self.key + ".q_weight").to(self.device())
qtensors["q_invperm"] = st.get_tensor(self.key + ".q_invperm").to(self.device())
qtensors["q_scale"] = st.get_tensor(self.key + ".q_scale").to(self.device())
qtensors["q_scale_max"] = st.get_tensor(self.key + ".q_scale_max").to(self.device())
qtensors["q_groups"] = st.get_tensor(self.key + ".q_groups").to(self.device())
qtensors["q_perm"] = torch.argsort(qtensors["q_invperm"]).to(torch.int)
return qtensors
qtensors = self.load_multi(["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm"])
qtensors["q_perm"] = torch.argsort(qtensors["q_invperm"]).to(torch.int)
return qtensors

# GPTQ

if self.key + ".qweight" in self.model.config.tensor_file_map:
filename = self.model.config.tensor_file_map[self.key + ".qweight"]
with safe_open(filename, framework = "pt", device = "cpu") as st:
qtensors = {}
qtensors["qweight"] = st.get_tensor(self.key + ".qweight").to(self.device())
qtensors["qzeros"] = st.get_tensor(self.key + ".qzeros").to(self.device())
qtensors["scales"] = st.get_tensor(self.key + ".scales").to(self.device())
if self.key + ".g_idx" in self.model.config.tensor_file_map:
qtensors["g_idx"] = st.get_tensor(self.key + ".g_idx").to(self.device())
return qtensors
qtensors = self.load_multi(["qweight", "qzeros", "scales", "g_idx"])
return qtensors

# Torch

if self.key + ".weight" in self.model.config.tensor_file_map:
filename = self.model.config.tensor_file_map[self.key + ".weight"]
with safe_open(filename, framework = "pt", device = "cpu") as st:
weight = st.get_tensor(self.key + ".weight")
weight = weight.half()
weight = weight.to(self.device())
return nn.Parameter(weight)
tensor = self.load_multi(["weight"])["weight"]
tensor = tensor.half()
return nn.Parameter(tensor)


def set_device_idx(self, idx):
Expand Down

0 comments on commit 49c8d9e

Please sign in to comment.