diff --git a/build/builder.py b/build/builder.py index 8c156a911..674a58afc 100644 --- a/build/builder.py +++ b/build/builder.py @@ -154,21 +154,13 @@ def device_sync(device): sys.path.append(str(wd)) -def _load_model(builder_args): - if builder_args.gguf_path: - model = Transformer.from_gguf(builder_args.gguf_path) - - # TODO: to take advantage of mmap, maybe we write converted gguf to file - # and read back in? - # TODO: should we add check that builder_args.precision is aligned with quant scheme, e.g., bfloat16 - # is needed for int4 - model = model.to(device=builder_args.device, dtype=builder_args.precision) - return model.eval() - else: - return _load_model_not_gguf(builder_args) +def _load_model_gguf(builder_args): + assert builder_args.gguf_path + model = Transformer.from_gguf(builder_args.gguf_path) + return model -def _load_model_not_gguf(builder_args): +def _load_model_default(builder_args): assert not builder_args.gguf_path with torch.device("meta"): @@ -218,9 +210,17 @@ def _load_model_not_gguf(builder_args): model.load_state_dict(checkpoint, assign=True, strict=False) + return model + + +def _load_model(builder_args): + if builder_args.gguf_path: + model = _load_model_gguf(builder_args) + else: + model = _load_model_default(builder_args) + if builder_args.use_tp: from tp import apply_tp - print("Applying tensor parallel to model ...") apply_tp(model) diff --git a/build/gguf_loader.py b/build/gguf_loader.py index 033cec212..f98e326da 100644 --- a/build/gguf_loader.py +++ b/build/gguf_loader.py @@ -17,6 +17,9 @@ import torch import torch.nn as nn +wd = Path(__file__).parent.resolve() +sys.path.append(str(wd)) + from gguf import GGUFValueType, ReaderTensor from quantize import ( group_dequantize_tensor_from_qparams, @@ -24,62 +27,12 @@ WeightOnlyInt4Linear, ) -from build.gguf_util import F16, F32, Q4_0, Q6_K - -wd = Path(__file__).parent.resolve() -sys.path.append(str(wd)) - +from build.gguf_util import F16, F32, Q4_0, Q6_K, to_float from model import ModelArgs, Transformer logger: logging.Logger = logging.getLogger(__name__) -@dataclass -class AttentionArgs: - head_count: int - head_count_kv: int - layer_norm_rms_epsilon: float - - -@dataclass -class RopeArgs: - dimension_count: int | None = None - freq_base: float | None = None - - -@dataclass -class GGUFModelArgs: - arch: str - embedding_length: int - block_count: int - feed_forward_length: int - vocab_size: int - attention: AttentionArgs - rope: RopeArgs - - -@dataclass -class GGUFWeights: - tensors: list[ReaderTensor] - - -def _create_pt_model( - gguf_model_args: GGUFModelArgs, -) -> nn.Module: - llama_model_args = ModelArgs( - dim=gguf_model_args.embedding_length, - n_layers=gguf_model_args.block_count, - n_heads=gguf_model_args.attention.head_count, - n_local_heads=gguf_model_args.attention.head_count_kv, - vocab_size=gguf_model_args.vocab_size, - norm_eps=gguf_model_args.attention.layer_norm_rms_epsilon, - hidden_dim=gguf_model_args.feed_forward_length, - ) - pt_model = Transformer(llama_model_args) - pt_model.eval() - return pt_model - - _name_replacements = [ ("blk", "layers"), ("token_embd", "tok_embeddings"), @@ -102,29 +55,6 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str: return result -def _build_model_args(metadata: dict[str, Any]) -> GGUFModelArgs: - arch = metadata["general.architecture"] - assert ( - arch == "llama" - ), f"Only general.architecture=llama is supported, but got general.architecture={arch}" - return GGUFModelArgs( - arch=arch, - embedding_length=metadata[f"{arch}.embedding_length"], - block_count=metadata[f"{arch}.block_count"], - feed_forward_length=metadata[f"{arch}.feed_forward_length"], - vocab_size=len(metadata["tokenizer.ggml.tokens"]), - attention=AttentionArgs( - head_count=metadata[f"{arch}.attention.head_count"], - head_count_kv=metadata[f"{arch}.attention.head_count_kv"], - layer_norm_rms_epsilon=metadata[f"{arch}.attention.layer_norm_rms_epsilon"], - ), - rope=RopeArgs( - freq_base=metadata.get(f"{arch}.rope.freq_base", None), - dimension_count=metadata.get(f"{arch}.rope.dimension_count", None), - ), - ) - - def _fqn_lookup(fqn: str, module: torch.nn.Module) -> Any: if fqn == "": return module @@ -153,74 +83,6 @@ def _fqn_last(fqn: str) -> str: return atoms[-1] -def load_weights( - pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor], inner_k_tiles=8 -) -> None: - fqns = [] - for fqn in pt_model.state_dict(): - assert _fqn_last(fqn) == "weight" - fqns.append(_fqn_up(fqn)) - - state_dict = {} - for fqn in fqns: - mod = _fqn_lookup(fqn, pt_model) - - t = weight_map[f"{fqn}.weight"] - - if ( - isinstance(mod, torch.nn.Linear) - and t.tensor_type == gguf.GGMLQuantizationType.Q4_0 - ): - assert not mod.bias - out_features = mod.out_features - in_features = mod.in_features - assert all(t.shape == (in_features, out_features)) - - q, s, z = Q4_0.unpack(t) - scales_and_zeros = pack_scales_and_zeros(s, z) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - q, inner_k_tiles - ) - - state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") - state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") - - parent = _fqn_lookup(_fqn_up(fqn), pt_model) - setattr( - parent, - _fqn_last(fqn), - WeightOnlyInt4Linear( - "cpu", # TODO: should --device work for gguf load? (yes?!) - in_features, - out_features, - bias=False, - groupsize=Q4_0.groupsize, - inner_k_tiles=inner_k_tiles, - ), - ) - else: - # All other weights are dequantized to float - if t.tensor_type == gguf.GGMLQuantizationType.Q4_0: - as_float = group_dequantize_tensor_from_qparams( - *Q4_0.unpack(t), Q4_0.n_bit, Q4_0.groupsize - ) - elif t.tensor_type == gguf.GGMLQuantizationType.Q6_K: - as_float = group_dequantize_tensor_from_qparams( - *Q6_K.unpack(t), Q6_K.n_bit, Q6_K.groupsize - ) - elif t.tensor_type == gguf.GGMLQuantizationType.F16: - as_float = F16.unpack(t) - elif t.tensor_type == gguf.GGMLQuantizationType.F32: - as_float = F32.unpack(t) - else: - raise ValueError(f"Unsupported tensor type {t.tensor_type}") - - state_dict[f"{fqn}.weight"] = as_float.to("cpu") - - pt_model.load_state_dict(state_dict) - return pt_model - - def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]: metadata: dict[str, Any] = {} @@ -244,34 +106,103 @@ def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]: return metadata -def load_llama_from_gguf_file(gguf_file: str) -> torch.nn.Module: +def load_model(gguf_file: str) -> torch.nn.Module: """ - Load a LLaMa model from a GGUF file and return a PT nn.Module. + Parses the GGUF file and returns an nn.Module on meta device. """ - if not Path(gguf_file).is_file(): - raise ValueError(f"Could not find file {gguf_file}") logger.info("Parsing GGUF metadata.") reader = gguf.GGUFReader(gguf_file, "r") metadata = _get_metadata(reader) - model_args = _build_model_args(metadata) + + arch = metadata["general.architecture"] assert ( - model_args.arch == "llama" + arch == "llama" ), "Only LLaMa models are supported by this converter." - logger.info("Creating initial PT model.") - pt_model = _create_pt_model(model_args) + model_args = ModelArgs( + dim=metadata[f"{arch}.embedding_length"], + n_layers=metadata[f"{arch}.block_count"], + n_heads=metadata[f"{arch}.attention.head_count"], + n_local_heads=metadata[f"{arch}.attention.head_count_kv"], + vocab_size=len(metadata["tokenizer.ggml.tokens"]), + norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"], + hidden_dim=metadata[f"{arch}.feed_forward_length"], + ) - logger.info("Reading GGUF weights.") - gguf_weights = GGUFWeights(tensors=reader.tensors) + # TODO: what to do with rope args like + # metadata.get(f"{arch}.rope.freq_base", None) + # metadata.get(f"{arch}.rope.dimension_count", None) - logger.info("Building GGUF weight map.") - # map from fqn in pt_model to gguf tensor + with torch.device("meta"): + model = Transformer(model_args) + return model + + +def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_k_tiles = 8) -> torch.nn.Module: + """ + Parses the GGUF file and returns an nn.Module on meta device along with a state_dict + that can be loaded into it. + + When load_as_quantized, the method tries to preserve the GGUF quantization when it + is natively supported by PyTorch, otherwise it converts quantized tensors to FP32. + """ + + model = load_model(gguf_file) + + reader = gguf.GGUFReader(gguf_file, "r") weight_map = { _convert_gguf_tensor_name_to_llama_nn(tensor.name): tensor - for tensor in gguf_weights.tensors + for tensor in reader.tensors } - logger.info("Loading weights into state_dict") - pt_model = load_weights(pt_model, weight_map, inner_k_tiles=8) - return pt_model + state_dict = {} + for fqn in weight_map: + assert _fqn_last(fqn) == "weight" + fqn = _fqn_up(fqn) + + mod = _fqn_lookup(fqn, model) + t = weight_map[f"{fqn}.weight"] + + if ( + isinstance(mod, torch.nn.Linear) + and t.tensor_type == gguf.GGMLQuantizationType.Q4_0 + and load_as_quantized + ): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + assert all(t.shape == (in_features, out_features)) + + q, s, z = Q4_0.unpack(t) + scales_and_zeros = pack_scales_and_zeros(s, z) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + q, inner_k_tiles + ) + + state_dict[f"{fqn}.weight"] = weight_int4pack + state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros + + parent = _fqn_lookup(_fqn_up(fqn), model) + setattr( + parent, + _fqn_last(fqn), + WeightOnlyInt4Linear( + "meta", + in_features, + out_features, + bias=False, + groupsize=Q4_0.groupsize, + inner_k_tiles=inner_k_tiles, + ), + ) + else: + state_dict[f"{fqn}.weight"] = to_float(t) + + return model, state_dict + + +def load_llama_from_gguf_file(gguf_file: str) -> torch.nn.Module: + model, state_dict = load_model_and_state_dict(gguf_file, load_as_quantized=True) + model.load_state_dict(state_dict, assign=True) + return model diff --git a/build/model.py b/build/model.py index 4786434f8..c3f0af512 100644 --- a/build/model.py +++ b/build/model.py @@ -247,9 +247,9 @@ def from_params(cls, params_path: str): @classmethod def from_gguf(cls, gguf_path: str): - from build.gguf_loader import load_llama_from_gguf_file - - model = load_llama_from_gguf_file(gguf_path) + from build.gguf_loader import load_model_and_state_dict + model, state_dict = load_model_and_state_dict(gguf_path, load_as_quantized=True, inner_k_tiles=8) + model.load_state_dict(state_dict, assign=True) return model