Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
metascroy committed Apr 5, 2024
1 parent d9eec1d commit 7893225
Showing 1 changed file with 38 additions and 22 deletions.
60 changes: 38 additions & 22 deletions gguf_util/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Mapping
from typing import Any, Mapping, Dict
import logging

import gguf
Expand Down Expand Up @@ -100,6 +100,8 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str:


def _build_model_args(metadata: dict[str, Any]) -> GGUFModelArgs:
arch = metadata["general.architecture"]
assert arch == "llama", f"Only general.architecture=llama is supported, but got general.architecture={arch}"
return GGUFModelArgs(
arch=arch,
embedding_length=metadata[f"{arch}.embedding_length"],
Expand Down Expand Up @@ -147,16 +149,15 @@ def _fqn_last(fqn: str) -> str:


def _load_by_state_dict(pt_model: torch.nn.Module, state_dict: Dict[str, Any], fqn: str, gguf_tensor: ReaderTensor) -> bool:
assert fqn in state_dict
if gguf.tensor_type in (GGMLQuantizationType.F32, GGMLQuantizationType.F16):
reversed_shape = tensor.shape[::-1]
new_tensor = tensor.data.reshape(reversed_shape)
if gguf_tensor.tensor_type in (gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16):
reversed_shape = gguf_tensor.shape[::-1]
new_tensor = gguf_tensor.data.reshape(reversed_shape)
state_dict[fqn] = torch.from_numpy(new_tensor)
return True
elif tensor.tensor_type == GGMLQuantizationType.Q4_0 and tensor.name == "token_embd.weight":
unpacked = Q4_0.to_float(torch.from_numpy(tensor.data.reshape(-1, 18)))
elif gguf_tensor.tensor_type == gguf.GGMLQuantizationType.Q4_0 and gguf_tensor.name == "token_embd.weight":
unpacked = Q4_0.to_float(torch.from_numpy(gguf_tensor.data.reshape(-1, 18)))
state_dict[fqn] = unpacked.reshape(
pt_model.params.vocab_size, pt_model.params.dim
pt_model.config.vocab_size, pt_model.config.dim
)
return True
return False
Expand All @@ -166,9 +167,9 @@ def _load_by_parameter(pt_model: torch.nn.Module, fqn: str, gguf_tensor: ReaderT
assert isinstance(_fqn_lookup(fqn, pt_model), torch.nn.Parameter)
parent: torch.nn.Module = _fqn_lookup(_fqn_up(fqn), pt_model)

if tensor.tensor_type == GGMLQuantizationType.Q4_0 and isinstance(parent, torch.nn.Linear) and _fqn_last(fqn) == "weight":
print(fqn, tensor.shape, tensor.data.shape, parent.weight.shape)
packed = torch.from_numpy(tensor.data).reshape(-1, 18)
if gguf_tensor.tensor_type == gguf.GGMLQuantizationType.Q4_0 and isinstance(parent, torch.nn.Linear) and _fqn_last(fqn) == "weight":
print(fqn, gguf_tensor.shape, gguf_tensor.data.shape, parent.weight.shape)
packed = torch.from_numpy(gguf_tensor.data).reshape(-1, 18)
scale = torch.tensor(Q4_0._unpack_two_uint8(packed[:, :2]), dtype=torch.float16)
parent.weight = torch.nn.Parameter(
Q4_0.GGMLInt4LinearWeight(packed, scale, parent.weight.shape)
Expand All @@ -179,12 +180,13 @@ def _load_by_parameter(pt_model: torch.nn.Module, fqn: str, gguf_tensor: ReaderT


def _load_weights(pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor]) -> None:
loaded_by_state_dict: Set[str] = {}
loaded_by_parameter: Set[str] = {}
loaded_by_state_dict: Set[str] = set()
loaded_by_parameter: Set[str] = set()

# state_dict pass
logger.info("Loading weights by state_dict.")
state_dict = {}
for fqn, _ model.state_dict():
for fqn in pt_model.state_dict():
if fqn not in weight_map:
continue
tensor = weight_map[fqn]
Expand All @@ -196,6 +198,7 @@ def _load_weights(pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor]
pt_model.load_state_dict(state_dict, strict=False)

# parameter pass
logger.info("Loading weights by parameter.")
for fqn, param in pt_model.named_parameters():
if fqn not in weight_map:
continue
Expand All @@ -206,13 +209,20 @@ def _load_weights(pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor]

# Sanity checks
for fqn in loaded_by_state_dict:
assert fqn not in loaded_by_parameter, f"{fqn} was loaded by both state_dict and parameter"
# assert fqn not in loaded_by_parameter, f"{fqn} was loaded by both state_dict and parameter"
if not(fqn not in loaded_by_parameter):
print(f"{fqn} was loaded by both state_dict and parameter")


for fqn in weight_map:
assert fqn in (loaded_by_state_dict | loaded_by_parameter), f"{fqn} in weight_map was not loaded"
# assert fqn in (loaded_by_state_dict | loaded_by_parameter), f"{fqn} in weight_map was not loaded"
if not (fqn in (loaded_by_state_dict | loaded_by_parameter)):
print(f"{fqn} in weight_map was not loaded")

for fqn, _ model.state_dict():
assert fqn in (loaded_by_state_dict | loaded_by_parameter), f"{fqn} in model.state_dict() was not loaded"
for fqn in pt_model.state_dict():
# assert fqn in (loaded_by_state_dict | loaded_by_parameter), f"{fqn} in model.state_dict() was not loaded"
if not (fqn in (loaded_by_state_dict | loaded_by_parameter)):
print(f"{fqn} in model.state_dict() was not loaded")


def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]:
Expand Down Expand Up @@ -244,22 +254,28 @@ def load_llama_from_gguf_file(gguf_file: str) -> torch.nn.Module:
if not Path(gguf_file).is_file():
raise ValueError(f"Could not find file {gguf_file}")

logger.info("Parsing GGUF metadata.")
reader = gguf.GGUFReader(gguf_file, "r")

metadata = _get_metadata(reader)
model_args = _build_model_args(metadata)
assert (
gguf_model_args.arch == "llama"
model_args.arch == "llama"
), "Only LLaMa models are supported by this converter."

logger.info("Creating initial PT model.")
pt_model = _create_pt_model(model_args)


logger.info("Reading GGUF weights.")
gguf_weights = GGUFWeights(tensors=reader.tensors)
pt_model = _create_pt_model(gguf_model_args)

logger.info("Building GGUF weight map.")
# map from fqn in pt_model to gguf tensor
weight_map = {
_convert_gguf_tensor_name_to_llama_nn(tensor.name): tensor
for tensor in gguf_weights.tensors
}
logger.info("Loading GGUF weights into PT model.")
_load_weights(pt_model, weight_map)

return pt_model
return pt_model, weight_map

0 comments on commit 7893225

Please sign in to comment.