diff --git a/README.md b/README.md index fbbb25d..fd83854 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,43 @@ # GPT2-Haliax The goal of this project is to write a clean implementation of GPT 2 in Haliax. +## Requirements +This project requires Python 3.10. It uses JAX, Equinox, and Haliax for running the neural network, +and Pytree2Safetensors for loading the weights. + +## Installation +First, clone this repository: +```sh +git clone git@github.com:cooljoseph1/gpt2-haliax.git +``` + +Next, go to the newly created directory: +```sh +cd gpt2-haliax +``` + +Then, I recommend setting up a virtual environment. If you have conda installed (either Miniconda or Anaconda), you +can do this with +```sh +conda create -n gpt2-haliax python=3.10 +``` +Your Python version needs to be at least Python 3.10. + +Finally, install the requirements: +``` +pip install -r requirements.txt +``` + +## Running +To do inference, run the command +```sh +python3 inference.py --prompt "" +``` +where `` is your text prompt. There are more options, which can be printed out using the `--help` flag. + +You don't have to provide a prompt; if you don't provide a prompt, it will instead read standard input for the prompt. + ## TODO -- [x] get inference working -- [x] speed up inference (right now it is recomputing attention for everything every time--instead, it should only compute it for the next token) -- [x] Added `jax.jit` to make some things faster. -- [x] clean up positional axis logic in `run/infer.py`` -- [x] Add dropout layers where appropriate +- [ ] Get inference to go longer than 1024 tokens. (GPT2 was only trained with 1024 positional embeddings. This might not be possible to do efficiently.) - [ ] Add training in a train/ folder (right now it has inference in a run/ folder) -- [ ] Figure out a better way to load safetensors/do serialization? See `run/load_gpt2` and `run/gpt2_skeleton` for how it is currently done. \ No newline at end of file +- [ ] Figure out a better way to load safetensors? Right now I'm using Pytree2Safetensors, which is a not-very-polished library I made in a few hours. \ No newline at end of file diff --git a/infer.py b/infer.py new file mode 100755 index 0000000..33c492f --- /dev/null +++ b/infer.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 +from run.main import main + +main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3b97936 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +jax +equinox +haliax +pytree2safetensors \ No newline at end of file diff --git a/run/__main__.py b/run/__main__.py index 65beea1..7186d05 100644 --- a/run/__main__.py +++ b/run/__main__.py @@ -1,11 +1,3 @@ -import sys -import jax -from .infer import infer_bytes +from .main import main -for token_bytes in infer_bytes( - "The sun is the center of the solar system.", - num_tokens=-1, - key=jax.random.key(5) -): - sys.stdout.buffer.write(token_bytes) - sys.stdout.flush() \ No newline at end of file +main() \ No newline at end of file diff --git a/run/main.py b/run/main.py new file mode 100644 index 0000000..328e21b --- /dev/null +++ b/run/main.py @@ -0,0 +1,38 @@ +import sys +import argparse + +def main(): + parser = argparse.ArgumentParser( + description=""" + Run GPT2 inference using JAX and Haliax. + It reads a prompt from stdin (unless --prompt is supplied) and prints the output to stdout. + """ + ) + parser.add_argument("--prompt", type=str, default="", help="The prompt. If not provided, the program will read from stdin") + parser.add_argument("--num-tokens", type=int, default=-1, help="The number of tokens to generate. Use -1 for an infinite number of tokens") + parser.add_argument("--seed", type=int, default=0, help="The seed to the random number generator") + + args = parser.parse_args() + + prompt, num_tokens, seed = args.prompt, args.num_tokens, args.seed + + if prompt == "": + prompt = sys.stdin.read() + sys.stdin.close() # Not needed anymore--can close the pipe + + + if prompt == "": # If no prompt is given, either as an arg or from stdin, print out the help and exit with an error + parser.print_help(file=sys.stderr) # Print to stderr because it is an *error* to not provide a prompt + sys.exit(2) + + + import jax + from .infer import infer_bytes + + for token_bytes in infer_bytes( + prompt, + num_tokens=num_tokens, + key=jax.random.key(seed) + ): + sys.stdout.buffer.write(token_bytes) + sys.stdout.flush() \ No newline at end of file