diff --git a/build/builder.py b/build/builder.py index f62c42131..63bfd226c 100644 --- a/build/builder.py +++ b/build/builder.py @@ -3,20 +3,17 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import itertools + import sys import time from pathlib import Path -from typing import Optional, Tuple +from typing import Optional import torch import torch._dynamo.config import torch._inductor.config -from quantize import ( - quantize_model, name_to_dtype, set_precision, get_precision -) -from cli import cli_args +from quantize import quantize_model, name_to_dtype, set_precision, get_precision from dataclasses import dataclass from typing import Union, Optional @@ -40,43 +37,50 @@ class BuilderArgs: def __post_init__(self): if not ( - (self.checkpoint_path and self.checkpoint_path.is_file()) or - (self.checkpoint_dir and self.checkpoint_path.is_dir()) or - (self.gguf_path and self.gguf_path.is_file()) or - (self.dso_path and Path(self.dso_path).is_file()) or - (self.pte_path and Path(self.pte_path).is_file()) + (self.checkpoint_path and self.checkpoint_path.is_file()) + or (self.checkpoint_dir and self.checkpoint_path.is_dir()) + or (self.gguf_path and self.gguf_path.is_file()) + or (self.dso_path and Path(self.dso_path).is_file()) + or (self.pte_path and Path(self.pte_path).is_file()) ): - raise RuntimeError("need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path") + raise RuntimeError( + "need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path" + ) - if (self.dso_path and self.pte_path): + if self.dso_path and self.pte_path: raise RuntimeError("specify either DSO path or PTE path, but not both") - if (self.checkpoint_path and (self.dso_path or self.pte_path)): - print("Warning: checkpoint path ignored because an exported DSO or PTE path specified") - if (self.checkpoint_dir and (self.dso_path or self.pte_path)): - print("Warning: checkpoint dir ignored because an exported DSO or PTE path specified") - if (self.gguf_path and (self.dso_path or self.pte_path)): - print("Warning: GGUF path ignored because an exported DSO or PTE path specified") - + if self.checkpoint_path and (self.dso_path or self.pte_path): + print( + "Warning: checkpoint path ignored because an exported DSO or PTE path specified" + ) + if self.checkpoint_dir and (self.dso_path or self.pte_path): + print( + "Warning: checkpoint dir ignored because an exported DSO or PTE path specified" + ) + if self.gguf_path and (self.dso_path or self.pte_path): + print( + "Warning: GGUF path ignored because an exported DSO or PTE path specified" + ) @classmethod - def from_args(cls, args): # -> BuilderArgs: + def from_args(cls, args): # -> BuilderArgs: return cls( - checkpoint_path = args.checkpoint_path, - checkpoint_dir = args.checkpoint_dir, - params_path = args.params_path, - params_table = args.params_table, - gguf_path = args.gguf_path, - dso_path = args.dso_path, - pte_path = args.pte_path, - device = args.device, - precision = name_to_dtype(args.dtype), - setup_caches = (args.output_dso_path or args.output_pte_path), - use_tp = False, + checkpoint_path=args.checkpoint_path, + checkpoint_dir=args.checkpoint_dir, + params_path=args.params_path, + params_table=args.params_table, + gguf_path=args.gguf_path, + dso_path=args.dso_path, + pte_path=args.pte_path, + device=args.device, + precision=name_to_dtype(args.dtype), + setup_caches=(args.output_dso_path or args.output_pte_path), + use_tp=False, ) @classmethod - def from_speculative_args(cls, args): # -> BuilderArgs: + def from_speculative_args(cls, args): # -> BuilderArgs: speculative_builder_args = BuilderArgs.from_args(args) # let's limit multi-checkpoint to checker speculative_builder_args.checkpoint_dir = None @@ -94,7 +98,7 @@ class TokenizerArgs: is_TikToken: bool = False @classmethod - def from_args(cls, args): # -> TokenizerArgs: + def from_args(cls, args): # -> TokenizerArgs: is_SentencePiece = True is_TikToken = False @@ -108,7 +112,7 @@ def from_args(cls, args): # -> TokenizerArgs: raise RuntimeError(f"cannot find tokenizer model") if not tokenizer_path.is_file(): - raise RuntimeError(f"did not find tokenizer at {tokenizer_path}") + raise RuntimeError(f"did not find tokenizer at {tokenizer_path}") if args.tiktoken: is_SentencePiece = False @@ -117,9 +121,10 @@ def from_args(cls, args): # -> TokenizerArgs: return cls( tokenizer_path=tokenizer_path, is_SentencePiece=is_SentencePiece, - is_TikToken=is_TikToken + is_TikToken=is_TikToken, ) + def _initialize_tokenizer(tokenizer_args: TokenizerArgs): if tokenizer_args.is_SentencePiece: return SentencePieceProcessor(model_file=str(tokenizer_args.tokenizer_path)) @@ -147,6 +152,7 @@ def device_sync(device): wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) + def _load_model(builder_args): if builder_args.gguf_path: model = Transformer.from_gguf(builder_args.gguf_path) @@ -160,9 +166,8 @@ def _load_model(builder_args): else: return _load_model_not_gguf(builder_args) -def _load_model_not_gguf( - builder_args -): + +def _load_model_not_gguf(builder_args): assert not builder_args.gguf_path with torch.device("meta"): @@ -200,7 +205,12 @@ def _load_model_not_gguf( else: checkpoint[key] = cps[0][key] else: - checkpoint = torch.load(builder_args.checkpoint_path, map_location=builder_args.device, mmap=True, weights_only=True) + checkpoint = torch.load( + builder_args.checkpoint_path, + map_location=builder_args.device, + mmap=True, + weights_only=True, + ) if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): checkpoint = checkpoint["model"] @@ -218,21 +228,21 @@ def _load_model_not_gguf( def _initialize_model( - builder_args, - quantize, + builder_args, + quantize, ): print("Loading model ...") t0 = time.time() - model_ = _load_model( - builder_args - ) + model_ = _load_model(builder_args) device_sync(device=builder_args.device) print(f"Time to load model: {time.time() - t0:.02f} seconds") if builder_args.dso_path: # make sure user did not try to set dtype # assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export." - assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export." + assert ( + quantize is None or quantize == "{ }" + ), f"quantize not valid for exported DSO model. Specify quantization during export." try: model = model_ # Replace model forward with the AOT-compiled forward @@ -241,15 +251,20 @@ def _initialize_model( # attributes will NOT be seen on by AOTI-compiled forward # function, e.g. calling model.setup_cache will NOT touch # AOTI compiled and maintained model buffers such as kv_cache. - model.forward = torch._export.aot_load(str(builder_args.dso_path.absolute()), builder_args.device) + model.forward = torch._export.aot_load( + str(builder_args.dso_path.absolute()), builder_args.device + ) except: raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}") elif builder_args.pte_path: # make sure user did not try to set dtype # assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export." - assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export." + assert ( + quantize is None or quantize == "{ }" + ), f"quantize not valid for exported PTE model. Specify quantization during export." try: from build.model_et import PTEModel + model = PTEModel(model_.config, builder_args.pte_path) except Exception as e: raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}") @@ -265,10 +280,7 @@ def _initialize_model( if builder_args.setup_caches: max_seq_length = 350 with torch.device(builder_args.device): - model.setup_caches( - max_batch_size=1, - max_seq_length=max_seq_length - ) + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) model.to(dtype=builder_args.precision) diff --git a/cli.py b/cli.py index b14f0a944..99e1155a1 100644 --- a/cli.py +++ b/cli.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import json import time import os from pathlib import Path @@ -39,12 +40,17 @@ def check_args(args, command_name: str): else: print(f"Warning: {text}") - -def cli_args(): - import argparse - parser = argparse.ArgumentParser(description="Your CLI description.") +def add_arguments_for_generate(parser): + _add_arguments_common(parser) + +def add_arguments_for_eval(parser): + _add_arguments_common(parser) + +def add_arguments_for_export(parser): + _add_arguments_common(parser) +def _add_arguments_common(parser): parser.add_argument( "--seed", type=int, @@ -59,26 +65,11 @@ def cli_args(): action="store_true", help="Whether to use tiktoken tokenizer.", ) - parser.add_argument( - "--export", - action="store_true", - help="Use torchat to export a model.", - ) - parser.add_argument( - "--eval", - action="store_true", - help="Use torchat to eval a model.", - ) - parser.add_argument( - "--generate", - action="store_true", - help="Use torchat to generate a sequence using a model.", - ) parser.add_argument( "--chat", action="store_true", help="Use torchat to for an interactive chat session.", - ) + ) parser.add_argument( "--gui", action="store_true", @@ -231,13 +222,15 @@ def cli_args(): default=None, help='maximum length sequence to evaluate') - args = parser.parse_args() - +def arg_init(args): + if (Path(args.quantize).is_file()): with open(args.quantize, "r") as f: args.quantize = json.loads(f.read()) if args.seed: - torch.manual_seed(args.seed) - + torch.manual_seed(args.seed) + return args + + diff --git a/eval.py b/eval.py index 8ac8c457f..94e099ef2 100644 --- a/eval.py +++ b/eval.py @@ -3,9 +3,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import sys +import argparse import time -from pathlib import Path from typing import Optional import torch @@ -18,32 +17,39 @@ torch._inductor.config.triton.cudagraphs = True torch._dynamo.config.cache_size_limit = 100000 -from cli import cli_args -from quantize import name_to_dtype, set_precision +from cli import add_arguments_for_eval, arg_init +from quantize import set_precision from build.model import Transformer try: import lm_eval + lm_eval_available = True except: lm_eval_available = False -from build.builder import _initialize_model, _initialize_tokenizer, BuilderArgs, TokenizerArgs +from build.builder import ( + _initialize_model, + _initialize_tokenizer, + BuilderArgs, + TokenizerArgs, +) from generate import encode_tokens, model_forward if lm_eval_available: - try: # lm_eval version 0.4 + try: # lm_eval version 0.4 from lm_eval.models.huggingface import HFLM as eval_wrapper from lm_eval.tasks import get_task_dict from lm_eval.evaluator import evaluate - except: #lm_eval version 0.3 + except: # lm_eval version 0.3 from lm_eval import base from lm_eval import tasks from lm_eval import evaluator - eval_wrapper=base.BaseLM - get_task_dict=tasks.get_task_dict - evaluate=evaluator.evaluate + + eval_wrapper = base.BaseLM + get_task_dict = tasks.get_task_dict + evaluate = evaluator.evaluate def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( @@ -84,20 +90,22 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( return seq, input_pos, max_seq_length + class GPTFastEvalWrapper(eval_wrapper): """ A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. """ + def __init__( self, model: Transformer, tokenizer, - max_seq_length: Optional[int]=None, + max_seq_length: Optional[int] = None, ): super().__init__() self._model = model self._tokenizer = tokenizer - self._device = torch.device('cuda') + self._device = torch.device("cuda") self._max_seq_length = 2048 if max_seq_length is None else max_seq_length @property @@ -121,8 +129,7 @@ def device(self): return self._device def tok_encode(self, string: str, **kwargs): - encoded = encode_tokens(self._tokenizer, - string, bos=True, device=self._device) + encoded = encode_tokens(self._tokenizer, string, bos=True, device=self._device) # encoded is a pytorch tensor, but some internal logic in the # eval harness expects it to be a list instead # TODO: verify this for multi-batch as well @@ -138,19 +145,20 @@ def _model_call(self, inps): inps = inps.squeeze(0) max_new_tokens = 1 - seq, input_pos, max_seq_length = \ + seq, input_pos, max_seq_length = ( setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( self._model, inps, max_new_tokens, self.max_length, ) + ) x = seq.index_select(0, input_pos).view(1, -1) logits = model_forward(self._model, x, input_pos) return logits def _model_generate(self, context, max_length, eos_token_id): - raise Exception('unimplemented') + raise Exception("unimplemented") @torch.no_grad() @@ -185,8 +193,8 @@ def eval( except: pass - if 'hendrycks_test' in tasks: - tasks.remove('hendrycks_test') + if "hendrycks_test" in tasks: + tasks.remove("hendrycks_test") tasks += [x for x in lm_eval.tasks.hendrycks_test.create_all_tasks().keys()] task_dict = get_task_dict(tasks) @@ -212,7 +220,7 @@ def main(args) -> None: builder_args = BuilderArgs.from_args(args) tokenizer_args = TokenizerArgs.from_args(args) - + checkpoint_path = args.checkpoint_path checkpoint_dir = args.checkpoint_dir params_path = args.params_path @@ -223,12 +231,12 @@ def main(args) -> None: pte_path = args.pte_path quantize = args.quantize device = args.device - model_dtype = args.dtype + model_dtype = args.dtype tasks = args.tasks limit = args.limit max_seq_length = args.max_seq_length use_tiktoken = args.tiktoken - + print(f"Using device={device}") set_precision(buildeer_args.precision) @@ -240,9 +248,13 @@ def main(args) -> None: ) if compile: - assert not (builder_args.dso_path or builder_args.pte_path), "cannot compile exported model" + assert not ( + builder_args.dso_path or builder_args.pte_path + ), "cannot compile exported model" global model_forward - model_forward = torch.compile(model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True) + model_forward = torch.compile( + model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True + ) torch._inductor.config.coordinate_descent_tuning = True t1 = time.time() @@ -268,11 +280,10 @@ def main(args) -> None: for task, res in result["results"].items(): print(f"{task}: {res}") -if __name__ == '__main__': - def cli(): - args = cli_args() - main(args) - if __name__ == "__main__": - cli() + parser = argparse.ArgumentParser(description="Export specific CLI.") + add_arguments_for_eval(parser) + args = parser.parse_args() + args = arg_init(args) + main(args) diff --git a/export.py b/export.py index 1e5fb5d37..3b02a89ae 100644 --- a/export.py +++ b/export.py @@ -4,16 +4,13 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import time +import argparse import os -from pathlib import Path import torch -import torch.nn as nn -from torch.export import Dim, export +from cli import add_arguments_for_export, arg_init, check_args -from quantize import quantize_model, name_to_dtype, set_precision, get_precision -from cli import cli_args +from quantize import set_precision try: executorch_export_available = True @@ -22,14 +19,9 @@ executorch_exception = f"ET EXPORT EXCEPTION: {e}" executorch_export_available = False +from build.builder import _initialize_model, BuilderArgs from export_aoti import export_model as export_model_aoti -from build.model import Transformer -from build.builder import _initialize_model, BuilderArgs, TokenizerArgs -from generate import decode_one_token -from quantize import quantize_model, name_to_dtype -from torch._export import capture_pre_autograd_graph - default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu' @@ -42,10 +34,8 @@ def device_sync(device): print(f"device={device} is not yet suppported") - def main(args): builder_args = BuilderArgs.from_args(args) - tokenizer_args = TokenizerArgs.from_args(args) quantize = args.quantize print(f"Using device={builder_args.device}") @@ -70,7 +60,9 @@ def main(args): print(f"Exporting model using Executorch to {output_pte_path}") export_model_et(model, builder_args.device, args.output_pte_path, args) else: - print(f"Export with executorch requested but Executorch could not be loaded") + print( + f"Export with executorch requested but Executorch could not be loaded" + ) print(executorch_exception) if output_dso_path: output_dso_path = str(os.path.abspath(output_dso_path)) @@ -78,9 +70,10 @@ def main(args): export_model_aoti(model, builder_args.device, output_dso_path, args) -def cli(): - args = cli_args() - main(args) - if __name__ == "__main__": - cli() + parser = argparse.ArgumentParser(description="Export specific CLI.") + add_arguments_for_export(parser) + args = parser.parse_args() + check_args(args, "export") + args = arg_init(args) + main(args) diff --git a/generate.py b/generate.py index ad6085582..68a393232 100644 --- a/generate.py +++ b/generate.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import argparse import itertools import sys import os @@ -15,37 +16,43 @@ import torch._dynamo.config import torch._inductor.config -from build.builder import _load_model, _initialize_model, _initialize_tokenizer, BuilderArgs, TokenizerArgs +from build.builder import ( + _initialize_model, + _initialize_tokenizer, + BuilderArgs, + TokenizerArgs, +) from build.model import Transformer -from quantize import quantize_model, name_to_dtype, set_precision, get_precision -from cli import cli_args +from quantize import set_precision +from cli import add_arguments_for_generate, arg_init, check_args + @dataclass class GeneratorArgs: prompt: str = "torchat is pronounced torch-chat and is so cool because" - chat: bool = False, - gui: bool = False, - num_samples: int =1, - max_new_tokens: int = 200, - top_k: int = 200, - temperature: int = 0, # deterministic argmax - compile: bool = False, - compile_prefill: bool = False, - speculate_k: int = 5, + chat: bool = (False,) + gui: bool = (False,) + num_samples: int = (1,) + max_new_tokens: int = (200,) + top_k: int = (200,) + temperature: int = (0,) # deterministic argmax + compile: bool = (False,) + compile_prefill: bool = (False,) + speculate_k: int = (5,) @classmethod - def from_args(cls, args): # -> GeneratorArgs: + def from_args(cls, args): # -> GeneratorArgs: return cls( - prompt = args.prompt, - chat = args.chat, - gui = args.gui, - num_samples = args.num_samples, - max_new_tokens = args.max_new_tokens, - top_k = args.top_k, - temperature = args.temperature, - compile = args.compile, - compile_prefill = args.compile_prefill, - speculate_k = args.speculate_k, + prompt=args.prompt, + chat=args.chat, + gui=args.gui, + num_samples=args.num_samples, + max_new_tokens=args.max_new_tokens, + top_k=args.top_k, + temperature=args.temperature, + compile=args.compile, + compile_prefill=args.compile_prefill, + speculate_k=args.speculate_k, ) @@ -152,6 +159,7 @@ def decode_n_tokens( # except: # print("compiled model load not successful, running eager model") + def model_forward(model, x, input_pos): return model(x, input_pos) @@ -336,20 +344,20 @@ def _main( is_chat = "chat" in str(os.path.basename(builder_args.checkpoint_path)) if is_chat: - raise RuntimeError("need to stop filename based kludgery, at a minimum need to look at all pathnames. in particular, this now fails because chat is part of the pathname, yuck!") + raise RuntimeError( + "need to stop filename based kludgery, at a minimum need to look at all pathnames. in particular, this now fails because chat is part of the pathname, yuck!" + ) tokenizer = _initialize_tokenizer(tokenizer_args) builder_args.setup_caches = False - model = _initialize_model( - builder_args, - quantize - ) + model = _initialize_model(builder_args, quantize) # will add a version of _initialize_model in future # (need additional args) if is_speculative: from builder import _load_model + speculative_builder_args = builder_args draft_model = _load_model( @@ -359,7 +367,7 @@ def _main( draft_model = None encoded = encode_tokens(tokenizer, prompt, bos=True, device=builder_args.device) - print (encoded) + print(encoded) prompt_length = encoded.size(0) model_size = sum( @@ -369,7 +377,9 @@ def _main( ] ) if compile: - if is_speculative and builder_args.use_tp: # and ("cuda" in builder_args.device): + if ( + is_speculative and builder_args.use_tp + ): # and ("cuda" in builder_args.device): torch._inductor.config.triton.cudagraph_trees = ( False # Bug with cudagraph trees in this case ) @@ -401,7 +411,9 @@ def _main( prompt = input("What is your prompt? ") if is_chat: prompt = f"{B_INST} {prompt.strip()} {E_INST}" - encoded = encode_tokens(tokenizer, prompt, bos=True, device=builder_args.device) + encoded = encode_tokens( + tokenizer, prompt, bos=True, device=builder_args.device + ) if chat_mode and i >= 0: buffer = [] @@ -503,10 +515,11 @@ def main(args): args.quantize, ) -def cli(): - args = cli_args() - main(args) - if __name__ == "__main__": - cli() + parser = argparse.ArgumentParser(description="Generate specific CLI.") + add_arguments_for_generate(parser) + args = parser.parse_args() + check_args(args, "generate") + args = arg_init(args) + main(args) diff --git a/torchat.py b/torchat.py index 4b720b8dd..504cd8a17 100644 --- a/torchat.py +++ b/torchat.py @@ -4,34 +4,52 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import time -import os -from pathlib import Path +import argparse -import torch -import torch.nn as nn -from torch.export import Dim, export +from cli import ( + add_arguments_for_eval, + add_arguments_for_export, + add_arguments_for_generate, + arg_init, + check_args, +) +from eval import main as eval_main from export import main as export_main from generate import main as generate_main -from eval import main as eval_main -from cli import cli_args, check_args default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu' + def cli(): - args = cli_args() - - if args.generate or args.chat: + parser = argparse.ArgumentParser(description="Top-level command") + subparsers = parser.add_subparsers( + dest="subcommand", help="Subcommands include generate, eval, export" + ) + + parser_generate = subparsers.add_parser("generate") + add_arguments_for_generate(parser_generate) + + parser_eval = subparsers.add_parser("eval") + add_arguments_for_eval(parser_eval) + + parser_export = subparsers.add_parser("export") + add_arguments_for_export(parser_export) + + args = parser.parse_args() + args = arg_init(args) + + if args.subcommand == "generate": check_args(args, "generate") generate_main(args) - elif args.eval: + elif args.subcommand == "eval": eval_main(args) - elif args.export: + elif args.subcommand == "export": check_args(args, "export") export_main(args) else: - raise RuntimeError("must specify either --generate or --export") - + raise RuntimeError("Must specify valid subcommands: generate, export, eval") + + if __name__ == "__main__": cli()