Skip to content

Commit

Permalink
Use lintrunner across the project
Browse files Browse the repository at this point in the history
  • Loading branch information
mergennachin committed Apr 16, 2024
1 parent 34699a6 commit d1d94ee
Show file tree
Hide file tree
Showing 20 changed files with 786 additions and 515 deletions.
51 changes: 51 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
merge_base_with = "origin/main"

[[linter]]
code = 'FLAKE8'
include_patterns = ['**/*.py']
command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'flake8_linter',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'pip_init',
'--dry-run={{DRYRUN}}',
'--requirement=requirements-lintrunner.txt',
]

# Black + usort
[[linter]]
code = 'UFMT'
include_patterns = [
'**/*.py',
'**/*.pyi',
]
command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'ufmt_linter',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'pip_init',
'--dry-run={{DRYRUN}}',
'--no-black-binary',
'--requirement=requirements-lintrunner.txt',
]
is_formatter = true
119 changes: 66 additions & 53 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,19 @@
import itertools
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
import torch._dynamo.config
import torch._inductor.config

from quantize import (
quantize_model, name_to_dtype, set_precision, get_precision
)
from cli import cli_args
from dataclasses import dataclass
from typing import Union, Optional

from quantize import get_precision, name_to_dtype, quantize_model, set_precision

from sentencepiece import SentencePieceProcessor

from build.model import Transformer


Expand All @@ -40,43 +38,50 @@ class BuilderArgs:

def __post_init__(self):
if not (
(self.checkpoint_path and self.checkpoint_path.is_file()) or
(self.checkpoint_dir and self.checkpoint_path.is_dir()) or
(self.gguf_path and self.gguf_path.is_file()) or
(self.dso_path and Path(self.dso_path).is_file()) or
(self.pte_path and Path(self.pte_path).is_file())
(self.checkpoint_path and self.checkpoint_path.is_file())
or (self.checkpoint_dir and self.checkpoint_path.is_dir())
or (self.gguf_path and self.gguf_path.is_file())
or (self.dso_path and Path(self.dso_path).is_file())
or (self.pte_path and Path(self.pte_path).is_file())
):
raise RuntimeError("need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path")
raise RuntimeError(
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
)

if (self.dso_path and self.pte_path):
if self.dso_path and self.pte_path:
raise RuntimeError("specify either DSO path or PTE path, but not both")

if (self.checkpoint_path and (self.dso_path or self.pte_path)):
print("Warning: checkpoint path ignored because an exported DSO or PTE path specified")
if (self.checkpoint_dir and (self.dso_path or self.pte_path)):
print("Warning: checkpoint dir ignored because an exported DSO or PTE path specified")
if (self.gguf_path and (self.dso_path or self.pte_path)):
print("Warning: GGUF path ignored because an exported DSO or PTE path specified")

if self.checkpoint_path and (self.dso_path or self.pte_path):
print(
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
)
if self.checkpoint_dir and (self.dso_path or self.pte_path):
print(
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
)
if self.gguf_path and (self.dso_path or self.pte_path):
print(
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
)

@classmethod
def from_args(cls, args): # -> BuilderArgs:
def from_args(cls, args): # -> BuilderArgs:
return cls(
checkpoint_path = args.checkpoint_path,
checkpoint_dir = args.checkpoint_dir,
params_path = args.params_path,
params_table = args.params_table,
gguf_path = args.gguf_path,
dso_path = args.dso_path,
pte_path = args.pte_path,
device = args.device,
precision = name_to_dtype(args.dtype),
setup_caches = (args.output_dso_path or args.output_pte_path),
use_tp = False,
checkpoint_path=args.checkpoint_path,
checkpoint_dir=args.checkpoint_dir,
params_path=args.params_path,
params_table=args.params_table,
gguf_path=args.gguf_path,
dso_path=args.dso_path,
pte_path=args.pte_path,
device=args.device,
precision=name_to_dtype(args.dtype),
setup_caches=(args.output_dso_path or args.output_pte_path),
use_tp=False,
)

@classmethod
def from_speculative_args(cls, args): # -> BuilderArgs:
def from_speculative_args(cls, args): # -> BuilderArgs:
speculative_builder_args = BuilderArgs.from_args(args)
# let's limit multi-checkpoint to checker
speculative_builder_args.checkpoint_dir = None
Expand All @@ -94,7 +99,7 @@ class TokenizerArgs:
is_TikToken: bool = False

@classmethod
def from_args(cls, args): # -> TokenizerArgs:
def from_args(cls, args): # -> TokenizerArgs:
is_SentencePiece = True
is_TikToken = False

Expand All @@ -108,7 +113,7 @@ def from_args(cls, args): # -> TokenizerArgs:
raise RuntimeError(f"cannot find tokenizer model")

if not tokenizer_path.is_file():
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")

if args.tiktoken:
is_SentencePiece = False
Expand All @@ -117,9 +122,10 @@ def from_args(cls, args): # -> TokenizerArgs:
return cls(
tokenizer_path=tokenizer_path,
is_SentencePiece=is_SentencePiece,
is_TikToken=is_TikToken
is_TikToken=is_TikToken,
)


def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
if tokenizer_args.is_SentencePiece:
return SentencePieceProcessor(model_file=str(tokenizer_args.tokenizer_path))
Expand Down Expand Up @@ -147,6 +153,7 @@ def device_sync(device):
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))


def _load_model(builder_args):
if builder_args.gguf_path:
model = Transformer.from_gguf(builder_args.gguf_path)
Expand All @@ -160,9 +167,8 @@ def _load_model(builder_args):
else:
return _load_model_not_gguf(builder_args)

def _load_model_not_gguf(
builder_args
):

def _load_model_not_gguf(builder_args):
assert not builder_args.gguf_path

with torch.device("meta"):
Expand Down Expand Up @@ -200,7 +206,12 @@ def _load_model_not_gguf(
else:
checkpoint[key] = cps[0][key]
else:
checkpoint = torch.load(builder_args.checkpoint_path, map_location=builder_args.device, mmap=True, weights_only=True)
checkpoint = torch.load(
builder_args.checkpoint_path,
map_location=builder_args.device,
mmap=True,
weights_only=True,
)

if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
checkpoint = checkpoint["model"]
Expand All @@ -218,21 +229,21 @@ def _load_model_not_gguf(


def _initialize_model(
builder_args,
quantize,
builder_args,
quantize,
):
print("Loading model ...")
t0 = time.time()
model_ = _load_model(
builder_args
)
model_ = _load_model(builder_args)
device_sync(device=builder_args.device)
print(f"Time to load model: {time.time() - t0:.02f} seconds")

if builder_args.dso_path:
# make sure user did not try to set dtype
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export."
assert (
quantize is None or quantize == "{ }"
), f"quantize not valid for exported DSO model. Specify quantization during export."
try:
model = model_
# Replace model forward with the AOT-compiled forward
Expand All @@ -241,15 +252,20 @@ def _initialize_model(
# attributes will NOT be seen on by AOTI-compiled forward
# function, e.g. calling model.setup_cache will NOT touch
# AOTI compiled and maintained model buffers such as kv_cache.
model.forward = torch._export.aot_load(str(builder_args.dso_path.absolute()), builder_args.device)
model.forward = torch._export.aot_load(
str(builder_args.dso_path.absolute()), builder_args.device
)
except:
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
elif builder_args.pte_path:
# make sure user did not try to set dtype
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export."
assert (
quantize is None or quantize == "{ }"
), f"quantize not valid for exported PTE model. Specify quantization during export."
try:
from build.model_et import PTEModel

model = PTEModel(model_.config, builder_args.pte_path)
except Exception as e:
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
Expand All @@ -265,10 +281,7 @@ def _initialize_model(
if builder_args.setup_caches:
max_seq_length = 350
with torch.device(builder_args.device):
model.setup_caches(
max_batch_size=1,
max_seq_length=max_seq_length
)
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

model.to(dtype=builder_args.precision)

Expand Down
Loading

0 comments on commit d1d94ee

Please sign in to comment.