diff --git a/exllamav2/module.py b/exllamav2/module.py index 511a7ceb..fbf8b79b 100644 --- a/exllamav2/module.py +++ b/exllamav2/module.py @@ -25,44 +25,51 @@ def device(self): return _torch_device(self.device_idx) + def load_multi(self, keys): + + tensors = {} + submap = {} + submap_i = {} + + for k in keys: + ck = self.key + "." + k + if ck in self.model.config.tensor_file_map: + submap[k] = self.model.config.tensor_file_map[ck] + + for k, v in submap.items(): + if v not in submap_i: + submap_i[v] = [] + submap_i[v].append(k) + + for v, ks in submap_i.items(): + with safe_open(v, framework="pt", device="cpu") as st: + for k in ks: + tensors[k] = st.get_tensor(self.key + "." + k).to(self.device()) + + return tensors + + def load_weight(self): # EXL2 if self.key + ".q_weight" in self.model.config.tensor_file_map: - filename = self.model.config.tensor_file_map[self.key + ".q_weight"] - with safe_open(filename, framework = "pt", device = "cpu") as st: - qtensors = {} - qtensors["q_weight"] = st.get_tensor(self.key + ".q_weight").to(self.device()) - qtensors["q_invperm"] = st.get_tensor(self.key + ".q_invperm").to(self.device()) - qtensors["q_scale"] = st.get_tensor(self.key + ".q_scale").to(self.device()) - qtensors["q_scale_max"] = st.get_tensor(self.key + ".q_scale_max").to(self.device()) - qtensors["q_groups"] = st.get_tensor(self.key + ".q_groups").to(self.device()) - qtensors["q_perm"] = torch.argsort(qtensors["q_invperm"]).to(torch.int) - return qtensors + qtensors = self.load_multi(["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm"]) + qtensors["q_perm"] = torch.argsort(qtensors["q_invperm"]).to(torch.int) + return qtensors # GPTQ if self.key + ".qweight" in self.model.config.tensor_file_map: - filename = self.model.config.tensor_file_map[self.key + ".qweight"] - with safe_open(filename, framework = "pt", device = "cpu") as st: - qtensors = {} - qtensors["qweight"] = st.get_tensor(self.key + ".qweight").to(self.device()) - qtensors["qzeros"] = st.get_tensor(self.key + ".qzeros").to(self.device()) - qtensors["scales"] = st.get_tensor(self.key + ".scales").to(self.device()) - if self.key + ".g_idx" in self.model.config.tensor_file_map: - qtensors["g_idx"] = st.get_tensor(self.key + ".g_idx").to(self.device()) - return qtensors + qtensors = self.load_multi(["qweight", "qzeros", "scales", "g_idx"]) + return qtensors # Torch if self.key + ".weight" in self.model.config.tensor_file_map: - filename = self.model.config.tensor_file_map[self.key + ".weight"] - with safe_open(filename, framework = "pt", device = "cpu") as st: - weight = st.get_tensor(self.key + ".weight") - weight = weight.half() - weight = weight.to(self.device()) - return nn.Parameter(weight) + tensor = self.load_multi(["weight"])["weight"] + tensor = tensor.half() + return nn.Parameter(tensor) def set_device_idx(self, idx):