diff --git a/cli.py b/cli.py index b14f0a944..e374618fb 100644 --- a/cli.py +++ b/cli.py @@ -15,22 +15,23 @@ strict = False + def check_args(args, command_name: str): global strict # chat and generate support the same options - if command_name in ["generate", "chat", "gui"]: + if command_name in ["generate", "chat", "gui"]: # examples, can add more. Note that attributes convert dash to _ - disallowed_args = ['output_pte_path', 'output_dso_path' ] + disallowed_args = ["output_pte_path", "output_dso_path"] elif command_name == "export": # examples, can add more. Note that attributes convert dash to _ - disallowed_args = ['pte_path', 'dso_path' ] + disallowed_args = ["pte_path", "dso_path"] elif command_name == "eval": # TBD disallowed_args = [] else: raise RuntimeError(f"{command_name} is not a valid command") - + for disallowed in disallowed_args: if hasattr(args, disallowed): text = f"command {command_name} does not support option {disallowed.replace('_', '-')}" @@ -39,7 +40,7 @@ def check_args(args, command_name: str): else: print(f"Warning: {text}") - + def cli_args(): import argparse @@ -48,8 +49,8 @@ def cli_args(): parser.add_argument( "--seed", type=int, - default=1234, # set None for release - help="Initialize torch seed" + default=1234, # set None for release + help="Initialize torch seed", ) parser.add_argument( "--prompt", type=str, default="Hello, my name is", help="Input prompt." @@ -78,55 +79,31 @@ def cli_args(): "--chat", action="store_true", help="Use torchat to for an interactive chat session.", - ) + ) parser.add_argument( "--gui", action="store_true", help="Use torchat to for an interactive gui-chat session.", - ) - parser.add_argument( - "--num-samples", - type=int, - default=1, - help="Number of samples.") - parser.add_argument( - "--max-new-tokens", - type=int, - default=200, - help="Maximum number of new tokens." ) + parser.add_argument("--num-samples", type=int, default=1, help="Number of samples.") parser.add_argument( - "--top-k", - type=int, - default=200, - help="Top-k for sampling.") + "--max-new-tokens", type=int, default=200, help="Maximum number of new tokens." + ) + parser.add_argument("--top-k", type=int, default=200, help="Top-k for sampling.") parser.add_argument( - "--temperature", - type=float, - default=0.8, - help="Temperature for sampling." + "--temperature", type=float, default=0.8, help="Temperature for sampling." ) parser.add_argument( - "--compile", - action="store_true", - help="Whether to compile the model." + "--compile", action="store_true", help="Whether to compile the model." ) parser.add_argument( "--compile-prefill", action="store_true", help="Whether to compile the prefill (improves prefill perf, but higher compile times)", ) + parser.add_argument("--profile", type=Path, default=None, help="Profile path.") parser.add_argument( - "--profile", - type=Path, - default=None, - help="Profile path." - ) - parser.add_argument( - "--speculate-k", - type=int, - default=5, - help="Speculative execution depth." + "--speculate-k", type=int, default=5, help="Speculative execution depth." ) parser.add_argument( "--draft-checkpoint-path", @@ -163,31 +140,18 @@ def cli_args(): type=Path, default=None, help="Model checkpoint path.", - ) - parser.add_argument( - "--output-pte-path", - type=str, - default=None, - help="Filename" ) + parser.add_argument("--output-pte-path", type=str, default=None, help="Filename") + parser.add_argument("--output-dso-path", type=str, default=None, help="Filename") parser.add_argument( - "--output-dso-path", - type=str, - default=None, - help="Filename" - ) - parser.add_argument( - "--dso-path", - type=Path, - default=None, - help="Use the specified AOTI DSO model." + "--dso-path", type=Path, default=None, help="Use the specified AOTI DSO model." ) parser.add_argument( "--pte-path", type=Path, default=None, - help="Use the specified Executorch PTE model." - ) + help="Use the specified Executorch PTE model.", + ) parser.add_argument( "-d", "--dtype", @@ -196,48 +160,36 @@ def cli_args(): ) parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument( - "--quantize", - type=str, - default="{ }", - help="Quantization options." + "--quantize", type=str, default="{ }", help="Quantization options." ) parser.add_argument( - "--device", - type=str, - default=default_device, - help="Device to use" - ) - parser.add_argument( - "--params-table", - type=str, - default=None, - help="Device to use" + "--device", type=str, default=default_device, help="Device to use" ) + parser.add_argument("--params-table", type=str, default=None, help="Device to use") parser.add_argument( - '--tasks', - nargs='+', + "--tasks", + nargs="+", type=str, default=["hellaswag"], - help='list of lm-eluther tasks to evaluate usage: --tasks task1 task2' + help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", ) parser.add_argument( - '--limit', type=int, - default=None, - help='number of samples to evaluate' + "--limit", type=int, default=None, help="number of samples to evaluate" ) parser.add_argument( - '--max-seq-length', + "--max-seq-length", type=int, default=None, - help='maximum length sequence to evaluate') - + help="maximum length sequence to evaluate", + ) + args = parser.parse_args() - if (Path(args.quantize).is_file()): + if Path(args.quantize).is_file(): with open(args.quantize, "r") as f: args.quantize = json.loads(f.read()) if args.seed: - torch.manual_seed(args.seed) + torch.manual_seed(args.seed) return args diff --git a/eval.py b/eval.py index 8ac8c457f..82bdba06c 100644 --- a/eval.py +++ b/eval.py @@ -25,25 +25,32 @@ try: import lm_eval + lm_eval_available = True except: lm_eval_available = False -from build.builder import _initialize_model, _initialize_tokenizer, BuilderArgs, TokenizerArgs +from build.builder import ( + _initialize_model, + _initialize_tokenizer, + BuilderArgs, + TokenizerArgs, +) from generate import encode_tokens, model_forward if lm_eval_available: - try: # lm_eval version 0.4 + try: # lm_eval version 0.4 from lm_eval.models.huggingface import HFLM as eval_wrapper from lm_eval.tasks import get_task_dict from lm_eval.evaluator import evaluate - except: #lm_eval version 0.3 + except: # lm_eval version 0.3 from lm_eval import base from lm_eval import tasks from lm_eval import evaluator - eval_wrapper=base.BaseLM - get_task_dict=tasks.get_task_dict - evaluate=evaluator.evaluate + + eval_wrapper = base.BaseLM + get_task_dict = tasks.get_task_dict + evaluate = evaluator.evaluate def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( @@ -84,20 +91,22 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( return seq, input_pos, max_seq_length + class GPTFastEvalWrapper(eval_wrapper): """ A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. """ + def __init__( self, model: Transformer, tokenizer, - max_seq_length: Optional[int]=None, + max_seq_length: Optional[int] = None, ): super().__init__() self._model = model self._tokenizer = tokenizer - self._device = torch.device('cuda') + self._device = torch.device("cuda") self._max_seq_length = 2048 if max_seq_length is None else max_seq_length @property @@ -121,8 +130,7 @@ def device(self): return self._device def tok_encode(self, string: str, **kwargs): - encoded = encode_tokens(self._tokenizer, - string, bos=True, device=self._device) + encoded = encode_tokens(self._tokenizer, string, bos=True, device=self._device) # encoded is a pytorch tensor, but some internal logic in the # eval harness expects it to be a list instead # TODO: verify this for multi-batch as well @@ -138,19 +146,20 @@ def _model_call(self, inps): inps = inps.squeeze(0) max_new_tokens = 1 - seq, input_pos, max_seq_length = \ + seq, input_pos, max_seq_length = ( setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( self._model, inps, max_new_tokens, self.max_length, ) + ) x = seq.index_select(0, input_pos).view(1, -1) logits = model_forward(self._model, x, input_pos) return logits def _model_generate(self, context, max_length, eos_token_id): - raise Exception('unimplemented') + raise Exception("unimplemented") @torch.no_grad() @@ -185,8 +194,8 @@ def eval( except: pass - if 'hendrycks_test' in tasks: - tasks.remove('hendrycks_test') + if "hendrycks_test" in tasks: + tasks.remove("hendrycks_test") tasks += [x for x in lm_eval.tasks.hendrycks_test.create_all_tasks().keys()] task_dict = get_task_dict(tasks) @@ -212,7 +221,7 @@ def main(args) -> None: builder_args = BuilderArgs.from_args(args) tokenizer_args = TokenizerArgs.from_args(args) - + checkpoint_path = args.checkpoint_path checkpoint_dir = args.checkpoint_dir params_path = args.params_path @@ -223,12 +232,12 @@ def main(args) -> None: pte_path = args.pte_path quantize = args.quantize device = args.device - model_dtype = args.dtype + model_dtype = args.dtype tasks = args.tasks limit = args.limit max_seq_length = args.max_seq_length use_tiktoken = args.tiktoken - + print(f"Using device={device}") set_precision(buildeer_args.precision) @@ -240,9 +249,13 @@ def main(args) -> None: ) if compile: - assert not (builder_args.dso_path or builder_args.pte_path), "cannot compile exported model" + assert not ( + builder_args.dso_path or builder_args.pte_path + ), "cannot compile exported model" global model_forward - model_forward = torch.compile(model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True) + model_forward = torch.compile( + model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True + ) torch._inductor.config.coordinate_descent_tuning = True t1 = time.time() @@ -268,11 +281,13 @@ def main(args) -> None: for task, res in result["results"].items(): print(f"{task}: {res}") -if __name__ == '__main__': + +if __name__ == "__main__": + def cli(): args = cli_args() main(args) if __name__ == "__main__": - cli() + cli() diff --git a/export.py b/export.py index 1e5fb5d37..812cf1fd0 100644 --- a/export.py +++ b/export.py @@ -42,7 +42,6 @@ def device_sync(device): print(f"device={device} is not yet suppported") - def main(args): builder_args = BuilderArgs.from_args(args) tokenizer_args = TokenizerArgs.from_args(args) @@ -70,7 +69,9 @@ def main(args): print(f"Exporting model using Executorch to {output_pte_path}") export_model_et(model, builder_args.device, args.output_pte_path, args) else: - print(f"Export with executorch requested but Executorch could not be loaded") + print( + f"Export with executorch requested but Executorch could not be loaded" + ) print(executorch_exception) if output_dso_path: output_dso_path = str(os.path.abspath(output_dso_path)) @@ -82,5 +83,6 @@ def cli(): args = cli_args() main(args) + if __name__ == "__main__": cli() diff --git a/export_aoti.py b/export_aoti.py index 6501b9e98..93c9cf504 100644 --- a/export_aoti.py +++ b/export_aoti.py @@ -33,11 +33,11 @@ def device_sync(device): def export_model(model: nn.Module, device, output_path, args=None): max_seq_length = 350 -# with torch.device(device): -# model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + # with torch.device(device): + # model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) input = ( - torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device), + torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device), torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), ) diff --git a/export_et.py b/export_et.py index 030fd0b6c..59c413400 100644 --- a/export_et.py +++ b/export_et.py @@ -14,7 +14,10 @@ from generate import decode_one_token from quantize import ( - quantize_model, name_to_dtype, set_precision, get_precision, + quantize_model, + name_to_dtype, + set_precision, + get_precision, ) from build.model import Transformer from build.model import Transformer @@ -22,10 +25,12 @@ from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( XnnpackPartitioner, ) + # from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( # XnnpackDynamicallyQuantizedPartitioner, -#) +# ) from executorch_portable_utils import export_to_edge + # TODO: change back to executorch.examples.portable.utils # when executorch installs correctly @@ -99,7 +104,7 @@ def export_model(model, device, output_path, args=None) -> str: # noqa: C901 _skip_type_promotion=bool(target_precision == torch.float16), ) - if target_precision == torch.float16: # or args.quantization_mode=="int4": + if target_precision == torch.float16: # or args.quantization_mode=="int4": if state_dict_dtype != torch.float16: print("model.to torch.float16") model = model.to(dtype=torch.float16) @@ -111,11 +116,11 @@ def export_model(model, device, output_path, args=None) -> str: # noqa: C901 else: raise ValueError(f"Unsupported dtype for ET export: {target_precision}") - with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]), torch.no_grad(): + with torch.nn.attention.sdpa_kernel( + [torch.nn.attention.SDPBackend.MATH] + ), torch.no_grad(): m = capture_pre_autograd_graph( - export_model, - input, - dynamic_shapes=dynamic_shapes + export_model, input, dynamic_shapes=dynamic_shapes ) edge_manager = export_to_edge( diff --git a/generate.py b/generate.py index ad6085582..1703126fc 100644 --- a/generate.py +++ b/generate.py @@ -15,37 +15,44 @@ import torch._dynamo.config import torch._inductor.config -from build.builder import _load_model, _initialize_model, _initialize_tokenizer, BuilderArgs, TokenizerArgs +from build.builder import ( + _load_model, + _initialize_model, + _initialize_tokenizer, + BuilderArgs, + TokenizerArgs, +) from build.model import Transformer from quantize import quantize_model, name_to_dtype, set_precision, get_precision from cli import cli_args + @dataclass class GeneratorArgs: prompt: str = "torchat is pronounced torch-chat and is so cool because" - chat: bool = False, - gui: bool = False, - num_samples: int =1, - max_new_tokens: int = 200, - top_k: int = 200, - temperature: int = 0, # deterministic argmax - compile: bool = False, - compile_prefill: bool = False, - speculate_k: int = 5, + chat: bool = (False,) + gui: bool = (False,) + num_samples: int = (1,) + max_new_tokens: int = (200,) + top_k: int = (200,) + temperature: int = (0,) # deterministic argmax + compile: bool = (False,) + compile_prefill: bool = (False,) + speculate_k: int = (5,) @classmethod - def from_args(cls, args): # -> GeneratorArgs: + def from_args(cls, args): # -> GeneratorArgs: return cls( - prompt = args.prompt, - chat = args.chat, - gui = args.gui, - num_samples = args.num_samples, - max_new_tokens = args.max_new_tokens, - top_k = args.top_k, - temperature = args.temperature, - compile = args.compile, - compile_prefill = args.compile_prefill, - speculate_k = args.speculate_k, + prompt=args.prompt, + chat=args.chat, + gui=args.gui, + num_samples=args.num_samples, + max_new_tokens=args.max_new_tokens, + top_k=args.top_k, + temperature=args.temperature, + compile=args.compile, + compile_prefill=args.compile_prefill, + speculate_k=args.speculate_k, ) @@ -152,6 +159,7 @@ def decode_n_tokens( # except: # print("compiled model load not successful, running eager model") + def model_forward(model, x, input_pos): return model(x, input_pos) @@ -336,20 +344,20 @@ def _main( is_chat = "chat" in str(os.path.basename(builder_args.checkpoint_path)) if is_chat: - raise RuntimeError("need to stop filename based kludgery, at a minimum need to look at all pathnames. in particular, this now fails because chat is part of the pathname, yuck!") + raise RuntimeError( + "need to stop filename based kludgery, at a minimum need to look at all pathnames. in particular, this now fails because chat is part of the pathname, yuck!" + ) tokenizer = _initialize_tokenizer(tokenizer_args) builder_args.setup_caches = False - model = _initialize_model( - builder_args, - quantize - ) + model = _initialize_model(builder_args, quantize) # will add a version of _initialize_model in future # (need additional args) if is_speculative: from builder import _load_model + speculative_builder_args = builder_args draft_model = _load_model( @@ -359,7 +367,7 @@ def _main( draft_model = None encoded = encode_tokens(tokenizer, prompt, bos=True, device=builder_args.device) - print (encoded) + print(encoded) prompt_length = encoded.size(0) model_size = sum( @@ -369,7 +377,9 @@ def _main( ] ) if compile: - if is_speculative and builder_args.use_tp: # and ("cuda" in builder_args.device): + if ( + is_speculative and builder_args.use_tp + ): # and ("cuda" in builder_args.device): torch._inductor.config.triton.cudagraph_trees = ( False # Bug with cudagraph trees in this case ) @@ -401,7 +411,9 @@ def _main( prompt = input("What is your prompt? ") if is_chat: prompt = f"{B_INST} {prompt.strip()} {E_INST}" - encoded = encode_tokens(tokenizer, prompt, bos=True, device=builder_args.device) + encoded = encode_tokens( + tokenizer, prompt, bos=True, device=builder_args.device + ) if chat_mode and i >= 0: buffer = [] @@ -503,10 +515,11 @@ def main(args): args.quantize, ) + def cli(): args = cli_args() main(args) if __name__ == "__main__": - cli() + cli() diff --git a/quantize.py b/quantize.py index 4890c4b42..45b8dec17 100644 --- a/quantize.py +++ b/quantize.py @@ -15,7 +15,6 @@ import quantized_ops - try: from GPTQ import GenericGPTQRunner, InputRecorder from eval import get_task_dict, evaluate, lm_eval @@ -27,34 +26,39 @@ precision = torch.float + def set_precision(dtype): global precision precision = dtype + def get_precision(): global precision return precision + def name_to_dtype(name): if name in name_to_dtype_dict: return name_to_dtype_dict[name] else: raise RuntimeError(f"unsupported dtype name {name} specified") + name_to_dtype_dict = { - "fp32" : torch.float, - "fp16" : torch.float16, - "bf16" : torch.bfloat16, - "float" : torch.float, - "half" : torch.float16, - "float32" : torch.float, - "float16" : torch.float16, - "bfloat16" : torch.bfloat16, + "fp32": torch.float, + "fp16": torch.float16, + "bf16": torch.bfloat16, + "float": torch.float, + "half": torch.float16, + "float32": torch.float, + "float16": torch.float16, + "bfloat16": torch.bfloat16, } ########################################################################## ### process quantization dictionary ### + def quantize_model(model: nn.Module, device, quantize_options): """ Quantize the specified model using the quantizers described by @@ -73,46 +77,34 @@ def quantize_model(model: nn.Module, device, quantize_options): for quantizer, q_kwargs in quantize_options.items(): if quantizer == "embedding": model = EmbeddingOnlyInt8QuantHandler( - model, - device, - **q_kwargs + model, device, **q_kwargs ).quantized_model() elif linears_quantized: - assert 0==1, "can only specify one linear quantizer" + assert 0 == 1, "can only specify one linear quantizer" elif quantizer == "linear:int8": linears_quantized = True model = WeightOnlyInt8QuantHandler( - model, - device, - **q_kwargs + model, device, **q_kwargs ).quantized_model() elif quantizer == "linear:int4": linears_quantized = True model = WeightOnlyInt4QuantHandler( - model, - device, - **q_kwargs + model, device, **q_kwargs ).quantized_model() elif quantizer == "linear:a8w4dq": linears_quantized = True model = Int8DynActInt4WeightQuantHandler( - model, - device, - **q_kwargs + model, device, **q_kwargs ).quantized_model() elif quantizer == "linear:gptq": linears_quantized = True model = WeightOnlyInt4GPTQQuantHandler( - model, - device, - **q_kwargs + model, device, **q_kwargs ).quantized_model() elif quantizer == "linear:hqq": linears_quantized = True model = WeightOnlyInt4HqqQuantHandler( - model, - device, - **q_kwargs + model, device, **q_kwargs ).quantized_model() elif quantizer == "precision": model.to(**q_kwargs) @@ -123,6 +115,7 @@ def quantize_model(model: nn.Module, device, quantize_options): ######################################################################### ##### Quantization Primitives ###### + def dynamically_quantize_per_channel( x, quant_min, @@ -217,8 +210,7 @@ def dynamically_quantize_per_channel( return quant, scales, zero_points - -def get_group_qparams(w, n_bit=4, groupsize=128, *, scales_dtype= torch.float): +def get_group_qparams(w, n_bit=4, groupsize=128, *, scales_dtype=torch.float): # needed for GPTQ with padding if groupsize > w.shape[-1]: groupsize = w.shape[-1] @@ -324,6 +316,7 @@ def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): w_int32, scales, zeros, n_bit, groupsize ) + ######################################################################### ### QuantHandler API definition ### @@ -349,9 +342,11 @@ def quantized_model(self) -> nn.Module: ##### Weight-only int8 per-channel quantized code ###### -def replace_linear_weight_only_int8_per_channel(module, device, node_type, groupsize=None): +def replace_linear_weight_only_int8_per_channel( + module, device, node_type, groupsize=None +): if groupsize is not None and groupsize != 0: - pass # groupsize = 2 ** groupsize + pass # groupsize = 2 ** groupsize for name, child in module.named_children(): # print(f"name: {name}") @@ -367,10 +362,14 @@ def replace_linear_weight_only_int8_per_channel(module, device, node_type, group setattr( module, name, - WeightOnlyInt8Linear(device, child.in_features, child.out_features, groupsize), + WeightOnlyInt8Linear( + device, child.in_features, child.out_features, groupsize + ), ) else: - replace_linear_weight_only_int8_per_channel(child, device, node_type, groupsize) + replace_linear_weight_only_int8_per_channel( + child, device, node_type, groupsize + ) class WeightOnlyInt8QuantHandler(QuantHandler): @@ -443,7 +442,9 @@ def create_quantized_state_dict(self) -> Dict: return cur_state_dict def convert_for_runtime(self) -> nn.Module: - replace_linear_weight_only_int8_per_channel(self.mod, self.device, self.node_type, self.groupsize) + replace_linear_weight_only_int8_per_channel( + self.mod, self.device, self.node_type, self.groupsize + ) return self.mod def quantized_model(self) -> nn.Module: @@ -474,14 +475,19 @@ def __init__( self.in_features = in_features self.out_features = out_features self.register_buffer( - "weight", torch.empty((out_features, in_features), dtype=torch.int8, device=device) + "weight", + torch.empty((out_features, in_features), dtype=torch.int8, device=device), ) - dtype=get_precision() + dtype = get_precision() if groupsize is None or (groupsize == 0): - self.register_buffer("scales", torch.ones(out_features, dtype=dtype, device=device)) + self.register_buffer( + "scales", torch.ones(out_features, dtype=dtype, device=device) + ) else: groups = (in_features + groupsize - 1) // groupsize - self.register_buffer("scales", torch.ones(out_features, groups, dtype=dtype, device=device)) + self.register_buffer( + "scales", torch.ones(out_features, groups, dtype=dtype, device=device) + ) def forward(self, input: torch.Tensor) -> torch.Tensor: scales = self.scales @@ -496,7 +502,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if scales.shape[1] == 1: return F.linear(input, weight.to(dtype=input.dtype)) * self.scales else: - return F.linear(input, (weight.to(dtype=input.dtype).view(weight.shape[0], no_groups, -1) * scales.view(weight.shape[0], no_groups, -1)).view(weight.shape[0], -1)) + return F.linear( + input, + ( + weight.to(dtype=input.dtype).view(weight.shape[0], no_groups, -1) + * scales.view(weight.shape[0], no_groups, -1) + ).view(weight.shape[0], -1), + ) ######################################################################### @@ -504,7 +516,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: def replace_embedding_weight_only_grouped_int8_per_channel( - module, device, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False + module, device, bitwidth: int = 8, groupsize: Optional[int] = None, packed=False ): for name, child in module.named_children(): # print(f"name: {name}") @@ -529,9 +541,17 @@ def replace_embedding_weight_only_grouped_int8_per_channel( class EmbeddingOnlyInt8QuantHandler(QuantHandler): - def __init__(self, mod, device, *, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False): + def __init__( + self, + mod, + device, + *, + bitwidth: int = 8, + groupsize: Optional[int] = None, + packed=False, + ): if isinstance(packed, str): - packed = (packed == "True") + packed = packed == "True" self.mod = mod self.device = device self.groupsize = groupsize @@ -540,7 +560,6 @@ def __init__(self, mod, device, *, bitwidth: int = 8, groupsize: Optional[int] = if (bitwidth != 4) and packed: raise RuntimeError("pack only works with bitsize 4") - @torch.no_grad() def create_quantized_state_dict(self, packed=False) -> Dict: cur_state_dict = self.mod.state_dict() @@ -555,9 +574,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict: raise ValueError(f"Unsupported bitwidth {self.bitwidth}") for fqn, mod in self.mod.named_modules(): - if ( - isinstance(mod, nn.Embedding) - ): + if isinstance(mod, nn.Embedding): # print("****") # print(f"Embedding identified: {fqn, mod}") # print(f"weights size: {mod.weight.size()}") @@ -576,17 +593,15 @@ def create_quantized_state_dict(self, packed=False) -> Dict: ) if packed: - if weight.shape[-1] %2 != 0: + if weight.shape[-1] % 2 != 0: raise RuntimeError("automatic padding not implemented yet") weight_range_shifted = weight.add(8).view(torch.uint8) weight_view = weight_range_shifted.view( - weight.shape[0], - weight.shape[1] //2, - 2 - ) - weight_even = weight_view[:,:,0] * 16 # left shift 4 - weight_odd = weight_view[:,:,1] + weight.shape[0], weight.shape[1] // 2, 2 + ) + weight_even = weight_view[:, :, 0] * 16 # left shift 4 + weight_odd = weight_view[:, :, 1] weight_packed = weight_even + weight_odd weight = weight_packed @@ -630,16 +645,25 @@ def __init__( self.packed = packed if not packed: self.register_buffer( - "weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8, device=device) + "weight", + torch.empty( + (vocab_size, embedding_dim), dtype=torch.int8, device=device + ), ) - else: # packed + else: # packed self.register_buffer( - "weight", torch.empty((vocab_size, embedding_dim//2), dtype=torch.uint8, device=device) + "weight", + torch.empty( + (vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device + ), ) groups_per_row = (embedding_dim + groupsize - 1) // groupsize if groups_per_row > 1: self.register_buffer( - "scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16, device=device) + "scales", + torch.ones( + (vocab_size, groups_per_row), dtype=torch.float16, device=device + ), ) else: self.register_buffer( @@ -648,17 +672,16 @@ def __init__( @torch.no_grad() def forward(self, indices: torch.Tensor) -> torch.Tensor: - if False: # Used for Executorch + if False: # Used for Executorch return torch.ops.llama_quantized.embedding_byte.dtype( self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype ) - # result_weights = self.weight.index_select(0, indices.view(-1)) # result_scales = self.scales.index_select(0, indices.view(-1)) if self.packed: - weight_even = self.weight.div(16, rounding_mode='trunc') + weight_even = self.weight.div(16, rounding_mode="trunc") weight_odd = self.weight.remainder(16) weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1) weight = weight_unpacked.view(self.weight.shape[0], -1) @@ -671,8 +694,22 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: result_weights = F.embedding(indices, weight) result_scales = F.embedding(indices, scales) - rw_view = result_weights.to(dtype=result_scales.dtype).view(tuple(result_weights.shape[:-1] + (scales.shape[1], -1, ))) - rs_view = result_scales.view(tuple(result_scales.shape[:-1]) + (scales.shape[1], 1, )) + rw_view = result_weights.to(dtype=result_scales.dtype).view( + tuple( + result_weights.shape[:-1] + + ( + scales.shape[1], + -1, + ) + ) + ) + rs_view = result_scales.view( + tuple(result_scales.shape[:-1]) + + ( + scales.shape[1], + 1, + ) + ) # print(f"rw_view {rw_view.shape}") # print(f"rs_view {rs_view.shape}") @@ -685,17 +722,25 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: ######################################################################### ##### weight only int4 per channel groupwise quantized code ###### -def _int4_prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): + +def _int4_prepare_int4_weight_and_scales_and_zeros( + weight_bf16, groupsize, inner_k_tiles +): weight_int32, scales_and_zeros = group_quantize_tensor( weight_bf16, n_bit=4, groupsize=groupsize ) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + weight_int32, inner_k_tiles + ) return weight_int4pack, scales_and_zeros + def _int4_calc_padded_size(k, groupsize=1, innner_k_tiles=1): from build.model import find_multiple + return find_multiple(k, 1024) + def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) @@ -705,31 +750,41 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou if "mps" in str(x.device): new_shape = origin_x_size[:-1] + (out_features,) return torch.zeros(new_shape, dtype=x.dtype, device=x.device) - + c = torch.ops.aten._weight_int4pack_mm( - x.to(torch.bfloat16), # TODO: should probably make a warning if x is not already bfloat16 + x.to( + torch.bfloat16 + ), # TODO: should probably make a warning if x is not already bfloat16 weight_int4pack, groupsize, - scales_and_zeros.to(torch.bfloat16), # TODO: should probably make a warning if not already bfloat16 - ).to(x.dtype) # cast back to x.dtype + scales_and_zeros.to( + torch.bfloat16 + ), # TODO: should probably make a warning if not already bfloat16 + ).to( + x.dtype + ) # cast back to x.dtype new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) return c -def _int4_check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1): +def _int4_check_linear_int4_k(k, groupsize=1, inner_k_tiles=1): return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 + def replace_linear_int4( - module, - device, - groupsize, - inner_k_tiles, - padding_allowed, + module, + device, + groupsize, + inner_k_tiles, + padding_allowed, ): for name, child in module.named_children(): if isinstance(child, nn.Linear): - if _int4_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: + if ( + _int4_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) + or padding_allowed + ): setattr( module, name, @@ -740,7 +795,8 @@ def replace_linear_int4( bias=False, groupsize=groupsize, inner_k_tiles=inner_k_tiles, - )) + ), + ) else: replace_linear_int4( child, device, groupsize, inner_k_tiles, padding_allowed @@ -748,7 +804,9 @@ def replace_linear_int4( class WeightOnlyInt4QuantHandler(QuantHandler): - def __init__(self, mod, device, *, groupsize=128, inner_k_tiles=8, padding_allowed=True): + def __init__( + self, mod, device, *, groupsize=128, inner_k_tiles=8, padding_allowed=True + ): self.mod = mod self.device = device self.groupsize = groupsize @@ -769,19 +827,30 @@ def create_quantized_state_dict(self): print(f"linear: {fqn}, in={in_features}, out={out_features}") weight = mod.weight.data - if not _int4_check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles): + if not _int4_check_linear_int4_k( + in_features, self.groupsize, self.inner_k_tiles + ): if self.padding_allowed: from build.model import find_multiple import torch.nn.functional as F - print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") + + print( + f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" + ) padded_in_features = find_multiple(in_features, 1024) - weight = F.pad(weight, pad=(0, padded_in_features - in_features)) + weight = F.pad( + weight, pad=(0, padded_in_features - in_features) + ) else: - print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + - "and that groupsize and inner_k_tiles*16 evenly divide into it") + print( + f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it" + ) continue - weight_int4pack, scales_and_zeros = _int4_prepare_int4_weight_and_scales_and_zeros( - weight.to(torch.float), self.groupsize, self.inner_k_tiles + weight_int4pack, scales_and_zeros = ( + _int4_prepare_int4_weight_and_scales_and_zeros( + weight.to(torch.float), self.groupsize, self.inner_k_tiles + ) ) weight_int4pack = weight_int4pack.to(device=self.device) scales_and_zeros = scales_and_zeros.to(device=self.device) @@ -790,9 +859,14 @@ def create_quantized_state_dict(self): return cur_state_dict - def convert_for_runtime(self): - replace_linear_int4(self.mod, self.device, self.groupsize, self.inner_k_tiles, self.padding_allowed) + replace_linear_int4( + self.mod, + self.device, + self.groupsize, + self.inner_k_tiles, + self.padding_allowed, + ) return self.mod def quantized_model(self) -> nn.Module: @@ -803,25 +877,28 @@ def quantized_model(self) -> nn.Module: class WeightOnlyInt4Linear(torch.nn.Module): - __constants__ = ['in_features', 'out_features'] + __constants__ = ["in_features", "out_features"] in_features: int out_features: int weight: torch.Tensor def __init__( - self, - device: str, - in_features: int, - out_features: int, - bias=True, - dtype=None, - groupsize: int = 128, - inner_k_tiles: int = 8, + self, + device: str, + in_features: int, + out_features: int, + bias=True, + dtype=None, + groupsize: int = 128, + inner_k_tiles: int = 8, ) -> None: super().__init__() - self.padding = not _int4_check_linear_int4_k(in_features, groupsize, inner_k_tiles) + self.padding = not _int4_check_linear_int4_k( + in_features, groupsize, inner_k_tiles + ) if self.padding: from build.model import find_multiple + self.origin_in_features = in_features in_features = find_multiple(in_features, 1024) @@ -832,14 +909,21 @@ def __init__( self.inner_k_tiles = inner_k_tiles assert out_features % 8 == 0, "require out_features % 8 == 0" - assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" + assert ( + in_features % (inner_k_tiles * 16) == 0 + ), "require in_features % (innerKTiles * 16) == 0" self.register_buffer( "weight", torch.empty( - (out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), + ( + out_features // 8, + in_features // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), dtype=torch.int32, device=device, - ) + ), ) # MKG: torch.float self.register_buffer( @@ -848,7 +932,7 @@ def __init__( (in_features // groupsize, out_features, 2), dtype=get_precision(), device=device, - ) + ), ) def forward(self, input: torch.Tensor) -> torch.Tensor: @@ -856,18 +940,17 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # input = input.to(torch.float) if self.padding: import torch.nn.functional as F + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) return linear_forward_int4( - input, - self.weight, - self.scales_and_zeros, - self.out_features, - self.groupsize + input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize ) + ######################################################################### ##### Int8 Dynamic Activations 4 Bit Weights ##### + def prepare_int4_weight_and_scales_and_zeros(weight, groupsize, precision): weight_int8, scales, zeros = group_quantize_tensor_symmetric( weight, @@ -924,6 +1007,7 @@ def find_multiple(n: int, *args: Tuple[int]) -> int: def _check_linear_int4_k(k, groupsize=1): return k % groupsize == 0 + def _calc_padded_size_linear_int4(k, groupsize=1): return find_multiple(k, groupsize) @@ -965,7 +1049,7 @@ def __init__( self, mod, device, - * , + *, groupsize=256, padding_allowed=False, precision=torch.float32, @@ -1130,6 +1214,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ######################################################################### ##### GPTQ ##### + class GPTQQuantHandler(QuantHandler): """ This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. @@ -1192,6 +1277,7 @@ class GPTQQuantHandler(QuantHandler): names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the corresponding quantized weights and qparams. """ + def __init__(self): assert self.mod is not None assert self.get_qparams_func is not None @@ -1201,7 +1287,14 @@ def __init__(self): assert self.make_names_and_values_dict_func is not None @staticmethod - def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput": + def get_inputs( + model, + tokenizer, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + ) -> "MultiInput": input_recorder = InputRecorder( model, tokenizer, @@ -1223,9 +1316,9 @@ def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibrati ) inputs = input_recorder.get_recorded_inputs() assert inputs is not None, ( - f"No inputs were collected, use a task other than {calibration_tasks}, "+ - f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+ - f"{calibration_seq_length})" + f"No inputs were collected, use a task other than {calibration_tasks}, " + + f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently " + + f"{calibration_seq_length})" ) print(f"Obtained {len(inputs[0].values)} calibration samples") return inputs @@ -1242,7 +1335,14 @@ def create_quantized_state_dict( calibration_seq_length, pad_calibration_inputs, ) -> "StateDict": - inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) + inputs = GPTQQuantHandler.get_inputs( + self.mod, + tokenizer, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + ) print("Tracing model for GPTQ") GPTQ_runner = GenericGPTQRunner( self.mod, @@ -1256,7 +1356,7 @@ def create_quantized_state_dict( self.dequantize_func, self.combine_qparams_list_func, self.make_names_and_values_dict_func, - self.skip_layer_func + self.skip_layer_func, ) print("Applying GPTQ to weights") @@ -1270,40 +1370,52 @@ def convert_for_runtime(self) -> "nn.Module": class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): def __init__(self, mod, device, *, groupsize=128, inner_k_tiles=8, padding=True): from build.model import find_multiple + self.mod = mod self.device = device self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles self.padding = padding self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize) - self.quantize_func = lambda w, qparams: \ - group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize) - self.dequantize_func = lambda q, qparams: \ - group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float() - self.combine_qparams_list_func = lambda qparams_list: \ - [torch.cat(x, dim=1) for x in zip(*qparams_list)] + self.quantize_func = lambda w, qparams: group_quantize_tensor_from_qparams( + w, qparams[0], qparams[1], 4, groupsize + ) + self.dequantize_func = lambda q, qparams: group_dequantize_tensor_from_qparams( + q, qparams[0], qparams[1], 4, groupsize + ).float() + self.combine_qparams_list_func = lambda qparams_list: [ + torch.cat(x, dim=1) for x in zip(*qparams_list) + ] # skip unless padding=True or its correctly sized self.skip_layer_func = lambda linear_weight: not ( - _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding + _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) + or padding ) + # we need to do the padding here, both for q and the qparams if necessary def make_names_and_values_dict_func(q, qparams): k = q.shape[1] new_k = find_multiple(k, 1024) # how much we need to pad the weight delta_k = new_k - q.shape[1] - final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) + final_q = torch.ops.aten._convert_weight_to_int4pack( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) scales_and_zeros = pack_scales_and_zeros(*qparams) # how many new groups we need for padded weight delta_groups = new_k // groupsize - scales_and_zeros.shape[0] - final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) + final_s_and_z = F.pad( + scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1 + ) return {"weight": final_q, "scales_and_zeros": final_s_and_z} + self.make_names_and_values_dict_func = make_names_and_values_dict_func super().__init__() - def convert_for_runtime(self): - replace_linear_int4(self.mod, self.device, self.groupsize, self.inner_k_tiles, self.padding) + replace_linear_int4( + self.mod, self.device, self.groupsize, self.inner_k_tiles, self.padding + ) return self.mod def quantized_model(self) -> nn.Module: @@ -1313,7 +1425,6 @@ def quantized_model(self) -> nn.Module: return self.mod - # class Int8DynActInt4WeightGPTQQuantHandler(GPTQQuantHandler): # def __init__( # self, @@ -1388,6 +1499,7 @@ def quantized_model(self) -> nn.Module: ################################################################## ### WIP: HQQ ### + class WeightOnlyInt4HqqQuantHandler: def __init__(self, mod, device, *, groupsize): self.mod = mod @@ -1397,7 +1509,6 @@ def __init__(self, mod, device, *, groupsize): def create_quantized_state_dict(self): from hqq.core.quantize import Quantizer # TODO maybe torchao - for m in self.mod.modules(): for name, child in m.named_children(): if isinstance(child, torch.nn.Linear): diff --git a/quantized_ops.py b/quantized_ops.py index 7ac39b85e..e01cdf3af 100644 --- a/quantized_ops.py +++ b/quantized_ops.py @@ -10,15 +10,13 @@ import torch.nn.functional as F from torch.library import impl, impl_abstract -torchat_lib = torch.library.Library( - "torchat", "DEF" -) +torchat_lib = torch.library.Library("torchat", "DEF") torchat_lib.define( - "embedding_int8(Tensor input, Tensor weight, " - "Tensor scales) -> Tensor", + "embedding_int8(Tensor input, Tensor weight, " "Tensor scales) -> Tensor", ) + @impl(torchat_lib, "embedding_int8", "CompositeExplicitAutograd") def embedding_int8( input: torch.Tensor, @@ -27,9 +25,7 @@ def embedding_int8( ) -> torch.Tensor: indices = input # embedding_byte_weight_checks(weight, weight_scales, weight_zero_points) - groupsize = weight.size(1) // ( - scales.size(1) if scales.dim() == 2 else 1 - ) + groupsize = weight.size(1) // (scales.size(1) if scales.dim() == 2 else 1) # ET definition if False: weight_zero_points = None @@ -45,68 +41,82 @@ def embedding_int8( ) return torch.ops.aten.embedding.default(weight, indices) - scales = scales.view(weight.shape[0], -1) + scales = scales.view(weight.shape[0], -1) result_weights = F.embedding(indices, weight) result_scales = F.embedding(indices, scales) - rw_view = result_weights.to(dtype=result_scales.dtype).view(tuple(result_weights.shape[:-1]) + (scales.shape[1], -1, )) - rs_view = result_scales.view(tuple(result_scales.shape[:-1]) + (scales.shape[1], 1, )) + rw_view = result_weights.to(dtype=result_scales.dtype).view( + tuple(result_weights.shape[:-1]) + + ( + scales.shape[1], + -1, + ) + ) + rs_view = result_scales.view( + tuple(result_scales.shape[:-1]) + + ( + scales.shape[1], + 1, + ) + ) # print(f"rw_view {rw_view.shape}") # print(f"rs_view {rs_view.shape}") r = rw_view * rs_view return r.view(indices.size() + (-1,)) - - + + torchat_lib.define( "linear_int8(Tensor input, Tensor weight, Tensor scales, " "Tensor bias = None) -> Tensor", ) + @impl(torchat_lib, "linear_int8", "CompositeExplicitAutograd") def linear_int8( - input: torch.Tensor, - weight: torch.Tensor, - scales: torch.Tensor, - bias: Optional[torch.Tensor] = None, + input: torch.Tensor, + weight: torch.Tensor, + scales: torch.Tensor, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert bias is None, "bias != None not implemented" - + scales = scales.view(scales.shape[0], -1) no_groups = scales.shape[1] # for now, we special-case channel-wise, because we know how to - # make that fast with Triton + # make that fast with Triton if scales.shape[1] == 1: return F.linear(input, weight.to(dtype=input.dtype)) * scales else: return F.linear( input, - (weight.to(dtype=input.dtype).view(weight.shape[0],no_groups, -1) - * scales.view(weight.shape[0], no_groups, -1) - ).view(weight.shape[0], -1) + ( + weight.to(dtype=input.dtype).view(weight.shape[0], no_groups, -1) + * scales.view(weight.shape[0], no_groups, -1) + ).view(weight.shape[0], -1), ) - torchat_lib.define( "linear_int4(Tensor input, Tensor weight, Tensor scales_and_zeros, " "Tensor bias=None, *, int groupsize, int origin_in_features, " "int int_features, int out_features, bool padding = True) -> Tensor", ) + @impl(torchat_lib, "linear_int4", "CompositeExplicitAutograd") def linear_int4( - input: torch.Tensor, - weight: torch.Tensor, - scales_and_zeros: torch.Tensor, - bias: torch.Tensor, - *, - groupsize: int, - origin_in_features: int, - in_features: int, - out_features: int, - padding: bool = True, + input: torch.Tensor, + weight: torch.Tensor, + scales_and_zeros: torch.Tensor, + bias: torch.Tensor, + *, + groupsize: int, + origin_in_features: int, + in_features: int, + out_features: int, + padding: bool = True, ) -> torch.Tensor: assert bias is None, "bias != None not implemented" @@ -116,7 +126,7 @@ def linear_int4( # the weight is in int4pack format # rename to remind ourselves of that weight_int4pack = weight - + origin_input_size = input.size() input = input.reshape(-1, origin_input_size[-1]) c = torch.ops.aten._weight_int4pack_mm( @@ -136,10 +146,9 @@ def linear_int4( "dtype precision) -> Tensor", ) + @impl(torchat_lib, "linear_a8w4dq", "CompositeExplicitAutograd") -def linear_a8w4dq( - input, weight, scales, zeros, out_features, groupsize, precision -): +def linear_a8w4dq(input, weight, scales, zeros, out_features, groupsize, precision): x = per_token_dynamic_quant(input) weight_int8 = weight # TODO: verify and remove following reshape code diff --git a/runner/run.cpp b/runner/run.cpp index 27c16cc95..32233cba2 100644 --- a/runner/run.cpp +++ b/runner/run.cpp @@ -13,7 +13,7 @@ #include #endif -#if defined(__AOTI_MODEL__) || (defined (__ET_MODEL__) && defined(USE_ATENLIB)) +#if defined(__AOTI_MODEL__) || (defined(__ET_MODEL__) && defined(USE_ATENLIB)) #include #endif @@ -27,67 +27,81 @@ #include #include -using torch::executor::Module; -using torch::executor::ManagedTensor; -using torch::executor::EValue; using exec_aten::ScalarType; +using torch::executor::EValue; +using torch::executor::ManagedTensor; +using torch::executor::Module; using torch::executor::Result; #endif - // ---------------------------------------------------------------------------- // Transformer model -typedef struct { +typedef struct +{ int vocab_size; // vocabulary size, usually 256 (byte-level) - int seq_len; // max sequence length + int seq_len; // max sequence length } Config; -typedef struct { +typedef struct +{ float *logits; // output logits - int64_t* toks; // tokens seen so far; no kv-cache :( + int64_t *toks; // tokens seen so far; no kv-cache :( } RunState; -typedef struct { - Config config; // the hyperparameters of the architecture (the blueprint) +typedef struct +{ + Config config; // the hyperparameters of the architecture (the blueprint) RunState state; // buffers for the "wave" of activations in the forward pass #ifdef __AOTI_MODEL__ torch::inductor::AOTIModelContainerRunnerCpu *runner; #else // __ET_MODEL__ - Module* runner; + Module *runner; #endif } Transformer; -void malloc_run_state(RunState* s, Config* p) { +void malloc_run_state(RunState *s, Config *p) +{ // we calloc instead of malloc to keep valgrind happy - s->logits = (float *) calloc(p->vocab_size, sizeof(float)); - s->toks = (int64_t *) calloc(p->seq_len, sizeof(int64_t)); - if (!s->logits || !s->toks) { + s->logits = (float *)calloc(p->vocab_size, sizeof(float)); + s->toks = (int64_t *)calloc(p->seq_len, sizeof(int64_t)); + if (!s->logits || !s->toks) + { fprintf(stderr, "malloc failed!\n"); exit(EXIT_FAILURE); } } -void free_run_state(RunState* s) { +void free_run_state(RunState *s) +{ free(s->logits); free(s->toks); } -void read_checkpoint(char* checkpoint, Config* config) { +void read_checkpoint(char *checkpoint, Config *config) +{ FILE *file = fopen(checkpoint, "rb"); - if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); } + if (!file) + { + fprintf(stderr, "Couldn't open file %s\n", checkpoint); + exit(EXIT_FAILURE); + } // read in the config header - if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); } + if (fread(config, sizeof(Config), 1, file) != 1) + { + exit(EXIT_FAILURE); + } // negative vocab size is hacky way of signaling unshared weights. bit yikes. int shared_weights = config->vocab_size > 0 ? 1 : 0; config->vocab_size = abs(config->vocab_size); } -void build_transformer(Transformer *t, char* checkpoint_path, int vocab_size, int seq_len) { +void build_transformer(Transformer *t, char *checkpoint_path, int vocab_size, int seq_len) +{ // read in the Config and the Weights from the checkpoint - //read_checkpoint(checkpoint_path, &t->config); + // read_checkpoint(checkpoint_path, &t->config); // allocate the RunState buffers t->config.vocab_size = vocab_size; t->config.seq_len = seq_len; @@ -95,19 +109,17 @@ void build_transformer(Transformer *t, char* checkpoint_path, int vocab_size, in #ifdef __AOTI_MODEL__ t->runner = new torch::inductor::AOTIModelContainerRunnerCpu( - /* path to model DSO */ checkpoint_path, - /* thread pool size */ 1 - ); + /* path to model DSO */ checkpoint_path, + /* thread pool size */ 1); #else //__ET_MODEL__ t->runner = new Module( - /* path to PTE model */ checkpoint_path, - /* PTE mmap settings */ Module::MlockConfig::UseMlockIgnoreErrors - ); + /* path to PTE model */ checkpoint_path, + /* PTE mmap settings */ Module::MlockConfig::UseMlockIgnoreErrors); #endif - } -void free_transformer(Transformer* t) { +void free_transformer(Transformer *t) +{ // free the RunState buffers free_run_state(&t->state); delete t->runner; @@ -116,29 +128,35 @@ void free_transformer(Transformer* t) { // ---------------------------------------------------------------------------- // neural net blocks; the dynamics of the Transformer -void softmax(float* x, int size) { +void softmax(float *x, int size) +{ // find max value (for numerical stability) float max_val = x[0]; - for (int i = 1; i < size; i++) { - if (x[i] > max_val) { + for (int i = 1; i < size; i++) + { + if (x[i] > max_val) + { max_val = x[i]; } } // exp and sum float sum = 0.0f; - for (int i = 0; i < size; i++) { + for (int i = 0; i < size; i++) + { x[i] = expf(x[i] - max_val); sum += x[i]; } // normalize - for (int i = 0; i < size; i++) { + for (int i = 0; i < size; i++) + { x[i] /= sum; } } -float* forward(Transformer* transformer, int token, int pos) { - Config* p = &transformer->config; - RunState* s = &transformer->state; +float *forward(Transformer *transformer, int token, int pos) +{ + Config *p = &transformer->config; + RunState *s = &transformer->state; s->toks[pos] = token; long token_buffer[1] = {token}; long pos_buffer[1] = {pos}; @@ -157,13 +175,13 @@ float* forward(Transformer* transformer, int token, int pos) { #else // __ET_MODEL__ ManagedTensor pos_managed( - pos_buffer, sizeof(int64_t), { 1 }, ScalarType::Long); + pos_buffer, sizeof(int64_t), {1}, ScalarType::Long); #ifndef __KV_CACHE__ // @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds - ManagedTensor tokens_managed(&(s->toks[pos]), /*ignored*/sizeof(int64_t)*(pos+1), {1, 1}, ScalarType::Long); + ManagedTensor tokens_managed(&(s->toks[pos]), /*ignored*/ sizeof(int64_t) * (pos + 1), {1, 1}, ScalarType::Long); #else // __KV_CACHE__ ManagedTensor tokens_managed( - token_buffer, sizeof(int64_t), {1, 1}, ScalarType::Long); + token_buffer, sizeof(int64_t), {1, 1}, ScalarType::Long); #endif std::vector inputs; auto tmp1 = EValue(tokens_managed.get_aliasing_tensor()); @@ -172,7 +190,8 @@ float* forward(Transformer* transformer, int token, int pos) { inputs.push_back(tmp1); inputs.push_back(tmp2); Result> outputs_res = transformer->runner->forward(inputs); - if (!outputs_res.ok()) { + if (!outputs_res.ok()) + { fprintf(stderr, "Executorch forward() failed."); exit(EXIT_FAILURE); } @@ -187,100 +206,152 @@ float* forward(Transformer* transformer, int token, int pos) { // ---------------------------------------------------------------------------- // The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens -typedef struct { +typedef struct +{ const char *str; int id; } TokenIndex; -typedef struct { - char** vocab; - float* vocab_scores; +typedef struct +{ + char **vocab; + float *vocab_scores; TokenIndex *sorted_vocab; int vocab_size; unsigned int max_token_length; unsigned char byte_pieces[512]; // stores all single-byte strings } Tokenizer; -int compare_tokens(const void *a, const void *b) { - return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); +int compare_tokens(const void *a, const void *b) +{ + return strcmp(((TokenIndex *)a)->str, ((TokenIndex *)b)->str); } -void build_tokenizer(Tokenizer* t, const char* tokenizer_path, int vocab_size) { +void build_tokenizer(Tokenizer *t, const char *tokenizer_path, int vocab_size) +{ // i should have written the vocab_size into the tokenizer file... sigh t->vocab_size = vocab_size; // malloc space to hold the scores and the strings - t->vocab = (char**)malloc(vocab_size * sizeof(char*)); - t->vocab_scores = (float*)malloc(vocab_size * sizeof(float)); + t->vocab = (char **)malloc(vocab_size * sizeof(char *)); + t->vocab_scores = (float *)malloc(vocab_size * sizeof(float)); t->sorted_vocab = NULL; // initialized lazily - for (int i = 0; i < 256; i++) { + for (int i = 0; i < 256; i++) + { t->byte_pieces[i * 2] = (unsigned char)i; t->byte_pieces[i * 2 + 1] = '\0'; } // read in the file FILE *file = fopen(tokenizer_path, "rb"); - if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); } - if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } + if (!file) + { + fprintf(stderr, "couldn't load %s\n", tokenizer_path); + exit(EXIT_FAILURE); + } + if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) + { + fprintf(stderr, "failed read\n"); + exit(EXIT_FAILURE); + } int len; - for (int i = 0; i < vocab_size; i++) { - if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);} - if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } + for (int i = 0; i < vocab_size; i++) + { + if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) + { + fprintf(stderr, "failed read\n"); + exit(EXIT_FAILURE); + } + if (fread(&len, sizeof(int), 1, file) != 1) + { + fprintf(stderr, "failed read\n"); + exit(EXIT_FAILURE); + } t->vocab[i] = (char *)malloc(len + 1); - if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } + if (fread(t->vocab[i], len, 1, file) != 1) + { + fprintf(stderr, "failed read\n"); + exit(EXIT_FAILURE); + } t->vocab[i][len] = '\0'; // add the string terminating token } fclose(file); } -void free_tokenizer(Tokenizer* t) { - for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); } +void free_tokenizer(Tokenizer *t) +{ + for (int i = 0; i < t->vocab_size; i++) + { + free(t->vocab[i]); + } free(t->vocab); free(t->vocab_scores); free(t->sorted_vocab); } -char* decode(Tokenizer* t, int prev_token, int token) { +char *decode(Tokenizer *t, int prev_token, int token) +{ char *piece = t->vocab[token]; // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) - if (prev_token == 1 && piece[0] == ' ') { piece++; } + if (prev_token == 1 && piece[0] == ' ') + { + piece++; + } // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' // parse this and convert and return the actual byte unsigned char byte_val; - if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { - piece = (char*)t->byte_pieces + byte_val * 2; + if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) + { + piece = (char *)t->byte_pieces + byte_val * 2; } return piece; } -void safe_printf(char *piece) { +void safe_printf(char *piece) +{ // piece might be a raw byte token, and we only want to print printable chars or whitespace // because some of the other bytes can be various control codes, backspace, etc. - if (piece == NULL) { return; } - if (piece[0] == '\0') { return; } - if (piece[1] == '\0') { + if (piece == NULL) + { + return; + } + if (piece[0] == '\0') + { + return; + } + if (piece[1] == '\0') + { unsigned char byte_val = piece[0]; - if (!(isprint(byte_val) || isspace(byte_val))) { + if (!(isprint(byte_val) || isspace(byte_val))) + { return; // bad byte, don't print it } } printf("%s", piece); } -int str_lookup(const char *str, TokenIndex *sorted_vocab, int vocab_size) { +int str_lookup(const char *str, TokenIndex *sorted_vocab, int vocab_size) +{ // efficiently find the perfect match for str in vocab, return its index or -1 if not found - TokenIndex tok = { .str = str }; // acts as the key to search for - TokenIndex *res = (TokenIndex *) bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + TokenIndex tok = {.str = str}; // acts as the key to search for + TokenIndex *res = (TokenIndex *)bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); return res != NULL ? res->id : -1; } -void encode(Tokenizer* t, const char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) { +void encode(Tokenizer *t, const char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) +{ // encode the string text (input) into an upper-bound preallocated tokens[] array // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2) - if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); } + if (text == NULL) + { + fprintf(stderr, "cannot encode NULL text\n"); + exit(EXIT_FAILURE); + } - if (t->sorted_vocab == NULL) { + if (t->sorted_vocab == NULL) + { // lazily malloc and sort the vocabulary t->sorted_vocab = (TokenIndex *)malloc(t->vocab_size * sizeof(TokenIndex)); - for (int i = 0; i < t->vocab_size; i++) { + for (int i = 0; i < t->vocab_size; i++) + { t->sorted_vocab[i].str = t->vocab[i]; t->sorted_vocab[i].id = i; } @@ -289,21 +360,23 @@ void encode(Tokenizer* t, const char *text, int8_t bos, int8_t eos, int *tokens, // create a temporary buffer that will store merge candidates of always two consecutive tokens // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1) - const int str_buffer_len = t->max_token_length*2 + 1 + 2; - char* str_buffer = (char *)malloc(str_buffer_len * sizeof(char)); + const int str_buffer_len = t->max_token_length * 2 + 1 + 2; + char *str_buffer = (char *)malloc(str_buffer_len * sizeof(char)); size_t str_len = 0; // start at 0 tokens *n_tokens = 0; // add optional BOS (=1) token, if desired - if (bos) tokens[(*n_tokens)++] = 1; + if (bos) + tokens[(*n_tokens)++] = 1; // add_dummy_prefix is true by default // so prepend a dummy prefix token to the input string, but only if text != "" // TODO: pretty sure this isn't correct in the general case but I don't have the // energy to read more of the sentencepiece code to figure out what it's doing - if (text[0] != '\0') { + if (text[0] != '\0') + { int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size); tokens[(*n_tokens)++] = dummy_prefix; } @@ -317,14 +390,16 @@ void encode(Tokenizer* t, const char *text, int8_t bos, int8_t eos, int *tokens, // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx // process the raw (UTF-8) byte sequence of the input string - for (const char *c = text; *c != '\0'; c++) { + for (const char *c = text; *c != '\0'; c++) + { // reset buffer if the current byte is ASCII or a leading byte // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest // 0x80 is 10000000 // in UTF-8, all continuation bytes start with "10" in first two bits // so in English this is: "if this byte is not a continuation byte" - if ((*c & 0xC0) != 0x80) { + if ((*c & 0xC0) != 0x80) + { // this byte must be either a leading byte (11...) or an ASCII char (0x...) // => reset our location, as we're starting a new UTF-8 codepoint str_len = 0; @@ -336,21 +411,26 @@ void encode(Tokenizer* t, const char *text, int8_t bos, int8_t eos, int *tokens, // while the next character is a continuation byte, continue appending // but if there are too many of them, just stop to avoid overruning str_buffer size. - if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) { + if ((*(c + 1) & 0xC0) == 0x80 && str_len < 4) + { continue; } // ok c+1 is not a continuation byte, so we've read in a full codepoint int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); - if (id != -1) { + if (id != -1) + { // we found this codepoint in vocab, add it as a token tokens[(*n_tokens)++] = id; - } else { + } + else + { // byte_fallback encoding: just encode each byte as a token // +3 is here because the first 3 vocab elements are , , // so the individual bytes only start at index 3 - for (int i=0; i < str_len; i++) { + for (int i = 0; i < str_len; i++) + { tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3; } } @@ -358,16 +438,19 @@ void encode(Tokenizer* t, const char *text, int8_t bos, int8_t eos, int *tokens, } // merge the best consecutive pair each iteration, according the scores in vocab_scores - while (1) { + while (1) + { float best_score = -1e10; int best_id = -1; int best_idx = -1; - for (int i=0; i < (*n_tokens-1); i++) { + for (int i = 0; i < (*n_tokens - 1); i++) + { // check if we can merge the pair (tokens[i], tokens[i+1]) - snprintf(str_buffer, str_buffer_len, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]); + snprintf(str_buffer, str_buffer_len, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i + 1]]); int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); - if (id != -1 && t->vocab_scores[id] > best_score) { + if (id != -1 && t->vocab_scores[id] > best_score) + { // this merge pair exists in vocab! record its score and position best_score = t->vocab_scores[id]; best_id = id; @@ -375,21 +458,24 @@ void encode(Tokenizer* t, const char *text, int8_t bos, int8_t eos, int *tokens, } } - if (best_idx == -1) { + if (best_idx == -1) + { break; // we couldn't find any more pairs to merge, so we're done } // merge the consecutive pair (best_idx, best_idx+1) into new token best_id tokens[best_idx] = best_id; // delete token at position best_idx+1, shift the entire sequence back 1 - for (int i = best_idx+1; i < (*n_tokens-1); i++) { - tokens[i] = tokens[i+1]; + for (int i = best_idx + 1; i < (*n_tokens - 1); i++) + { + tokens[i] = tokens[i + 1]; } (*n_tokens)--; // token length decreased } // add optional EOS (=2) token, if desired - if (eos) tokens[(*n_tokens)++] = 2; + if (eos) + tokens[(*n_tokens)++] = 2; free(str_buffer); } @@ -398,25 +484,30 @@ void encode(Tokenizer* t, const char *text, int8_t bos, int8_t eos, int *tokens, // The Sampler, which takes logits and returns a sampled token // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling -typedef struct { +typedef struct +{ float prob; int index; } ProbIndex; // struct used when sorting probabilities during top-p sampling -typedef struct { +typedef struct +{ int vocab_size; - ProbIndex* probindex; // buffer used in top-p sampling + ProbIndex *probindex; // buffer used in top-p sampling float temperature; float topp; unsigned long long rng_state; } Sampler; -int sample_argmax(float* probabilities, int n) { +int sample_argmax(float *probabilities, int n) +{ // return the index that has the highest probability int max_i = 0; float max_p = probabilities[0]; - for (int i = 1; i < n; i++) { - if (probabilities[i] > max_p) { + for (int i = 1; i < n; i++) + { + if (probabilities[i] > max_p) + { max_i = i; max_p = probabilities[i]; } @@ -424,28 +515,35 @@ int sample_argmax(float* probabilities, int n) { return max_i; } -int sample_mult(float* probabilities, int n, float coin) { +int sample_mult(float *probabilities, int n, float coin) +{ // sample index from probabilities (they must sum to 1!) // coin is a random number in [0, 1), usually from random_f32() float cdf = 0.0f; - for (int i = 0; i < n; i++) { + for (int i = 0; i < n; i++) + { cdf += probabilities[i]; - if (coin < cdf) { + if (coin < cdf) + { return i; } } return n - 1; // in case of rounding errors } -int compare(const void* a, const void* b) { - ProbIndex* a_ = (ProbIndex*) a; - ProbIndex* b_ = (ProbIndex*) b; - if (a_->prob > b_->prob) return -1; - if (a_->prob < b_->prob) return 1; +int compare(const void *a, const void *b) +{ + ProbIndex *a_ = (ProbIndex *)a; + ProbIndex *b_ = (ProbIndex *)b; + if (a_->prob > b_->prob) + return -1; + if (a_->prob < b_->prob) + return 1; return 0; } -int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) { +int sample_topp(float *probabilities, int n, float topp, ProbIndex *probindex, float coin) +{ // top-p sampling (or "nucleus sampling") samples from the smallest set of // tokens that exceed probability topp. This way we never sample tokens that // have very low probabilities and are less likely to go "off the rails". @@ -456,8 +554,10 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, f // values smaller than (1 - topp) / (n - 1) cannot be part of the result // so for efficiency we crop these out as candidates before sorting const float cutoff = (1.0f - topp) / (n - 1); - for (int i = 0; i < n; i++) { - if (probabilities[i] >= cutoff) { + for (int i = 0; i < n; i++) + { + if (probabilities[i] >= cutoff) + { probindex[n0].index = i; probindex[n0].prob = probabilities[i]; n0++; @@ -468,9 +568,11 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, f // truncate the list where cumulative probability exceeds topp float cumulative_prob = 0.0f; int last_idx = n0 - 1; // in case of rounding errors consider all elements - for (int i = 0; i < n0; i++) { + for (int i = 0; i < n0; i++) + { cumulative_prob += probindex[i].prob; - if (cumulative_prob > topp) { + if (cumulative_prob > topp) + { last_idx = i; break; // we've exceeded topp by including last_idx } @@ -479,57 +581,73 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, f // sample from the truncated list float r = coin * cumulative_prob; float cdf = 0.0f; - for (int i = 0; i <= last_idx; i++) { + for (int i = 0; i <= last_idx; i++) + { cdf += probindex[i].prob; - if (r < cdf) { + if (r < cdf) + { return probindex[i].index; } } return probindex[last_idx].index; // in case of rounding errors } -void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) { +void build_sampler(Sampler *sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) +{ sampler->vocab_size = vocab_size; sampler->temperature = temperature; sampler->topp = topp; sampler->rng_state = rng_seed; // buffer only used with nucleus sampling; may not need but it's ~small - sampler->probindex = (ProbIndex *) malloc(sampler->vocab_size * sizeof(ProbIndex)); + sampler->probindex = (ProbIndex *)malloc(sampler->vocab_size * sizeof(ProbIndex)); } -void free_sampler(Sampler* sampler) { +void free_sampler(Sampler *sampler) +{ free(sampler->probindex); } -unsigned int random_u32(unsigned long long *state) { +unsigned int random_u32(unsigned long long *state) +{ // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A *state ^= *state >> 12; *state ^= *state << 25; *state ^= *state >> 27; return (*state * 0x2545F4914F6CDD1Dull) >> 32; } -float random_f32(unsigned long long *state) { // random float32 in [0,1) +float random_f32(unsigned long long *state) +{ // random float32 in [0,1) return (random_u32(state) >> 8) / 16777216.0f; } -int sample(Sampler* sampler, float* logits) { +int sample(Sampler *sampler, float *logits) +{ // sample the token given the logits and some hyperparameters int next; - if (sampler->temperature == 0.0f) { + if (sampler->temperature == 0.0f) + { // greedy argmax sampling: take the token with the highest probability next = sample_argmax(logits, sampler->vocab_size); - } else { + } + else + { // apply the temperature to the logits - for (int q=0; qvocab_size; q++) { logits[q] /= sampler->temperature; } + for (int q = 0; q < sampler->vocab_size; q++) + { + logits[q] /= sampler->temperature; + } // apply softmax to the logits to get the probabilities for next token softmax(logits, sampler->vocab_size); // flip a (float) coin (this is our source of entropy for sampling) float coin = random_f32(&sampler->rng_state); // we sample from this distribution to get the next token - if (sampler->topp <= 0 || sampler->topp >= 1) { + if (sampler->topp <= 0 || sampler->topp >= 1) + { // simply sample from the predicted probability distribution next = sample_mult(logits, sampler->vocab_size, coin); - } else { + } + else + { // top-p (nucleus) sampling, clamping the least likely tokens to zero next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin); } @@ -540,7 +658,8 @@ int sample(Sampler* sampler, float* logits) { // ---------------------------------------------------------------------------- // utilities: time -long time_in_ms() { +long time_in_ms() +{ // return time in milliseconds, for benchmarking the model speed struct timespec time; clock_gettime(CLOCK_REALTIME, &time); @@ -550,75 +669,94 @@ long time_in_ms() { // ---------------------------------------------------------------------------- // generation loop -void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, const char *prompt, int steps) { +void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, const char *prompt, int steps) +{ const char *default_prompt = "Once upon a time"; - if (prompt == NULL) { prompt = default_prompt; } + if (prompt == NULL) + { + prompt = default_prompt; + } // encode the (string) prompt into tokens sequence int num_prompt_tokens = 0; - int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS + int *prompt_tokens = (int *)malloc((strlen(prompt) + 3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens); - if (num_prompt_tokens < 1) { + if (num_prompt_tokens < 1) + { fprintf(stderr, "something is wrong, expected at least 1 prompt token\n"); exit(EXIT_FAILURE); } - #ifdef DEBUG +#ifdef DEBUG std::cerr << "# " << num_prompt_tokens << "\n"; - for(int i = 0; i < num_prompt_tokens; i++) - std::cerr << "[" << i << "] " << prompt_tokens[i]; + for (int i = 0; i < num_prompt_tokens; i++) + std::cerr << "[" << i << "] " << prompt_tokens[i]; std::cerr << "\n"; - #endif +#endif // start the main loop - long start = 0; // used to time our code, only initialized after first iteration - int next; // will store the next token in the sequence + long start = 0; // used to time our code, only initialized after first iteration + int next; // will store the next token in the sequence int token = prompt_tokens[0]; // kick off with the first token in the prompt - int pos = 0; // position in the sequence - while (pos < steps) { + int pos = 0; // position in the sequence + while (pos < steps) + { // forward the transformer to get logits for the next token - float* logits = forward(transformer, token, pos); + float *logits = forward(transformer, token, pos); // advance the state machine - if (pos < num_prompt_tokens - 1) { + if (pos < num_prompt_tokens - 1) + { // if we are still processing the input prompt, force the next prompt token next = prompt_tokens[pos + 1]; - } else { + } + else + { // otherwise sample the next token from the logits next = sample(sampler, logits); } pos++; // data-dependent terminating condition: the BOS (=1) token delimits sequences - if (next == 1) { break; } + if (next == 1) + { + break; + } // print the token as string, decode it with the Tokenizer object - char* piece = decode(tokenizer, token, next); + char *piece = decode(tokenizer, token, next); safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes fflush(stdout); token = next; // init the timer here because the first iteration can be slower - if (start == 0) { start = time_in_ms(); } + if (start == 0) + { + start = time_in_ms(); + } } printf("\n"); // report achieved tok/s (pos-1 because the timer starts after first iteration) - if (pos > 1) { + if (pos > 1) + { long end = time_in_ms(); - fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000); + fprintf(stderr, "achieved tok/s: %f\n", (pos - 1) / (double)(end - start) * 1000); } free(prompt_tokens); } -void read_stdin(const char* guide, char* buffer, size_t bufsize) { +void read_stdin(const char *guide, char *buffer, size_t bufsize) +{ // read a line from stdin, up to but not including \n printf("%s", guide); - if (fgets(buffer, bufsize, stdin) != NULL) { + if (fgets(buffer, bufsize, stdin) != NULL) + { size_t len = strlen(buffer); - if (len > 0 && buffer[len - 1] == '\n') { + if (len > 0 && buffer[len - 1] == '\n') + { buffer[len - 1] = '\0'; // strip newline } } @@ -631,7 +769,8 @@ void read_stdin(const char* guide, char* buffer, size_t bufsize) { // is not safely implemented, it's more a proof of concept atm. void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, - const char *cli_user_prompt, const char *cli_system_prompt, int steps) { + const char *cli_user_prompt, const char *cli_system_prompt, int steps) +{ // buffers for reading the system prompt and user prompt from stdin // you'll notice they are soomewhat haphazardly and unsafely set atm @@ -639,43 +778,55 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char user_prompt[512]; char rendered_prompt[1152]; int num_prompt_tokens = 0; - int* prompt_tokens = (int*)malloc(1152 * sizeof(int)); + int *prompt_tokens = (int *)malloc(1152 * sizeof(int)); int user_idx; // start the main loop int8_t user_turn = 1; // user starts - int next; // will store the next token in the sequence - int token; // stores the current token to feed into the transformer + int next; // will store the next token in the sequence + int token; // stores the current token to feed into the transformer int prev_token; - int pos = 0; // position in the sequence - while (pos < steps) { + int pos = 0; // position in the sequence + while (pos < steps) + { // when it is the user's turn to contribute tokens to the dialog... - if (user_turn) { + if (user_turn) + { // get the (optional) system prompt at position 0 - if (pos == 0) { + if (pos == 0) + { // at position 0, the user can also contribute a system prompt - if (cli_system_prompt == NULL) { + if (cli_system_prompt == NULL) + { // system prompt was not passed in, attempt to get it from stdin read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt)); - } else { + } + else + { // system prompt was passed in, use it strcpy(system_prompt, cli_system_prompt); } } // get the user prompt - if (pos == 0 && cli_user_prompt != NULL) { + if (pos == 0 && cli_user_prompt != NULL) + { // user prompt for position 0 was passed in, use it strcpy(user_prompt, cli_user_prompt); - } else { + } + else + { // otherwise get user prompt from stdin read_stdin("User: ", user_prompt, sizeof(user_prompt)); } // render user/system prompts into the Llama 2 Chat schema - if (pos == 0 && system_prompt[0] != '\0') { + if (pos == 0 && system_prompt[0] != '\0') + { char system_template[] = "[INST] <>\n%s\n<>\n\n%s [/INST]"; snprintf(rendered_prompt, 1151, system_template, system_prompt, user_prompt); - } else { + } + else + { char user_template[] = "[INST] %s [/INST]"; snprintf(rendered_prompt, 1151, user_template, user_prompt); } @@ -687,39 +838,49 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, } // determine the token to pass into the transformer next - if (user_idx < num_prompt_tokens) { + if (user_idx < num_prompt_tokens) + { // if we are still processing the input prompt, force the next prompt token token = prompt_tokens[user_idx++]; - } else { + } + else + { // otherwise use the next token sampled from previous turn token = next; } // EOS (=2) token ends the Assistant turn - if (token == 2) { user_turn = 1; } + if (token == 2) + { + user_turn = 1; + } // forward the transformer to get logits for the next token - float* logits = forward(transformer, token, pos); + float *logits = forward(transformer, token, pos); next = sample(sampler, logits); pos++; - if (user_idx >= num_prompt_tokens && next != 2) { + if (user_idx >= num_prompt_tokens && next != 2) + { // the Assistant is responding, so print its output - char* piece = decode(tokenizer, token, next); + char *piece = decode(tokenizer, token, next); safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes fflush(stdout); } - if (next == 2) { printf("\n"); } + if (next == 2) + { + printf("\n"); + } } printf("\n"); free(prompt_tokens); } - // ---------------------------------------------------------------------------- // CLI, include only if not testing #ifndef TESTING -void error_usage() { +void error_usage() +{ fprintf(stderr, "Usage: run [options]\n"); fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n"); fprintf(stderr, "Options:\n"); @@ -734,45 +895,97 @@ void error_usage() { exit(EXIT_FAILURE); } -int main(int argc, char *argv[]) { +int main(int argc, char *argv[]) +{ // default parameters - char *checkpoint_path = NULL; // e.g. out/model.bin + char *checkpoint_path = NULL; // e.g. out/model.bin const char *tokenizer_path = "tokenizer.bin"; - float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher - float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower + float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher + float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower int vocab_size = 32000; - int steps = 256; // number of steps to run for - const char *prompt = NULL; // prompt string + int steps = 256; // number of steps to run for + const char *prompt = NULL; // prompt string unsigned long long rng_seed = 0; // seed rng with time by default - const char *mode = "generate"; // generate|chat - char *system_prompt = NULL; // the (optional) system prompt to use in chat mode + const char *mode = "generate"; // generate|chat + char *system_prompt = NULL; // the (optional) system prompt to use in chat mode // poor man's C argparse so we can override the defaults above from the command line - if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); } - for (int i = 2; i < argc; i+=2) { + if (argc >= 2) + { + checkpoint_path = argv[1]; + } + else + { + error_usage(); + } + for (int i = 2; i < argc; i += 2) + { // do some basic validation - if (i + 1 >= argc) { error_usage(); } // must have arg after flag - if (argv[i][0] != '-') { error_usage(); } // must start with dash - if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter) + if (i + 1 >= argc) + { + error_usage(); + } // must have arg after flag + if (argv[i][0] != '-') + { + error_usage(); + } // must start with dash + if (strlen(argv[i]) != 2) + { + error_usage(); + } // must be -x (one dash, one letter) // read in the args - if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); } - else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); } - else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); } - else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); } - else if (argv[i][1] == 'v') { vocab_size = atoi(argv[i + 1]); } - else if (argv[i][1] == 'i') { prompt = argv[i + 1]; } - else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; } - else if (argv[i][1] == 'm') { mode = argv[i + 1]; } - else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; } - else { error_usage(); } + if (argv[i][1] == 't') + { + temperature = atof(argv[i + 1]); + } + else if (argv[i][1] == 'p') + { + topp = atof(argv[i + 1]); + } + else if (argv[i][1] == 's') + { + rng_seed = atoi(argv[i + 1]); + } + else if (argv[i][1] == 'n') + { + steps = atoi(argv[i + 1]); + } + else if (argv[i][1] == 'v') + { + vocab_size = atoi(argv[i + 1]); + } + else if (argv[i][1] == 'i') + { + prompt = argv[i + 1]; + } + else if (argv[i][1] == 'z') + { + tokenizer_path = argv[i + 1]; + } + else if (argv[i][1] == 'm') + { + mode = argv[i + 1]; + } + else if (argv[i][1] == 'y') + { + system_prompt = argv[i + 1]; + } + else + { + error_usage(); + } } // parameter validation/overrides - if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL); - if (temperature < 0.0) temperature = 0.0; - if (topp < 0.0 || 1.0 < topp) topp = 0.9; - if (steps < 0) steps = 0; + if (rng_seed <= 0) + rng_seed = (unsigned int)time(NULL); + if (temperature < 0.0) + temperature = 0.0; + if (topp < 0.0 || 1.0 < topp) + topp = 0.9; + if (steps < 0) + steps = 0; // build the Transformer via the model .bin file Transformer transformer; @@ -787,11 +1000,16 @@ int main(int argc, char *argv[]) { build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed); // run! - if (strcmp(mode, "generate") == 0) { + if (strcmp(mode, "generate") == 0) + { generate(&transformer, &tokenizer, &sampler, prompt, steps); - } else if (strcmp(mode, "chat") == 0) { + } + else if (strcmp(mode, "chat") == 0) + { chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps); - } else { + } + else + { fprintf(stderr, "unknown mode: %s\n", mode); error_usage(); } diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 428c4a733..6ac1dc3e1 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -22,7 +22,9 @@ @torch.inference_mode() def convert_hf_checkpoint( *, - checkpoint_dir: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"), + checkpoint_dir: Path = Path( + "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf" + ), model_name: Optional[str] = None, ) -> None: if model_name is None: @@ -45,8 +47,8 @@ def convert_hf_checkpoint( "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, - 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", @@ -66,13 +68,15 @@ def permute(w, n_heads): merged_result = {} for file in sorted(bin_files): - state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) + state_dict = torch.load( + str(file), map_location="cpu", mmap=True, weights_only=True + ) merged_result.update(state_dict) final_result = {} for key, value in merged_result.items(): if "layers" in key: - abstract_key = re.sub(r'(\d+)', '{}', key) - layer_num = re.search(r'\d+', key).group(0) + abstract_key = re.sub(r"(\d+)", "{}", key) + layer_num = re.search(r"\d+", key).group(0) new_key = weight_map[abstract_key] if new_key is None: continue @@ -96,11 +100,17 @@ def permute(w, n_heads): print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") torch.save(final_result, checkpoint_dir / "model.pth") -if __name__ == '__main__': + +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') - parser.add_argument('--checkpoint-dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf")) - parser.add_argument('--model-name', type=str, default=None) + + parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.") + parser.add_argument( + "--checkpoint-dir", + type=Path, + default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"), + ) + parser.add_argument("--model-name", type=str, default=None) args = parser.parse_args() convert_hf_checkpoint( diff --git a/scripts/download.py b/scripts/download.py index 849095ddf..387c41243 100644 --- a/scripts/download.py +++ b/scripts/download.py @@ -11,6 +11,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: from huggingface_hub import snapshot_download + os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) try: snapshot_download( @@ -18,18 +19,30 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) - local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, - ignore_patterns="*safetensors*") + ignore_patterns="*safetensors*", + ) except HTTPError as e: if e.response.status_code == 401: - print("You need to pass a valid `--hf_token=...` to download private checkpoints.") + print( + "You need to pass a valid `--hf_token=...` to download private checkpoints." + ) else: raise e -if __name__ == '__main__': + +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.') - parser.add_argument('--repo-id', type=str, default="checkpoints/meta-llama/llama-2-7b-chat-hf", help='Repository ID to download from.') - parser.add_argument('--hf-token', type=str, default=None, help='HuggingFace API token.') + + parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.") + parser.add_argument( + "--repo-id", + type=str, + default="checkpoints/meta-llama/llama-2-7b-chat-hf", + help="Repository ID to download from.", + ) + parser.add_argument( + "--hf-token", type=str, default=None, help="HuggingFace API token." + ) args = parser.parse_args() hf_download(args.repo_id, args.hf_token) diff --git a/torchat.py b/torchat.py index 4b720b8dd..914d71770 100644 --- a/torchat.py +++ b/torchat.py @@ -19,9 +19,10 @@ default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu' + def cli(): args = cli_args() - + if args.generate or args.chat: check_args(args, "generate") generate_main(args) @@ -32,6 +33,7 @@ def cli(): export_main(args) else: raise RuntimeError("must specify either --generate or --export") - + + if __name__ == "__main__": cli() diff --git a/utils/tokenizer.py b/utils/tokenizer.py index f3c0cc324..eef0a72d8 100644 --- a/utils/tokenizer.py +++ b/utils/tokenizer.py @@ -9,7 +9,8 @@ from sentencepiece import SentencePieceProcessor -TOKENIZER_MODEL = "tokenizer.model" # the llama sentencepiece tokenizer model +TOKENIZER_MODEL = "tokenizer.model" # the llama sentencepiece tokenizer model + class Tokenizer: def __init__(self, tokenizer_model=None): @@ -23,7 +24,7 @@ def __init__(self, tokenizer_model=None): self.bos_id: int = self.sp_model.bos_id() self.eos_id: int = self.sp_model.eos_id() self.pad_id: int = self.sp_model.pad_id() - #print(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") + # print(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() def encode(self, s: str, bos: bool, eos: bool) -> List[int]: @@ -48,11 +49,11 @@ def export(self): t = self.sp_model.id_to_piece(i) s = self.sp_model.get_score(i) if i == self.bos_id: - t = '\n\n' + t = "\n\n" elif i == self.eos_id: - t = '\n\n' - t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace - b = t.encode('utf-8') # bytes of this token, utf-8 encoded + t = "\n\n" + t = t.replace("▁", " ") # sentencepiece uses this character as whitespace + b = t.encode("utf-8") # bytes of this token, utf-8 encoded tokens.append(b) scores.append(s) @@ -62,16 +63,19 @@ def export(self): # write to a binary file # the tokenizer.bin file is the same as .model file, but .bin - tokenizer_bin = self.model_path.replace('.model', '.bin') - with open(tokenizer_bin, 'wb') as f: + tokenizer_bin = self.model_path.replace(".model", ".bin") + with open(tokenizer_bin, "wb") as f: f.write(struct.pack("I", max_token_length)) for bytes, score in zip(tokens, scores): f.write(struct.pack("fI", score, len(bytes))) f.write(bytes) + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--tokenizer-model", type=str, help="optional path to custom tokenizer ") + parser.add_argument( + "-t", "--tokenizer-model", type=str, help="optional path to custom tokenizer " + ) args = parser.parse_args() t = Tokenizer(args.tokenizer_model)