diff --git a/gguf_util/loader.py b/gguf_util/loader.py index f0528dcb6..e557866bd 100644 --- a/gguf_util/loader.py +++ b/gguf_util/loader.py @@ -11,7 +11,7 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import Any, Mapping +from typing import Any, Mapping, Dict import logging import gguf @@ -100,6 +100,8 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str: def _build_model_args(metadata: dict[str, Any]) -> GGUFModelArgs: + arch = metadata["general.architecture"] + assert arch == "llama", f"Only general.architecture=llama is supported, but got general.architecture={arch}" return GGUFModelArgs( arch=arch, embedding_length=metadata[f"{arch}.embedding_length"], @@ -147,16 +149,15 @@ def _fqn_last(fqn: str) -> str: def _load_by_state_dict(pt_model: torch.nn.Module, state_dict: Dict[str, Any], fqn: str, gguf_tensor: ReaderTensor) -> bool: - assert fqn in state_dict - if gguf.tensor_type in (GGMLQuantizationType.F32, GGMLQuantizationType.F16): - reversed_shape = tensor.shape[::-1] - new_tensor = tensor.data.reshape(reversed_shape) + if gguf_tensor.tensor_type in (gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16): + reversed_shape = gguf_tensor.shape[::-1] + new_tensor = gguf_tensor.data.reshape(reversed_shape) state_dict[fqn] = torch.from_numpy(new_tensor) return True - elif tensor.tensor_type == GGMLQuantizationType.Q4_0 and tensor.name == "token_embd.weight": - unpacked = Q4_0.to_float(torch.from_numpy(tensor.data.reshape(-1, 18))) + elif gguf_tensor.tensor_type == gguf.GGMLQuantizationType.Q4_0 and gguf_tensor.name == "token_embd.weight": + unpacked = Q4_0.to_float(torch.from_numpy(gguf_tensor.data.reshape(-1, 18))) state_dict[fqn] = unpacked.reshape( - pt_model.params.vocab_size, pt_model.params.dim + pt_model.config.vocab_size, pt_model.config.dim ) return True return False @@ -166,9 +167,9 @@ def _load_by_parameter(pt_model: torch.nn.Module, fqn: str, gguf_tensor: ReaderT assert isinstance(_fqn_lookup(fqn, pt_model), torch.nn.Parameter) parent: torch.nn.Module = _fqn_lookup(_fqn_up(fqn), pt_model) - if tensor.tensor_type == GGMLQuantizationType.Q4_0 and isinstance(parent, torch.nn.Linear) and _fqn_last(fqn) == "weight": - print(fqn, tensor.shape, tensor.data.shape, parent.weight.shape) - packed = torch.from_numpy(tensor.data).reshape(-1, 18) + if gguf_tensor.tensor_type == gguf.GGMLQuantizationType.Q4_0 and isinstance(parent, torch.nn.Linear) and _fqn_last(fqn) == "weight": + print(fqn, gguf_tensor.shape, gguf_tensor.data.shape, parent.weight.shape) + packed = torch.from_numpy(gguf_tensor.data).reshape(-1, 18) scale = torch.tensor(Q4_0._unpack_two_uint8(packed[:, :2]), dtype=torch.float16) parent.weight = torch.nn.Parameter( Q4_0.GGMLInt4LinearWeight(packed, scale, parent.weight.shape) @@ -179,12 +180,13 @@ def _load_by_parameter(pt_model: torch.nn.Module, fqn: str, gguf_tensor: ReaderT def _load_weights(pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor]) -> None: - loaded_by_state_dict: Set[str] = {} - loaded_by_parameter: Set[str] = {} + loaded_by_state_dict: Set[str] = set() + loaded_by_parameter: Set[str] = set() # state_dict pass + logger.info("Loading weights by state_dict.") state_dict = {} - for fqn, _ model.state_dict(): + for fqn in pt_model.state_dict(): if fqn not in weight_map: continue tensor = weight_map[fqn] @@ -196,6 +198,7 @@ def _load_weights(pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor] pt_model.load_state_dict(state_dict, strict=False) # parameter pass + logger.info("Loading weights by parameter.") for fqn, param in pt_model.named_parameters(): if fqn not in weight_map: continue @@ -206,13 +209,20 @@ def _load_weights(pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor] # Sanity checks for fqn in loaded_by_state_dict: - assert fqn not in loaded_by_parameter, f"{fqn} was loaded by both state_dict and parameter" + # assert fqn not in loaded_by_parameter, f"{fqn} was loaded by both state_dict and parameter" + if not(fqn not in loaded_by_parameter): + print(f"{fqn} was loaded by both state_dict and parameter") + for fqn in weight_map: - assert fqn in (loaded_by_state_dict | loaded_by_parameter), f"{fqn} in weight_map was not loaded" + # assert fqn in (loaded_by_state_dict | loaded_by_parameter), f"{fqn} in weight_map was not loaded" + if not (fqn in (loaded_by_state_dict | loaded_by_parameter)): + print(f"{fqn} in weight_map was not loaded") - for fqn, _ model.state_dict(): - assert fqn in (loaded_by_state_dict | loaded_by_parameter), f"{fqn} in model.state_dict() was not loaded" + for fqn in pt_model.state_dict(): + # assert fqn in (loaded_by_state_dict | loaded_by_parameter), f"{fqn} in model.state_dict() was not loaded" + if not (fqn in (loaded_by_state_dict | loaded_by_parameter)): + print(f"{fqn} in model.state_dict() was not loaded") def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]: @@ -244,22 +254,28 @@ def load_llama_from_gguf_file(gguf_file: str) -> torch.nn.Module: if not Path(gguf_file).is_file(): raise ValueError(f"Could not find file {gguf_file}") + logger.info("Parsing GGUF metadata.") reader = gguf.GGUFReader(gguf_file, "r") - metadata = _get_metadata(reader) model_args = _build_model_args(metadata) assert ( - gguf_model_args.arch == "llama" + model_args.arch == "llama" ), "Only LLaMa models are supported by this converter." + logger.info("Creating initial PT model.") + pt_model = _create_pt_model(model_args) + + + logger.info("Reading GGUF weights.") gguf_weights = GGUFWeights(tensors=reader.tensors) - pt_model = _create_pt_model(gguf_model_args) + logger.info("Building GGUF weight map.") # map from fqn in pt_model to gguf tensor weight_map = { _convert_gguf_tensor_name_to_llama_nn(tensor.name): tensor for tensor in gguf_weights.tensors } + logger.info("Loading GGUF weights into PT model.") _load_weights(pt_model, weight_map) - return pt_model + return pt_model, weight_map