-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updated README, created an inference script with argument parsing
- Loading branch information
1 parent
83f889f
commit 1baa477
Showing
5 changed files
with
86 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,43 @@ | ||
# GPT2-Haliax | ||
The goal of this project is to write a clean implementation of GPT 2 in Haliax. | ||
|
||
## Requirements | ||
This project requires Python 3.10. It uses JAX, Equinox, and Haliax for running the neural network, | ||
and Pytree2Safetensors for loading the weights. | ||
|
||
## Installation | ||
First, clone this repository: | ||
```sh | ||
git clone [email protected]:cooljoseph1/gpt2-haliax.git | ||
``` | ||
|
||
Next, go to the newly created directory: | ||
```sh | ||
cd gpt2-haliax | ||
``` | ||
|
||
Then, I recommend setting up a virtual environment. If you have conda installed (either Miniconda or Anaconda), you | ||
can do this with | ||
```sh | ||
conda create -n gpt2-haliax python=3.10 | ||
``` | ||
Your Python version needs to be at least Python 3.10. | ||
|
||
Finally, install the requirements: | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Running | ||
To do inference, run the command | ||
```sh | ||
python3 inference.py --prompt "<prompt>" | ||
``` | ||
where `<prompt>` is your text prompt. There are more options, which can be printed out using the `--help` flag. | ||
|
||
You don't have to provide a prompt; if you don't provide a prompt, it will instead read standard input for the prompt. | ||
|
||
## TODO | ||
- [x] get inference working | ||
- [x] speed up inference (right now it is recomputing attention for everything every time--instead, it should only compute it for the next token) | ||
- [x] Added `jax.jit` to make some things faster. | ||
- [x] clean up positional axis logic in `run/infer.py`` | ||
- [x] Add dropout layers where appropriate | ||
- [ ] Get inference to go longer than 1024 tokens. (GPT2 was only trained with 1024 positional embeddings. This might not be possible to do efficiently.) | ||
- [ ] Add training in a train/ folder (right now it has inference in a run/ folder) | ||
- [ ] Figure out a better way to load safetensors/do serialization? See `run/load_gpt2` and `run/gpt2_skeleton` for how it is currently done. | ||
- [ ] Figure out a better way to load safetensors? Right now I'm using Pytree2Safetensors, which is a not-very-polished library I made in a few hours. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#!/usr/bin/env python3 | ||
from run.main import main | ||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
jax | ||
equinox | ||
haliax | ||
pytree2safetensors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,3 @@ | ||
import sys | ||
import jax | ||
from .infer import infer_bytes | ||
from .main import main | ||
|
||
for token_bytes in infer_bytes( | ||
"The sun is the center of the solar system.", | ||
num_tokens=-1, | ||
key=jax.random.key(5) | ||
): | ||
sys.stdout.buffer.write(token_bytes) | ||
sys.stdout.flush() | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import sys | ||
import argparse | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description=""" | ||
Run GPT2 inference using JAX and Haliax. | ||
It reads a prompt from stdin (unless --prompt is supplied) and prints the output to stdout. | ||
""" | ||
) | ||
parser.add_argument("--prompt", type=str, default="", help="The prompt. If not provided, the program will read from stdin") | ||
parser.add_argument("--num-tokens", type=int, default=-1, help="The number of tokens to generate. Use -1 for an infinite number of tokens") | ||
parser.add_argument("--seed", type=int, default=0, help="The seed to the random number generator") | ||
|
||
args = parser.parse_args() | ||
|
||
prompt, num_tokens, seed = args.prompt, args.num_tokens, args.seed | ||
|
||
if prompt == "": | ||
prompt = sys.stdin.read() | ||
sys.stdin.close() # Not needed anymore--can close the pipe | ||
|
||
|
||
if prompt == "": # If no prompt is given, either as an arg or from stdin, print out the help and exit with an error | ||
parser.print_help(file=sys.stderr) # Print to stderr because it is an *error* to not provide a prompt | ||
sys.exit(2) | ||
|
||
|
||
import jax | ||
from .infer import infer_bytes | ||
|
||
for token_bytes in infer_bytes( | ||
prompt, | ||
num_tokens=num_tokens, | ||
key=jax.random.key(seed) | ||
): | ||
sys.stdout.buffer.write(token_bytes) | ||
sys.stdout.flush() |