diff --git a/exllamav2/module.py b/exllamav2/module.py
index 511a7ceb..fbf8b79b 100644
--- a/exllamav2/module.py
+++ b/exllamav2/module.py
@@ -25,44 +25,51 @@ def device(self):
         return _torch_device(self.device_idx)
 
 
+    def load_multi(self, keys):
+
+        tensors = {}
+        submap = {}
+        submap_i = {}
+
+        for k in keys:
+            ck = self.key + "." + k
+            if ck in self.model.config.tensor_file_map:
+                submap[k] = self.model.config.tensor_file_map[ck]
+
+        for k, v in submap.items():
+            if v not in submap_i:
+                submap_i[v] = []
+            submap_i[v].append(k)
+
+        for v, ks in submap_i.items():
+            with safe_open(v, framework="pt", device="cpu") as st:
+                for k in ks:
+                    tensors[k] = st.get_tensor(self.key + "." + k).to(self.device())
+
+        return tensors
+
+
     def load_weight(self):
 
         # EXL2
 
         if self.key + ".q_weight" in self.model.config.tensor_file_map:
-            filename = self.model.config.tensor_file_map[self.key + ".q_weight"]
-            with safe_open(filename, framework = "pt", device = "cpu") as st:
-                qtensors = {}
-                qtensors["q_weight"] = st.get_tensor(self.key + ".q_weight").to(self.device())
-                qtensors["q_invperm"] = st.get_tensor(self.key + ".q_invperm").to(self.device())
-                qtensors["q_scale"] = st.get_tensor(self.key + ".q_scale").to(self.device())
-                qtensors["q_scale_max"] = st.get_tensor(self.key + ".q_scale_max").to(self.device())
-                qtensors["q_groups"] = st.get_tensor(self.key + ".q_groups").to(self.device())
-                qtensors["q_perm"] = torch.argsort(qtensors["q_invperm"]).to(torch.int)
-                return qtensors
+            qtensors = self.load_multi(["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm"])
+            qtensors["q_perm"] = torch.argsort(qtensors["q_invperm"]).to(torch.int)
+            return qtensors
 
         # GPTQ
 
         if self.key + ".qweight" in self.model.config.tensor_file_map:
-            filename = self.model.config.tensor_file_map[self.key + ".qweight"]
-            with safe_open(filename, framework = "pt", device = "cpu") as st:
-                qtensors = {}
-                qtensors["qweight"] = st.get_tensor(self.key + ".qweight").to(self.device())
-                qtensors["qzeros"] = st.get_tensor(self.key + ".qzeros").to(self.device())
-                qtensors["scales"] = st.get_tensor(self.key + ".scales").to(self.device())
-                if self.key + ".g_idx" in self.model.config.tensor_file_map:
-                    qtensors["g_idx"] = st.get_tensor(self.key + ".g_idx").to(self.device())
-                return qtensors
+            qtensors = self.load_multi(["qweight", "qzeros", "scales", "g_idx"])
+            return qtensors
 
         # Torch
 
         if self.key + ".weight" in self.model.config.tensor_file_map:
-            filename = self.model.config.tensor_file_map[self.key + ".weight"]
-            with safe_open(filename, framework = "pt", device = "cpu") as st:
-                weight = st.get_tensor(self.key + ".weight")
-                weight = weight.half()
-                weight = weight.to(self.device())
-                return nn.Parameter(weight)
+            tensor = self.load_multi(["weight"])["weight"]
+            tensor = tensor.half()
+            return nn.Parameter(tensor)
 
 
     def set_device_idx(self, idx):