Skip to content

Commit

Permalink
Create top-level torchat.py CLI binary
Browse files Browse the repository at this point in the history
  • Loading branch information
mergennachin committed Apr 16, 2024
1 parent d5cf1c8 commit b5b4c2f
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 121 deletions.
16 changes: 8 additions & 8 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import itertools

import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Optional, Union

import torch
import torch._dynamo.config
import torch._inductor.config
from cli import cli_args

from quantize import get_precision, name_to_dtype, quantize_model, set_precision
from quantize import name_to_dtype, quantize_model

from sentencepiece import SentencePieceProcessor

Expand Down Expand Up @@ -110,7 +110,7 @@ def from_args(cls, args): # -> TokenizerArgs:
elif args.checkpoint_dir:
tokenizer_path = args.checkpoint_dir / "tokenizer.model"
else:
raise RuntimeError(f"cannot find tokenizer model")
raise RuntimeError("cannot find tokenizer model")

if not tokenizer_path.is_file():
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")
Expand Down Expand Up @@ -243,7 +243,7 @@ def _initialize_model(
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert (
quantize is None or quantize == "{ }"
), f"quantize not valid for exported DSO model. Specify quantization during export."
), "quantize not valid for exported DSO model. Specify quantization during export."
try:
model = model_
# Replace model forward with the AOT-compiled forward
Expand All @@ -262,12 +262,12 @@ def _initialize_model(
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert (
quantize is None or quantize == "{ }"
), f"quantize not valid for exported PTE model. Specify quantization during export."
), "quantize not valid for exported PTE model. Specify quantization during export."
try:
from build.model_et import PTEModel

model = PTEModel(model_.config, builder_args.pte_path)
except Exception as e:
except Exception:
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
else:
model = model_
Expand Down
42 changes: 17 additions & 25 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import time
import json
from pathlib import Path

import torch
import torch.nn as nn


default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'

Expand Down Expand Up @@ -41,11 +40,19 @@ def check_args(args, command_name: str):
print(f"Warning: {text}")


def cli_args():
import argparse
def add_arguments_for_generate(parser):
_add_arguments_common(parser)


def add_arguments_for_eval(parser):
_add_arguments_common(parser)


def add_arguments_for_export(parser):
_add_arguments_common(parser)

parser = argparse.ArgumentParser(description="Your CLI description.")

def _add_arguments_common(parser):
parser.add_argument(
"--seed",
type=int,
Expand All @@ -60,25 +67,10 @@ def cli_args():
action="store_true",
help="Whether to use tiktoken tokenizer.",
)
parser.add_argument(
"--export",
action="store_true",
help="Use torchchat to export a model.",
)
parser.add_argument(
"--eval",
action="store_true",
help="Use torchchat to eval a model.",
)
parser.add_argument(
"--generate",
action="store_true",
help="Use torchchat to generate a sequence using a model.",
)
parser.add_argument(
"--chat",
action="store_true",
help="Use torchchat to for an interactive chat session.",
help="Use torchat to for an interactive chat session.",
)
parser.add_argument(
"--gui",
Expand Down Expand Up @@ -162,10 +154,10 @@ def cli_args():
parser.add_argument(
"--quantize", type=str, default="{ }", help="Quantization options."
)
parser.add_argument("--params-table", type=str, default=None, help="Device to use")
parser.add_argument(
"--device", type=str, default=default_device, help="Device to use"
)
parser.add_argument("--params-table", type=str, default=None, help="Device to use")
parser.add_argument(
"--tasks",
nargs="+",
Expand All @@ -183,13 +175,13 @@ def cli_args():
help="maximum length sequence to evaluate",
)

args = parser.parse_args()

def arg_init(args):

if Path(args.quantize).is_file():
with open(args.quantize, "r") as f:
args.quantize = json.loads(f.read())

if args.seed:
torch.manual_seed(args.seed)

return args
54 changes: 21 additions & 33 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,33 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import sys
import argparse
import time
from pathlib import Path
from typing import Optional

import torch
import torch._dynamo.config
import torch._inductor.config

from build.builder import (
_initialize_model,
_initialize_tokenizer,
BuilderArgs,
TokenizerArgs,
)

from build.model import Transformer
from cli import add_arguments_for_eval, arg_init
from generate import encode_tokens, model_forward

from quantize import set_precision

torch._dynamo.config.automatic_dynamic_shapes = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.triton.cudagraphs = True
torch._dynamo.config.cache_size_limit = 100000

from build.model import Transformer
from cli import cli_args
from quantize import name_to_dtype, set_precision

try:
import lm_eval
Expand All @@ -29,13 +38,6 @@
except:
lm_eval_available = False

from build.builder import (
_initialize_model,
_initialize_tokenizer,
BuilderArgs,
TokenizerArgs,
)
from generate import encode_tokens, model_forward

if lm_eval_available:
try: # lm_eval version 0.4
Expand Down Expand Up @@ -218,30 +220,19 @@ def main(args) -> None:

builder_args = BuilderArgs.from_args(args)
tokenizer_args = TokenizerArgs.from_args(args)

checkpoint_path = args.checkpoint_path
checkpoint_dir = args.checkpoint_dir
params_path = args.params_path
params_table = args.params_table
gguf_path = args.gguf_path
tokenizer_path = args.tokenizer_path
dso_path = args.dso_path
pte_path = args.pte_path
quantize = args.quantize
device = args.device
model_dtype = args.dtype
tasks = args.tasks
limit = args.limit
max_seq_length = args.max_seq_length
use_tiktoken = args.tiktoken

print(f"Using device={device}")
set_precision(buildeer_args.precision)
set_precision(builder_args.precision)

tokenizer = _initialize_tokenizer(tokenizer_args)
builder_args.setup_caches = False
model = _initialize_model(
buildeer_args,
builder_args,
quantize,
)

Expand Down Expand Up @@ -280,11 +271,8 @@ def main(args) -> None:


if __name__ == "__main__":

def cli():
args = cli_args()
main(args)


if __name__ == "__main__":
cli()
parser = argparse.ArgumentParser(description="Export specific CLI.")
add_arguments_for_eval(parser)
args = parser.parse_args()
args = arg_init(args)
main(args)
34 changes: 13 additions & 21 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os
import time
from pathlib import Path

import torch
import torch.nn as nn
from cli import cli_args

from quantize import get_precision, name_to_dtype, quantize_model, set_precision
from torch.export import Dim, export
from build.builder import _initialize_model, BuilderArgs
from cli import add_arguments_for_export, arg_init, check_args
from export_aoti import export_model as export_model_aoti

from quantize import set_precision

try:
executorch_export_available = True
Expand All @@ -22,13 +22,6 @@
executorch_exception = f"ET EXPORT EXCEPTION: {e}"
executorch_export_available = False

from build.builder import _initialize_model, BuilderArgs, TokenizerArgs

from build.model import Transformer
from export_aoti import export_model as export_model_aoti
from generate import decode_one_token
from quantize import name_to_dtype, quantize_model
from torch._export import capture_pre_autograd_graph

default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'

Expand All @@ -44,7 +37,6 @@ def device_sync(device):

def main(args):
builder_args = BuilderArgs.from_args(args)
tokenizer_args = TokenizerArgs.from_args(args)
quantize = args.quantize

print(f"Using device={builder_args.device}")
Expand All @@ -70,7 +62,7 @@ def main(args):
export_model_et(model, builder_args.device, args.output_pte_path, args)
else:
print(
f"Export with executorch requested but Executorch could not be loaded"
"Export with executorch requested but Executorch could not be loaded"
)
print(executorch_exception)
if output_dso_path:
Expand All @@ -79,10 +71,10 @@ def main(args):
export_model_aoti(model, builder_args.device, output_dso_path, args)


def cli():
args = cli_args()
main(args)


if __name__ == "__main__":
cli()
parser = argparse.ArgumentParser(description="Export specific CLI.")
add_arguments_for_export(parser)
args = parser.parse_args()
check_args(args, "export")
args = arg_init(args)
main(args)
23 changes: 10 additions & 13 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import itertools
import os
import sys
Expand All @@ -23,8 +24,8 @@
TokenizerArgs,
)
from build.model import Transformer
from cli import cli_args
from quantize import get_precision, name_to_dtype, quantize_model, set_precision
from cli import add_arguments_for_generate, arg_init, check_args
from quantize import set_precision


@dataclass
Expand Down Expand Up @@ -137,7 +138,7 @@ def decode_n_tokens(
**sampling_kwargs,
):
new_tokens, new_probs = [], []
for i in range(num_new_tokens):
for _ in range(num_new_tokens):
with torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_mem_efficient=False, enable_math=True
): # Actually better for Inductor to codegen attention here
Expand Down Expand Up @@ -356,8 +357,6 @@ def _main(
# will add a version of _initialize_model in future
# (need additional args)
if is_speculative:
from builder import _load_model

speculative_builder_args = builder_args

draft_model = _load_model(
Expand Down Expand Up @@ -496,8 +495,6 @@ def main(args):
builder_args = BuilderArgs.from_args(args)
speculative_builder_args = BuilderArgs.from_speculative_args(args)
tokenizer_args = TokenizerArgs.from_args(args)
generator_args = GeneratorArgs.from_args(args)

_main(
builder_args,
speculative_builder_args,
Expand All @@ -516,10 +513,10 @@ def main(args):
)


def cli():
args = cli_args()
main(args)


if __name__ == "__main__":
cli()
parser = argparse.ArgumentParser(description="Generate specific CLI.")
add_arguments_for_generate(parser)
args = parser.parse_args()
check_args(args, "generate")
args = arg_init(args)
main(args)
Loading

0 comments on commit b5b4c2f

Please sign in to comment.