Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable lintrunner across the project #216

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
merge_base_with = "origin/main"

[[linter]]
code = 'FLAKE8'
include_patterns = ['**/*.py']
command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'flake8_linter',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'pip_init',
'--dry-run={{DRYRUN}}',
'--requirement=requirements-lintrunner.txt',
]

# Black + usort
[[linter]]
code = 'UFMT'
include_patterns = [
'**/*.py',
'**/*.pyi',
]
command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'ufmt_linter',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'pip_init',
'--dry-run={{DRYRUN}}',
'--no-black-binary',
'--requirement=requirements-lintrunner.txt',
]
is_formatter = true
119 changes: 66 additions & 53 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,19 @@
import itertools
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
import torch._dynamo.config
import torch._inductor.config

from quantize import (
quantize_model, name_to_dtype, set_precision, get_precision
)
from cli import cli_args
from dataclasses import dataclass
from typing import Union, Optional

from quantize import get_precision, name_to_dtype, quantize_model, set_precision

from sentencepiece import SentencePieceProcessor

from build.model import Transformer


Expand All @@ -40,43 +38,50 @@ class BuilderArgs:

def __post_init__(self):
if not (
(self.checkpoint_path and self.checkpoint_path.is_file()) or
(self.checkpoint_dir and self.checkpoint_path.is_dir()) or
(self.gguf_path and self.gguf_path.is_file()) or
(self.dso_path and Path(self.dso_path).is_file()) or
(self.pte_path and Path(self.pte_path).is_file())
(self.checkpoint_path and self.checkpoint_path.is_file())
or (self.checkpoint_dir and self.checkpoint_path.is_dir())
or (self.gguf_path and self.gguf_path.is_file())
or (self.dso_path and Path(self.dso_path).is_file())
or (self.pte_path and Path(self.pte_path).is_file())
):
raise RuntimeError("need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path")
raise RuntimeError(
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
)

if (self.dso_path and self.pte_path):
if self.dso_path and self.pte_path:
raise RuntimeError("specify either DSO path or PTE path, but not both")

if (self.checkpoint_path and (self.dso_path or self.pte_path)):
print("Warning: checkpoint path ignored because an exported DSO or PTE path specified")
if (self.checkpoint_dir and (self.dso_path or self.pte_path)):
print("Warning: checkpoint dir ignored because an exported DSO or PTE path specified")
if (self.gguf_path and (self.dso_path or self.pte_path)):
print("Warning: GGUF path ignored because an exported DSO or PTE path specified")

if self.checkpoint_path and (self.dso_path or self.pte_path):
print(
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
)
if self.checkpoint_dir and (self.dso_path or self.pte_path):
print(
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
)
if self.gguf_path and (self.dso_path or self.pte_path):
print(
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
)

@classmethod
def from_args(cls, args): # -> BuilderArgs:
def from_args(cls, args): # -> BuilderArgs:
return cls(
checkpoint_path = args.checkpoint_path,
checkpoint_dir = args.checkpoint_dir,
params_path = args.params_path,
params_table = args.params_table,
gguf_path = args.gguf_path,
dso_path = args.dso_path,
pte_path = args.pte_path,
device = args.device,
precision = name_to_dtype(args.dtype),
setup_caches = (args.output_dso_path or args.output_pte_path),
use_tp = False,
checkpoint_path=args.checkpoint_path,
checkpoint_dir=args.checkpoint_dir,
params_path=args.params_path,
params_table=args.params_table,
gguf_path=args.gguf_path,
dso_path=args.dso_path,
pte_path=args.pte_path,
device=args.device,
precision=name_to_dtype(args.dtype),
setup_caches=(args.output_dso_path or args.output_pte_path),
use_tp=False,
)

@classmethod
def from_speculative_args(cls, args): # -> BuilderArgs:
def from_speculative_args(cls, args): # -> BuilderArgs:
speculative_builder_args = BuilderArgs.from_args(args)
# let's limit multi-checkpoint to checker
speculative_builder_args.checkpoint_dir = None
Expand All @@ -94,7 +99,7 @@ class TokenizerArgs:
is_TikToken: bool = False

@classmethod
def from_args(cls, args): # -> TokenizerArgs:
def from_args(cls, args): # -> TokenizerArgs:
is_SentencePiece = True
is_TikToken = False

Expand All @@ -108,7 +113,7 @@ def from_args(cls, args): # -> TokenizerArgs:
raise RuntimeError(f"cannot find tokenizer model")

if not tokenizer_path.is_file():
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")

if args.tiktoken:
is_SentencePiece = False
Expand All @@ -117,9 +122,10 @@ def from_args(cls, args): # -> TokenizerArgs:
return cls(
tokenizer_path=tokenizer_path,
is_SentencePiece=is_SentencePiece,
is_TikToken=is_TikToken
is_TikToken=is_TikToken,
)


def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
if tokenizer_args.is_SentencePiece:
return SentencePieceProcessor(model_file=str(tokenizer_args.tokenizer_path))
Expand Down Expand Up @@ -147,6 +153,7 @@ def device_sync(device):
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))


def _load_model(builder_args):
if builder_args.gguf_path:
model = Transformer.from_gguf(builder_args.gguf_path)
Expand All @@ -160,9 +167,8 @@ def _load_model(builder_args):
else:
return _load_model_not_gguf(builder_args)

def _load_model_not_gguf(
builder_args
):

def _load_model_not_gguf(builder_args):
assert not builder_args.gguf_path

with torch.device("meta"):
Expand Down Expand Up @@ -200,7 +206,12 @@ def _load_model_not_gguf(
else:
checkpoint[key] = cps[0][key]
else:
checkpoint = torch.load(builder_args.checkpoint_path, map_location=builder_args.device, mmap=True, weights_only=True)
checkpoint = torch.load(
builder_args.checkpoint_path,
map_location=builder_args.device,
mmap=True,
weights_only=True,
)

if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
checkpoint = checkpoint["model"]
Expand All @@ -218,21 +229,21 @@ def _load_model_not_gguf(


def _initialize_model(
builder_args,
quantize,
builder_args,
quantize,
):
print("Loading model ...")
t0 = time.time()
model_ = _load_model(
builder_args
)
model_ = _load_model(builder_args)
device_sync(device=builder_args.device)
print(f"Time to load model: {time.time() - t0:.02f} seconds")

if builder_args.dso_path:
# make sure user did not try to set dtype
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export."
assert (
quantize is None or quantize == "{ }"
), f"quantize not valid for exported DSO model. Specify quantization during export."
try:
model = model_
# Replace model forward with the AOT-compiled forward
Expand All @@ -241,15 +252,20 @@ def _initialize_model(
# attributes will NOT be seen on by AOTI-compiled forward
# function, e.g. calling model.setup_cache will NOT touch
# AOTI compiled and maintained model buffers such as kv_cache.
model.forward = torch._export.aot_load(str(builder_args.dso_path.absolute()), builder_args.device)
model.forward = torch._export.aot_load(
str(builder_args.dso_path.absolute()), builder_args.device
)
except:
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
elif builder_args.pte_path:
# make sure user did not try to set dtype
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export."
assert (
quantize is None or quantize == "{ }"
), f"quantize not valid for exported PTE model. Specify quantization during export."
try:
from build.model_et import PTEModel

model = PTEModel(model_.config, builder_args.pte_path)
except Exception as e:
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
Expand All @@ -265,10 +281,7 @@ def _initialize_model(
if builder_args.setup_caches:
max_seq_length = 350
with torch.device(builder_args.device):
model.setup_caches(
max_batch_size=1,
max_seq_length=max_seq_length
)
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

model.to(dtype=builder_args.precision)

Expand Down
Loading
Loading