Skip to content

Commit

Permalink
Updated README, created an inference script with argument parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
cooljoseph1 committed Aug 28, 2024
1 parent 83f889f commit 1baa477
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 16 deletions.
44 changes: 38 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,43 @@
# GPT2-Haliax
The goal of this project is to write a clean implementation of GPT 2 in Haliax.

## Requirements
This project requires Python 3.10. It uses JAX, Equinox, and Haliax for running the neural network,
and Pytree2Safetensors for loading the weights.

## Installation
First, clone this repository:
```sh
git clone [email protected]:cooljoseph1/gpt2-haliax.git
```

Next, go to the newly created directory:
```sh
cd gpt2-haliax
```

Then, I recommend setting up a virtual environment. If you have conda installed (either Miniconda or Anaconda), you
can do this with
```sh
conda create -n gpt2-haliax python=3.10
```
Your Python version needs to be at least Python 3.10.

Finally, install the requirements:
```
pip install -r requirements.txt
```

## Running
To do inference, run the command
```sh
python3 inference.py --prompt "<prompt>"
```
where `<prompt>` is your text prompt. There are more options, which can be printed out using the `--help` flag.

You don't have to provide a prompt; if you don't provide a prompt, it will instead read standard input for the prompt.

## TODO
- [x] get inference working
- [x] speed up inference (right now it is recomputing attention for everything every time--instead, it should only compute it for the next token)
- [x] Added `jax.jit` to make some things faster.
- [x] clean up positional axis logic in `run/infer.py``
- [x] Add dropout layers where appropriate
- [ ] Get inference to go longer than 1024 tokens. (GPT2 was only trained with 1024 positional embeddings. This might not be possible to do efficiently.)
- [ ] Add training in a train/ folder (right now it has inference in a run/ folder)
- [ ] Figure out a better way to load safetensors/do serialization? See `run/load_gpt2` and `run/gpt2_skeleton` for how it is currently done.
- [ ] Figure out a better way to load safetensors? Right now I'm using Pytree2Safetensors, which is a not-very-polished library I made in a few hours.
4 changes: 4 additions & 0 deletions infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/usr/bin/env python3
from run.main import main

main()
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
jax
equinox
haliax
pytree2safetensors
12 changes: 2 additions & 10 deletions run/__main__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
import sys
import jax
from .infer import infer_bytes
from .main import main

for token_bytes in infer_bytes(
"The sun is the center of the solar system.",
num_tokens=-1,
key=jax.random.key(5)
):
sys.stdout.buffer.write(token_bytes)
sys.stdout.flush()
main()
38 changes: 38 additions & 0 deletions run/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import sys
import argparse

def main():
parser = argparse.ArgumentParser(
description="""
Run GPT2 inference using JAX and Haliax.
It reads a prompt from stdin (unless --prompt is supplied) and prints the output to stdout.
"""
)
parser.add_argument("--prompt", type=str, default="", help="The prompt. If not provided, the program will read from stdin")
parser.add_argument("--num-tokens", type=int, default=-1, help="The number of tokens to generate. Use -1 for an infinite number of tokens")
parser.add_argument("--seed", type=int, default=0, help="The seed to the random number generator")

args = parser.parse_args()

prompt, num_tokens, seed = args.prompt, args.num_tokens, args.seed

if prompt == "":
prompt = sys.stdin.read()
sys.stdin.close() # Not needed anymore--can close the pipe


if prompt == "": # If no prompt is given, either as an arg or from stdin, print out the help and exit with an error
parser.print_help(file=sys.stderr) # Print to stderr because it is an *error* to not provide a prompt
sys.exit(2)


import jax
from .infer import infer_bytes

for token_bytes in infer_bytes(
prompt,
num_tokens=num_tokens,
key=jax.random.key(seed)
):
sys.stdout.buffer.write(token_bytes)
sys.stdout.flush()

0 comments on commit 1baa477

Please sign in to comment.