Skip to content

Commit

Permalink
make --device fast the default (pytorch#515)
Browse files Browse the repository at this point in the history
* make --device fast the default

* Update iOS.md (pytorch#517)

* Update iOS.md

* Update iOS.md

* Pip to pip3 (pytorch#504)

* remove macos-12 test

* pip to pip3

* break aoti CI jobs separately (pytorch#500)

* init

* fixes

* more fixes

* fixes

* fix

* fix

* bug fix

* add objcopy update

* suppress int8

* undefined variable

---------

Co-authored-by: Michael Gschwind <[email protected]>

* Support llama3 in chat in run.cpp  (pytorch#486)

* refactor chat runner in preparation for llama3

* add sketch for llama3 prompt template and move to returning tokens

* fix tiktoken

* fixes to chat

* add default llama_ver

* Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519)

* remove code for no KV Cache path (pytorch#527)

* Update ADVANCED-USERS.md (pytorch#529)

Update Advanced Users description to reflect changes in the repo since the description was initially created.

* runner-aoti on cuda (pytorch#531)

* runner-aoti on cuda

* transfer results back to CPU

* transfer results back to CPU

* runner-aoti on cuda

* Update runner_build.md (pytorch#530)

Update description of runner and build process in runner_build.md

* clean up runner code a little (pytorch#532)

* clean up runner code a little

* update

* update

* pull out generate loop in chat

* updates

* edit docs

* typo

* move int8 linear class and function into qops.py (pytorch#534)

* add dtype tests for runner-aoti + runner-et (pytorch#539)

* add dtype tests for runner-aoti + runner-et

* typo

* Quantized embedding (pytorch#536)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* Move Linear int4 to qops (pytorch#537)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* move int4 linear to qops

* Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548)

This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1.

* fix generate for llama3 (pytorch#538)

* fix generate for llama3

* switch more things to C

* remove C++ header

* add delegation visualization instructions (pytorch#551)

* Add dtype runner aoti (pytorch#552)

* add dtype tests for runner-aoti + runner-et

* typo

* add dtype test runner-aoti

* test sdpa with fp16 (pytorch#553)

* test sdpa with fp16

* kv cache fp32

* typo

* update (pytorch#560)

* Only support newest versions of lm-eval (pytorch#556)

Summary:
remove support for lm-eval 0.3 to reduce the options we have

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* split cpu eval CI by dtype (pytorch#554)

* split cpu eval CI by dtype

* fix

* differentiate names with checks

* keep one name the same as old

* fix

* Removing duplicate HF issue message from README (pytorch#559)

Co-authored-by: Michael Gschwind <[email protected]>

* doc updates (pytorch#567)

* Add VM-safe MPS check

---------

Co-authored-by: Anthony Shoumikhin <[email protected]>
Co-authored-by: metascroy <[email protected]>
Co-authored-by: Nikita Shulga <[email protected]>
Co-authored-by: lucylq <[email protected]>
Co-authored-by: Jerry Zhang <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>
  • Loading branch information
7 people committed Jul 17, 2024
1 parent 62d4041 commit 6732127
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
18 changes: 16 additions & 2 deletions build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,27 @@ def state_dict_device(d, device="cpu") -> Dict:
#########################################################################
### move state dict to specified device ###

def is_mps_available() -> bool:
if not torch.backends.mps.is_available():
return False

# out system says mps is available, but it's not on VMs
# so let's set up some memry, and see if that work:
try:
mps_tensor = torch.zero(1024, dtype=torch.float16, device="mps")
except:
return False

# MPS, is that you?
return True


def get_device_str(device) -> str:
if isinstance(device, str) and device == "fast":
return (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
else "mps" if is_mps_available() else "cpu"
)
else:
return str(device)
Expand All @@ -173,6 +187,6 @@ def get_device(device) -> str:
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
else "mps" if is_mps_available() else "cpu"
)
return torch.device(device)
2 changes: 1 addition & 1 deletion cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from build.utils import allowable_dtype_names, allowable_params_table, get_device_str
from download import download_and_convert, is_model_downloaded

default_device = "cpu"
default_device = "fast"


# Handle CLI arguments that are common to a majority of subcommands.
Expand Down
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def decode_n_tokens(
):
new_tokens, new_probs = [], []
encountered_eos = False
for i in range(
for _i in range(
num_new_tokens - 1
): # -1 to save space to run an EoS if dont generate it naturally
# Actually better for Inductor to codegen attention here
Expand Down

0 comments on commit 6732127

Please sign in to comment.