Skip to content

Commit

Permalink
Create top-level torchat.py CLI binary
Browse files Browse the repository at this point in the history
python torchat.py {generate, export, eval} --foo --bar

but also you can do:

python export.py --foo --bar

python generate.py --foo --bar

python eval.py --foo --bar
  • Loading branch information
mergennachin committed Apr 16, 2024
1 parent 848ff25 commit bfbe846
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 176 deletions.
116 changes: 64 additions & 52 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,17 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import itertools

import sys
import time
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional

import torch
import torch._dynamo.config
import torch._inductor.config

from quantize import (
quantize_model, name_to_dtype, set_precision, get_precision
)
from cli import cli_args
from quantize import quantize_model, name_to_dtype, set_precision, get_precision
from dataclasses import dataclass
from typing import Union, Optional

Expand All @@ -40,43 +37,50 @@ class BuilderArgs:

def __post_init__(self):
if not (
(self.checkpoint_path and self.checkpoint_path.is_file()) or
(self.checkpoint_dir and self.checkpoint_path.is_dir()) or
(self.gguf_path and self.gguf_path.is_file()) or
(self.dso_path and Path(self.dso_path).is_file()) or
(self.pte_path and Path(self.pte_path).is_file())
(self.checkpoint_path and self.checkpoint_path.is_file())
or (self.checkpoint_dir and self.checkpoint_path.is_dir())
or (self.gguf_path and self.gguf_path.is_file())
or (self.dso_path and Path(self.dso_path).is_file())
or (self.pte_path and Path(self.pte_path).is_file())
):
raise RuntimeError("need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path")
raise RuntimeError(
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
)

if (self.dso_path and self.pte_path):
if self.dso_path and self.pte_path:
raise RuntimeError("specify either DSO path or PTE path, but not both")

if (self.checkpoint_path and (self.dso_path or self.pte_path)):
print("Warning: checkpoint path ignored because an exported DSO or PTE path specified")
if (self.checkpoint_dir and (self.dso_path or self.pte_path)):
print("Warning: checkpoint dir ignored because an exported DSO or PTE path specified")
if (self.gguf_path and (self.dso_path or self.pte_path)):
print("Warning: GGUF path ignored because an exported DSO or PTE path specified")

if self.checkpoint_path and (self.dso_path or self.pte_path):
print(
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
)
if self.checkpoint_dir and (self.dso_path or self.pte_path):
print(
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
)
if self.gguf_path and (self.dso_path or self.pte_path):
print(
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
)

@classmethod
def from_args(cls, args): # -> BuilderArgs:
def from_args(cls, args): # -> BuilderArgs:
return cls(
checkpoint_path = args.checkpoint_path,
checkpoint_dir = args.checkpoint_dir,
params_path = args.params_path,
params_table = args.params_table,
gguf_path = args.gguf_path,
dso_path = args.dso_path,
pte_path = args.pte_path,
device = args.device,
precision = name_to_dtype(args.dtype),
setup_caches = (args.output_dso_path or args.output_pte_path),
use_tp = False,
checkpoint_path=args.checkpoint_path,
checkpoint_dir=args.checkpoint_dir,
params_path=args.params_path,
params_table=args.params_table,
gguf_path=args.gguf_path,
dso_path=args.dso_path,
pte_path=args.pte_path,
device=args.device,
precision=name_to_dtype(args.dtype),
setup_caches=(args.output_dso_path or args.output_pte_path),
use_tp=False,
)

@classmethod
def from_speculative_args(cls, args): # -> BuilderArgs:
def from_speculative_args(cls, args): # -> BuilderArgs:
speculative_builder_args = BuilderArgs.from_args(args)
# let's limit multi-checkpoint to checker
speculative_builder_args.checkpoint_dir = None
Expand All @@ -94,7 +98,7 @@ class TokenizerArgs:
is_TikToken: bool = False

@classmethod
def from_args(cls, args): # -> TokenizerArgs:
def from_args(cls, args): # -> TokenizerArgs:
is_SentencePiece = True
is_TikToken = False

Expand All @@ -108,7 +112,7 @@ def from_args(cls, args): # -> TokenizerArgs:
raise RuntimeError(f"cannot find tokenizer model")

if not tokenizer_path.is_file():
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")

if args.tiktoken:
is_SentencePiece = False
Expand All @@ -117,9 +121,10 @@ def from_args(cls, args): # -> TokenizerArgs:
return cls(
tokenizer_path=tokenizer_path,
is_SentencePiece=is_SentencePiece,
is_TikToken=is_TikToken
is_TikToken=is_TikToken,
)


def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
if tokenizer_args.is_SentencePiece:
return SentencePieceProcessor(model_file=str(tokenizer_args.tokenizer_path))
Expand Down Expand Up @@ -147,6 +152,7 @@ def device_sync(device):
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))


def _load_model(builder_args):
if builder_args.gguf_path:
model = Transformer.from_gguf(builder_args.gguf_path)
Expand All @@ -160,9 +166,8 @@ def _load_model(builder_args):
else:
return _load_model_not_gguf(builder_args)

def _load_model_not_gguf(
builder_args
):

def _load_model_not_gguf(builder_args):
assert not builder_args.gguf_path

with torch.device("meta"):
Expand Down Expand Up @@ -200,7 +205,12 @@ def _load_model_not_gguf(
else:
checkpoint[key] = cps[0][key]
else:
checkpoint = torch.load(builder_args.checkpoint_path, map_location=builder_args.device, mmap=True, weights_only=True)
checkpoint = torch.load(
builder_args.checkpoint_path,
map_location=builder_args.device,
mmap=True,
weights_only=True,
)

if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
checkpoint = checkpoint["model"]
Expand All @@ -218,21 +228,21 @@ def _load_model_not_gguf(


def _initialize_model(
builder_args,
quantize,
builder_args,
quantize,
):
print("Loading model ...")
t0 = time.time()
model_ = _load_model(
builder_args
)
model_ = _load_model(builder_args)
device_sync(device=builder_args.device)
print(f"Time to load model: {time.time() - t0:.02f} seconds")

if builder_args.dso_path:
# make sure user did not try to set dtype
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export."
assert (
quantize is None or quantize == "{ }"
), f"quantize not valid for exported DSO model. Specify quantization during export."
try:
model = model_
# Replace model forward with the AOT-compiled forward
Expand All @@ -241,15 +251,20 @@ def _initialize_model(
# attributes will NOT be seen on by AOTI-compiled forward
# function, e.g. calling model.setup_cache will NOT touch
# AOTI compiled and maintained model buffers such as kv_cache.
model.forward = torch._export.aot_load(str(builder_args.dso_path.absolute()), builder_args.device)
model.forward = torch._export.aot_load(
str(builder_args.dso_path.absolute()), builder_args.device
)
except:
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
elif builder_args.pte_path:
# make sure user did not try to set dtype
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export."
assert (
quantize is None or quantize == "{ }"
), f"quantize not valid for exported PTE model. Specify quantization during export."
try:
from build.model_et import PTEModel

model = PTEModel(model_.config, builder_args.pte_path)
except Exception as e:
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
Expand All @@ -265,10 +280,7 @@ def _initialize_model(
if builder_args.setup_caches:
max_seq_length = 350
with torch.device(builder_args.device):
model.setup_caches(
max_batch_size=1,
max_seq_length=max_seq_length
)
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

model.to(dtype=builder_args.precision)

Expand Down
41 changes: 17 additions & 24 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import json
import time
import os
from pathlib import Path
Expand Down Expand Up @@ -39,12 +40,17 @@ def check_args(args, command_name: str):
else:
print(f"Warning: {text}")


def cli_args():
import argparse

parser = argparse.ArgumentParser(description="Your CLI description.")
def add_arguments_for_generate(parser):
_add_arguments_common(parser)

def add_arguments_for_eval(parser):
_add_arguments_common(parser)

def add_arguments_for_export(parser):
_add_arguments_common(parser)

def _add_arguments_common(parser):
parser.add_argument(
"--seed",
type=int,
Expand All @@ -59,26 +65,11 @@ def cli_args():
action="store_true",
help="Whether to use tiktoken tokenizer.",
)
parser.add_argument(
"--export",
action="store_true",
help="Use torchat to export a model.",
)
parser.add_argument(
"--eval",
action="store_true",
help="Use torchat to eval a model.",
)
parser.add_argument(
"--generate",
action="store_true",
help="Use torchat to generate a sequence using a model.",
)
parser.add_argument(
"--chat",
action="store_true",
help="Use torchat to for an interactive chat session.",
)
)
parser.add_argument(
"--gui",
action="store_true",
Expand Down Expand Up @@ -231,13 +222,15 @@ def cli_args():
default=None,
help='maximum length sequence to evaluate')

args = parser.parse_args()

def arg_init(args):
if (Path(args.quantize).is_file()):
with open(args.quantize, "r") as f:
args.quantize = json.loads(f.read())

if args.seed:
torch.manual_seed(args.seed)

torch.manual_seed(args.seed)
return args


Loading

0 comments on commit bfbe846

Please sign in to comment.