You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# if load: # this is deprecated
ckpts = sorted(Path(llama_ckpt_dir).glob("*.pth"))
for ckpt in tqdm(ckpts, desc="Loading LLaMA ckpt"):
ckpt = torch.load(ckpt, map_location='cuda:0')
names = self.llama.state_dict().keys()
ckpt_names = ckpt.keys()
for n in ckpt_names:
if n not in names:
print(f"Warning: {n} not in llama model")
self.llama.load_state_dict(ckpt, strict=False)
self.llama_keys = ["llama." + i for i in ckpt_names]
The text was updated successfully, but these errors were encountered:
您好,我想问一下为什么在这段加载llama的代码中加载速度特别慢:
we enforce loading to llama checkpoint
The text was updated successfully, but these errors were encountered: