diff --git a/colabs/llms/jax_gpt_dev.ipynb b/colabs/llms/jax_gpt_dev.ipynb new file mode 100644 index 0000000..46fe282 --- /dev/null +++ b/colabs/llms/jax_gpt_dev.ipynb @@ -0,0 +1,3373 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NIbX_6V1ELk2" + }, + "source": [ + "## import" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "id": "a6NK_zR1EN1I" + }, + "outputs": [], + "source": [ + "from jax import device_put\n", + "import time\n", + "\n", + "import jax.numpy as jnp\n", + "from jax import lax\n", + "import jax\n", + "import jax.random as random\n", + "\n", + "import jax.numpy as jnp\n", + "from jax import random\n", + "from jax.nn.initializers import normal\n", + "import flax.linen as nn\n", + "import jax.nn as jnn\n", + "\n", + "import optax\n", + "from optax import adam" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wJpXpmjEYC_T" + }, + "source": [ + "## Building a GPT\n", + "\n", + "Companion notebook to the [Zero To Hero](https://karpathy.ai/zero-to-hero.html) video on GPT." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kqorRX1vF3KP" + }, + "source": [ + "### load data" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "h5hjCcLDr2WC", + "outputId": "f37218fb-ea5f-4513-e676-eb56b2bedced" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--2024-04-01 21:55:14-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n", + "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n", + "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 1115394 (1.1M) [text/plain]\n", + "Saving to: ‘input.txt.2’\n", + "\n", + "\rinput.txt.2 0%[ ] 0 --.-KB/s \rinput.txt.2 100%[===================>] 1.06M --.-KB/s in 0.03s \n", + "\n", + "2024-04-01 21:55:14 (34.1 MB/s) - ‘input.txt.2’ saved [1115394/1115394]\n", + "\n" + ] + } + ], + "source": [ + "# We always start with a dataset to train on. Let's download the tiny shakespeare dataset\n", + "!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "id": "O6medjfRsLD9" + }, + "outputs": [], + "source": [ + "# read it in to inspect it\n", + "with open('input.txt', 'r', encoding='utf-8') as f:\n", + " text = f.read()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "6xWI_VyAsN8F", + "outputId": "ee3ef8e9-2dc5-40a4-8862-2888737fe0b5" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "length of dataset in characters: 1115394\n" + ] + } + ], + "source": [ + "print(\"length of dataset in characters: \", len(text))" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2c5V0FvqseE0", + "outputId": "9d8131df-1186-4920-d006-16d55cdc2f53" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "First Citizen:\n", + "Before we proceed any further, hear me speak.\n", + "\n", + "All:\n", + "Speak, speak.\n", + "\n", + "First Citizen:\n", + "You are all resolved rather to die than to famish?\n", + "\n", + "All:\n", + "Resolved. resolved.\n", + "\n", + "First Citizen:\n", + "First, you know Caius Marcius is chief enemy to the people.\n", + "\n", + "All:\n", + "We know't, we know't.\n", + "\n", + "First Citizen:\n", + "Let us kill him, and we'll have corn at our own price.\n", + "Is't a verdict?\n", + "\n", + "All:\n", + "No more talking on't; let it be done: away, away!\n", + "\n", + "Second Citizen:\n", + "One word, good citizens.\n", + "\n", + "First Citizen:\n", + "We are accounted poor citizens, the patricians good.\n", + "What authority surfeits on would relieve us: if they\n", + "would yield us but the superfluity, while it were\n", + "wholesome, we might guess they relieved us humanely;\n", + "but they think we are too dear: the leanness that\n", + "afflicts us, the object of our misery, is as an\n", + "inventory to particularise their abundance; our\n", + "sufferance is a gain to them Let us revenge this with\n", + "our pikes, ere we become rakes: for the gods know I\n", + "speak this in hunger for bread, not in thirst for revenge.\n", + "\n", + "\n" + ] + } + ], + "source": [ + "# let's look at the first 1000 characters\n", + "print(text[:1000])" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0e-Rbyr8sfM8", + "outputId": "716efcf0-abc2-4b44-d9e6-6dc23415c33c" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + " !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n", + "65\n" + ] + } + ], + "source": [ + "# here are all the unique characters that occur in this text\n", + "chars = sorted(list(set(text)))\n", + "vocab_size = len(chars)\n", + "print(''.join(chars))\n", + "print(vocab_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "qFaINwGqD1Bm", + "outputId": "cc555ac7-a842-48d5-8de8-67a16f5799d2" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "True" + ] + }, + "metadata": {}, + "execution_count": 31 + } + ], + "source": [ + "'!' in chars" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Yw1LKNCgwjj1", + "outputId": "f3b63de3-916d-4919-c6e5-9681e5cd288c" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[46, 47, 47, 1, 58, 46, 43, 56, 43]\n", + "hii there\n" + ] + } + ], + "source": [ + "# create a mapping from characters to integers\n", + "stoi = { ch:i for i,ch in enumerate(chars) }\n", + "itos = { i:ch for i,ch in enumerate(chars) }\n", + "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n", + "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n", + "\n", + "print(encode(\"hii there\"))\n", + "print(decode(encode(\"hii there\")))" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YJb0OXPwzvqg", + "outputId": "e8a896c5-dd1f-4005-f6a8-59874c93fca6" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + ":8: UserWarning: Explicitly requested dtype requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n", + " data = jnp.array(encode(text), dtype=jnp.int64)\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(1115394,) int32\n", + "[18 47 56 57 58 1 15 47 58 47 64 43 52 10 0 14 43 44 53 56 43 1 61 43\n", + " 1 54 56 53 41 43 43 42 1 39 52 63 1 44 59 56 58 46 43 56 6 1 46 43\n", + " 39 56 1 51 43 1 57 54 43 39 49 8 0 0 13 50 50 10 0 31 54 43 39 49\n", + " 6 1 57 54 43 39 49 8 0 0 18 47 56 57 58 1 15 47 58 47 64 43 52 10\n", + " 0 37 53 59 1 39 56 43 1 39 50 50 1 56 43 57 53 50 60 43 42 1 56 39\n", + " 58 46 43 56 1 58 53 1 42 47 43 1 58 46 39 52 1 58 53 1 44 39 51 47\n", + " 57 46 12 0 0 13 50 50 10 0 30 43 57 53 50 60 43 42 8 1 56 43 57 53\n", + " 50 60 43 42 8 0 0 18 47 56 57 58 1 15 47 58 47 64 43 52 10 0 18 47\n", + " 56 57 58 6 1 63 53 59 1 49 52 53 61 1 15 39 47 59 57 1 25 39 56 41\n", + " 47 59 57 1 47 57 1 41 46 47 43 44 1 43 52 43 51 63 1 58 53 1 58 46\n", + " 43 1 54 43 53 54 50 43 8 0 0 13 50 50 10 0 35 43 1 49 52 53 61 5\n", + " 58 6 1 61 43 1 49 52 53 61 5 58 8 0 0 18 47 56 57 58 1 15 47 58\n", + " 47 64 43 52 10 0 24 43 58 1 59 57 1 49 47 50 50 1 46 47 51 6 1 39\n", + " 52 42 1 61 43 5 50 50 1 46 39 60 43 1 41 53 56 52 1 39 58 1 53 59\n", + " 56 1 53 61 52 1 54 56 47 41 43 8 0 21 57 5 58 1 39 1 60 43 56 42\n", + " 47 41 58 12 0 0 13 50 50 10 0 26 53 1 51 53 56 43 1 58 39 50 49 47\n", + " 52 45 1 53 52 5 58 11 1 50 43 58 1 47 58 1 40 43 1 42 53 52 43 10\n", + " 1 39 61 39 63 6 1 39 61 39 63 2 0 0 31 43 41 53 52 42 1 15 47 58\n", + " 47 64 43 52 10 0 27 52 43 1 61 53 56 42 6 1 45 53 53 42 1 41 47 58\n", + " 47 64 43 52 57 8 0 0 18 47 56 57 58 1 15 47 58 47 64 43 52 10 0 35\n", + " 43 1 39 56 43 1 39 41 41 53 59 52 58 43 42 1 54 53 53 56 1 41 47 58\n", + " 47 64 43 52 57 6 1 58 46 43 1 54 39 58 56 47 41 47 39 52 57 1 45 53\n", + " 53 42 8 0 35 46 39 58 1 39 59 58 46 53 56 47 58 63 1 57 59 56 44 43\n", + " 47 58 57 1 53 52 1 61 53 59 50 42 1 56 43 50 47 43 60 43 1 59 57 10\n", + " 1 47 44 1 58 46 43 63 0 61 53 59 50 42 1 63 47 43 50 42 1 59 57 1\n", + " 40 59 58 1 58 46 43 1 57 59 54 43 56 44 50 59 47 58 63 6 1 61 46 47\n", + " 50 43 1 47 58 1 61 43 56 43 0 61 46 53 50 43 57 53 51 43 6 1 61 43\n", + " 1 51 47 45 46 58 1 45 59 43 57 57 1 58 46 43 63 1 56 43 50 47 43 60\n", + " 43 42 1 59 57 1 46 59 51 39 52 43 50 63 11 0 40 59 58 1 58 46 43 63\n", + " 1 58 46 47 52 49 1 61 43 1 39 56 43 1 58 53 53 1 42 43 39 56 10 1\n", + " 58 46 43 1 50 43 39 52 52 43 57 57 1 58 46 39 58 0 39 44 44 50 47 41\n", + " 58 57 1 59 57 6 1 58 46 43 1 53 40 48 43 41 58 1 53 44 1 53 59 56\n", + " 1 51 47 57 43 56 63 6 1 47 57 1 39 57 1 39 52 0 47 52 60 43 52 58\n", + " 53 56 63 1 58 53 1 54 39 56 58 47 41 59 50 39 56 47 57 43 1 58 46 43\n", + " 47 56 1 39 40 59 52 42 39 52 41 43 11 1 53 59 56 0 57 59 44 44 43 56\n", + " 39 52 41 43 1 47 57 1 39 1 45 39 47 52 1 58 53 1 58 46 43 51 1 24\n", + " 43 58 1 59 57 1 56 43 60 43 52 45 43 1 58 46 47 57 1 61 47 58 46 0\n", + " 53 59 56 1 54 47 49 43 57 6 1 43 56 43 1 61 43 1 40 43 41 53 51 43\n", + " 1 56 39 49 43 57 10 1 44 53 56 1 58 46 43 1 45 53 42 57 1 49 52 53\n", + " 61 1 21 0 57 54 43 39 49 1 58 46 47 57 1 47 52 1 46 59 52 45 43 56\n", + " 1 44 53 56 1 40 56 43 39 42 6 1 52 53 58 1 47 52 1 58 46 47 56 57\n", + " 58 1 44 53 56 1 56 43 60 43 52 45 43 8 0 0]\n" + ] + } + ], + "source": [ + "# # let's now encode the entire text dataset and store it into a torch.Tensor\n", + "# import torch # we use PyTorch: https://pytorch.org\n", + "# data = torch.tensor(encode(text), dtype=torch.long)\n", + "# print(data.shape, data.dtype)\n", + "# print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this\n", + "\n", + "# Assuming `encode` is a function defined elsewhere\n", + "data = jnp.array(encode(text), dtype=jnp.int64)\n", + "print(data.shape, data.dtype)\n", + "print(data[:1000])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pBlRYpvhF9xj" + }, + "source": [ + "#### split into train and test" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "id": "f_WIXqxz0lU5" + }, + "outputs": [], + "source": [ + "# Let's now split up the data into train and validation sets\n", + "n = int(0.9*len(data)) # first 90% will be train, rest val\n", + "train_data = data[:n]\n", + "val_data = data[n:]" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bCAY6ej2DQUm", + "outputId": "8f7f6770-08d4-4855-e919-e5986f7d1c7d" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(111540,)" + ] + }, + "metadata": {}, + "execution_count": 35 + } + ], + "source": [ + "val_data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TD5Bj8Y6IAD4", + "outputId": "f0218cbd-db0d-48c5-eed0-46e5994ed986" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Array([18, 47, 56, 57, 58, 1, 15, 47, 58], dtype=int32)" + ] + }, + "metadata": {}, + "execution_count": 36 + } + ], + "source": [ + "block_size = 8\n", + "train_data[:block_size+1]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nHbUA5zZGBqn" + }, + "source": [ + "#### build feature input x and target output y" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9HXDe8vGJCEn", + "outputId": "8157bb28-b5f0-4af6-902a-4fecee9940e3" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0\n", + "when input is [18] the target: 47\n", + "1\n", + "when input is [18 47] the target: 56\n", + "2\n", + "when input is [18 47 56] the target: 57\n", + "3\n", + "when input is [18 47 56 57] the target: 58\n", + "4\n", + "when input is [18 47 56 57 58] the target: 1\n", + "5\n", + "when input is [18 47 56 57 58 1] the target: 15\n", + "6\n", + "when input is [18 47 56 57 58 1 15] the target: 47\n", + "7\n", + "when input is [18 47 56 57 58 1 15 47] the target: 58\n" + ] + } + ], + "source": [ + "x = train_data[:block_size]\n", + "y = train_data[1:block_size+1]\n", + "for t in range(block_size):\n", + " print(t)\n", + " context = x[:t+1]\n", + " target = y[t]\n", + " print(f\"when input is {context} the target: {target}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Q3k1Czf7LuA9", + "outputId": "5dd0892a-0f14-4555-86e6-ad3560e042db" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "inputs:\n", + "(4, 8)\n", + "[[ 1 51 39 52 58 50 43 0]\n", + " [25 17 27 10 0 27 6 1]\n", + " [47 51 8 0 14 59 58 1]\n", + " [ 1 57 59 41 46 1 50 43]]\n", + "targets:\n", + "(4, 8)\n", + "[[51 39 52 58 50 43 0 53]\n", + " [17 27 10 0 27 6 1 50]\n", + " [51 8 0 14 59 58 1 46]\n", + " [57 59 41 46 1 50 43 52]]\n", + "----\n", + "when input is [1] the target: 51\n", + "when input is [1, 51] the target: 39\n", + "when input is [1, 51, 39] the target: 52\n", + "when input is [1, 51, 39, 52] the target: 58\n", + "when input is [1, 51, 39, 52, 58] the target: 50\n", + "when input is [1, 51, 39, 52, 58, 50] the target: 43\n", + "when input is [1, 51, 39, 52, 58, 50, 43] the target: 0\n", + "when input is [1, 51, 39, 52, 58, 50, 43, 0] the target: 53\n", + "when input is [25] the target: 17\n", + "when input is [25, 17] the target: 27\n", + "when input is [25, 17, 27] the target: 10\n", + "when input is [25, 17, 27, 10] the target: 0\n", + "when input is [25, 17, 27, 10, 0] the target: 27\n", + "when input is [25, 17, 27, 10, 0, 27] the target: 6\n", + "when input is [25, 17, 27, 10, 0, 27, 6] the target: 1\n", + "when input is [25, 17, 27, 10, 0, 27, 6, 1] the target: 50\n", + "when input is [47] the target: 51\n", + "when input is [47, 51] the target: 8\n", + "when input is [47, 51, 8] the target: 0\n", + "when input is [47, 51, 8, 0] the target: 14\n", + "when input is [47, 51, 8, 0, 14] the target: 59\n", + "when input is [47, 51, 8, 0, 14, 59] the target: 58\n", + "when input is [47, 51, 8, 0, 14, 59, 58] the target: 1\n", + "when input is [47, 51, 8, 0, 14, 59, 58, 1] the target: 46\n", + "when input is [1] the target: 57\n", + "when input is [1, 57] the target: 59\n", + "when input is [1, 57, 59] the target: 41\n", + "when input is [1, 57, 59, 41] the target: 46\n", + "when input is [1, 57, 59, 41, 46] the target: 1\n", + "when input is [1, 57, 59, 41, 46, 1] the target: 50\n", + "when input is [1, 57, 59, 41, 46, 1, 50] the target: 43\n", + "when input is [1, 57, 59, 41, 46, 1, 50, 43] the target: 52\n" + ] + } + ], + "source": [ + "prng = jax.random.PRNGKey(1337)\n", + "batch_size = 4 # how many independent sequences will we process in parallel?\n", + "block_size = 8 # what is the maximum context length for predictions?\n", + "ix = random.randint(random.PRNGKey(0), (batch_size,), 0, len(data) - block_size)\n", + "\n", + "def get_batch(split, subkey):\n", + " # generate a small batch of data of inputs x and targets y\n", + " data = train_data if split == 'train' else val_data\n", + " t1 = time.time()\n", + " ix = random.randint(subkey, (batch_size,), 0, len(data) - block_size)\n", + " t2 = time.time()\n", + "\n", + "\n", + " # x = jnp.stack([data[i:i+block_size] for i in ix])\n", + " # y = jnp.stack([data[i+1:i+block_size+1] for i in ix])\n", + " def slice_data(i):\n", + " return jax.lax.dynamic_slice(data, (i,), (block_size,))\n", + "\n", + " x = jax.vmap(slice_data)(ix)\n", + " y = jax.vmap(slice_data)(ix+1)\n", + " x, y = device_put(x), device_put(y)\n", + " # print('TIME rand idx', t2-t1)\n", + " # print('TIME rand idx fetch', time.time()-t2)\n", + " return x, y\n", + "\n", + "xb, yb = get_batch('train', prng)\n", + "print('inputs:')\n", + "print(xb.shape)\n", + "print(xb)\n", + "print('targets:')\n", + "print(yb.shape)\n", + "print(yb)\n", + "\n", + "print('----')\n", + "\n", + "for b in range(batch_size): # batch dimension\n", + " for t in range(block_size): # time dimension\n", + " context = xb[b, :t+1]\n", + " target = yb[b,t]\n", + " print(f\"when input is {context.tolist()} the target: {target}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "qpyyAeIzQjlO", + "outputId": "d903b396-ef7c-425f-db46-8312d9bd2fd6" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[[ 1 51 39 52 58 50 43 0]\n", + " [25 17 27 10 0 27 6 1]\n", + " [47 51 8 0 14 59 58 1]\n", + " [ 1 57 59 41 46 1 50 43]]\n" + ] + } + ], + "source": [ + "print(xb) # our input to the transformer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yZF__SLaGJs9" + }, + "source": [ + "### Baseline model: Bigram" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nql_1ER53oCf", + "outputId": "fa018823-a3e6-4a13-bc4c-76bdfbb34b7b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "65 {'params': {'Embed_0': {'embedding': Array([[-0.07564811, -0.24239814, -0.01499225, ..., -0.06269587,\n", + " 0.07931701, 0.1008973 ],\n", + " [-0.00872193, 0.02849229, -0.08602703, ..., 0.12625487,\n", + " -0.05664039, -0.12900828],\n", + " [-0.01677619, -0.01286294, -0.00534049, ..., -0.056512 ,\n", + " -0.11744383, -0.09810068],\n", + " ...,\n", + " [ 0.15148896, 0.08473317, -0.10937848, ..., 0.03070055,\n", + " -0.00960146, -0.15743323],\n", + " [ 0.05090765, 0.06334479, -0.07453259, ..., -0.06299953,\n", + " -0.09558795, 0.02108589],\n", + " [ 0.08294208, 0.0413576 , -0.10926365, ..., 0.05151561,\n", + " -0.18575938, 0.19301337]], dtype=float32)}}}\n", + "(65, 65)\n", + "(32, 65)\n", + "4.1973753\n" + ] + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax import random\n", + "from flax import linen as nn\n", + "\n", + "class BigramLanguageModel(nn.Module):\n", + " vocab_size: int\n", + "\n", + " @nn.compact\n", + " def __call__(self, idx, targets=None):\n", + " # Token embedding table\n", + " embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)\n", + " logits = embedding_table(idx) # (B,T,C)\n", + "\n", + " if targets is None:\n", + " loss = None\n", + " else:\n", + " B, T, C = logits.shape\n", + " logits = logits.reshape(B*T, C)\n", + " targets = targets.reshape(B*T)\n", + " loss = -jnp.sum(jax.nn.one_hot(targets, C) * jax.nn.log_softmax(logits), axis=1).mean()\n", + "\n", + " return logits, loss\n", + "\n", + " def generate(self, params, key, idx, max_new_tokens):\n", + " for _ in range(max_new_tokens):\n", + " logits, _ = self.apply(params, idx)\n", + " logits = logits[:, -1, :] # (B, C)\n", + " # probs = jax.nn.softmax(logits, axis=-1) # (B, C)\n", + " key, subkey = random.split(key)\n", + " idx_next = random.categorical(subkey, logits)[:, None] # (B, 1)\n", + " idx = jnp.concatenate((idx, idx_next), axis=1) # (B, T+1)\n", + " return idx\n", + "\n", + "# Example usage\n", + "model = BigramLanguageModel(vocab_size)\n", + "\n", + "# Initialize model parameters and optimizer\n", + "key = random.PRNGKey(1337)\n", + "params = model.init(key, jnp.ones((1, 1), jnp.int32))\n", + "print(vocab_size, params)\n", + "print(params['params']['Embed_0']['embedding'].shape)\n", + "\n", + "flax_apply_jitted = jax.jit(lambda params, xb, yb: model.apply(params, xb, yb))\n", + "\n", + "logits, loss = flax_apply_jitted(params, xb, yb)\n", + "print(logits.shape)\n", + "print(loss)\n" + ] + }, + { + "cell_type": "code", + "source": [ + "a = jnp.ones([10, 20, 30])\n", + "t1 = time.time()\n", + "a = a[:, -1, :]\n", + "print('TIME total', time.time()-t1)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "EgnDhLeh0E0b", + "outputId": "476d1815-554d-4261-ec43-3e0ced5599ce" + }, + "execution_count": 41, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "TIME total 0.008388519287109375\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "@jax.jit\n", + "def slice_fn(x):\n", + " return x[:, -1, :]\n" + ], + "metadata": { + "id": "_8WBhAL9w53o" + }, + "execution_count": 89, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "a = jnp.ones([10, 99, 30])" + ], + "metadata": { + "id": "nCNwo1H_r4oJ" + }, + "execution_count": 97, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%time slice_fn(a).block_until_ready()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "kIBJwk1ErvCD", + "outputId": "501552a7-2536-4e76-ce2c-764d14783f98" + }, + "execution_count": 98, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "CPU times: user 1.11 ms, sys: 0 ns, total: 1.11 ms\n", + "Wall time: 674 µs\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)" + ] + }, + "metadata": {}, + "execution_count": 98 + } + ] + }, + { + "cell_type": "code", + "source": [ + "b = jnp.ones([10, 100, 30])" + ], + "metadata": { + "id": "1R4kOQmNsC-w" + }, + "execution_count": 99, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%time slice_fn(b).block_until_ready()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "werFHdHyr8Kw", + "outputId": "56808098-4ae0-4ec9-eebe-44b53d51fdd2" + }, + "execution_count": 100, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "CPU times: user 48.4 ms, sys: 1.02 ms, total: 49.4 ms\n", + "Wall time: 50.6 ms\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)" + ] + }, + "metadata": {}, + "execution_count": 100 + } + ] + }, + { + "cell_type": "code", + "source": [ + "t1 = time.time()\n", + "for i in range(100):\n", + " a = jnp.ones([10, 100, 30])\n", + " a = slice_fn(a)\n", + "print(f'total:{time.time()-t1}')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7_3OtSg5rMnx", + "outputId": "43295512-a90c-4b7e-8a6b-cd79f420f1ab" + }, + "execution_count": 68, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "total:0.10582447052001953\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "def generate(params, flax_apply_jitted, key, idx, max_new_tokens):\n", + " for _ in range(max_new_tokens):\n", + " t1 = time.time()\n", + " logits, _ = flax_apply_jitted(params, idx, None)\n", + " t2 = time.time()\n", + " print(logits.shape)\n", + " logits = slice_fn(logits) # (B, C)\n", + " print(logits.shape)\n", + " t3 = time.time()\n", + " key, subkey = random.split(key)\n", + " idx_next = random.categorical(subkey, logits)[:, None] # (B, 1)\n", + " t4 = time.time()\n", + " idx = jnp.concatenate((idx, idx_next), axis=1) # (B, T+1)\n", + " print(f'apply {t2-t1}, fetch {t3-t2}, rand: {t4-t3} sample:{time.time()-t4}, total:{time.time()-t1}')\n", + " return idx\n" + ], + "metadata": { + "id": "TAW6J0sLqwcv" + }, + "execution_count": 50, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "t1=time.time()\n", + "print(decode(generate(jax.lax.stop_gradient(params), flax_apply_jitted, key, jnp.zeros((1, 1), jnp.int32), 100)[0].tolist()))\n", + "print('TIME total', time.time()-t1)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7gVsfkbsqLBp", + "outputId": "c72409f4-a016-4e1b-f9e7-95776480c5d4" + }, + "execution_count": 51, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(1, 1, 65)\n", + "(1, 65)\n", + "apply 0.00012230873107910156, fetch 0.041161537170410156, rand: 0.0023279190063476562 sample:0.0003173351287841797, total:0.043931007385253906\n", + "(1, 2, 65)\n", + "(1, 65)\n", + "apply 0.00048065185546875, fetch 0.04120588302612305, rand: 0.0021431446075439453 sample:0.0002779960632324219, total:0.04410958290100098\n", + "(1, 3, 65)\n", + "(1, 65)\n", + "apply 0.00047898292541503906, fetch 0.03672361373901367, rand: 0.0021305084228515625 sample:0.00027441978454589844, total:0.039609670639038086\n", + "(1, 4, 65)\n", + "(1, 65)\n", + "apply 0.0004787445068359375, fetch 0.03933238983154297, rand: 0.002025127410888672 sample:0.0002868175506591797, total:0.04212522506713867\n", + "(1, 5, 65)\n", + "(1, 65)\n", + "apply 0.0004420280456542969, fetch 0.03999972343444824, rand: 0.0023453235626220703 sample:0.0002808570861816406, total:0.04306983947753906\n", + "(1, 6, 65)\n", + "(1, 65)\n", + "apply 0.000579833984375, fetch 0.03841257095336914, rand: 0.002195596694946289 sample:0.00031304359436035156, total:0.0415036678314209\n", + "(1, 7, 65)\n", + "(1, 65)\n", + "apply 0.0005831718444824219, fetch 0.04033041000366211, rand: 0.002285003662109375 sample:0.0003218650817871094, total:0.04352235794067383\n", + "(1, 8, 65)\n", + "(1, 65)\n", + "apply 0.0006184577941894531, fetch 0.04154014587402344, rand: 0.0023641586303710938 sample:0.00031638145446777344, total:0.044841766357421875\n", + "(1, 9, 65)\n", + "(1, 65)\n", + "apply 0.0004944801330566406, fetch 0.037868499755859375, rand: 0.002101898193359375 sample:0.0003063678741455078, total:0.04077315330505371\n", + "(1, 10, 65)\n", + "(1, 65)\n", + "apply 0.0004963874816894531, fetch 0.04486680030822754, rand: 0.002717256546020508 sample:0.00038695335388183594, total:0.04846978187561035\n", + "(1, 11, 65)\n", + "(1, 65)\n", + "apply 0.0004949569702148438, fetch 0.03588724136352539, rand: 0.0021865367889404297 sample:0.0002856254577636719, total:0.03885650634765625\n", + "(1, 12, 65)\n", + "(1, 65)\n", + "apply 0.0005121231079101562, fetch 0.034277915954589844, rand: 0.0022237300872802734 sample:0.0003218650817871094, total:0.0373380184173584\n", + "(1, 13, 65)\n", + "(1, 65)\n", + "apply 0.0005197525024414062, fetch 0.03408074378967285, rand: 0.002284526824951172 sample:0.00032019615173339844, total:0.03720736503601074\n", + "(1, 14, 65)\n", + "(1, 65)\n", + "apply 0.0004794597625732422, fetch 0.036359548568725586, rand: 0.002101421356201172 sample:0.00030493736267089844, total:0.039247751235961914\n", + "(1, 15, 65)\n", + "(1, 65)\n", + "apply 0.0004851818084716797, fetch 0.03603172302246094, rand: 0.0021791458129882812 sample:0.0002892017364501953, total:0.03898739814758301\n", + "(1, 16, 65)\n", + "(1, 65)\n", + "apply 0.0005307197570800781, fetch 0.04622006416320801, rand: 0.0027723312377929688 sample:0.00032711029052734375, total:0.049852609634399414\n", + "(1, 17, 65)\n", + "(1, 65)\n", + "apply 0.0001220703125, fetch 0.03553509712219238, rand: 0.002171754837036133 sample:0.0002913475036621094, total:0.03812217712402344\n", + "(1, 18, 65)\n", + "(1, 65)\n", + "apply 0.0004968643188476562, fetch 0.03546881675720215, rand: 0.0022530555725097656 sample:0.0002722740173339844, total:0.03849315643310547\n", + "(1, 19, 65)\n", + "(1, 65)\n", + "apply 0.0004630088806152344, fetch 0.03522849082946777, rand: 0.0021698474884033203 sample:0.0003001689910888672, total:0.03816366195678711\n", + "(1, 20, 65)\n", + "(1, 65)\n", + "apply 0.00048041343688964844, fetch 0.03506278991699219, rand: 0.002357006072998047 sample:0.00028228759765625, total:0.03818464279174805\n", + "(1, 21, 65)\n", + "(1, 65)\n", + "apply 0.00013685226440429688, fetch 0.042746782302856445, rand: 0.0050699710845947266 sample:0.0003261566162109375, total:0.048282623291015625\n", + "(1, 22, 65)\n", + "(1, 65)\n", + "apply 0.0006587505340576172, fetch 0.03631782531738281, rand: 0.002202749252319336 sample:0.0002942085266113281, total:0.039475440979003906\n", + "(1, 23, 65)\n", + "(1, 65)\n", + "apply 0.0005395412445068359, fetch 0.03643512725830078, rand: 0.0021920204162597656 sample:0.00031828880310058594, total:0.03948688507080078\n", + "(1, 24, 65)\n", + "(1, 65)\n", + "apply 0.0004425048828125, fetch 0.03418898582458496, rand: 0.002078533172607422 sample:0.00026345252990722656, total:0.03697514533996582\n", + "(1, 25, 65)\n", + "(1, 65)\n", + "apply 0.000499725341796875, fetch 0.038149356842041016, rand: 0.0022847652435302734 sample:0.0002911090850830078, total:0.04122757911682129\n", + "(1, 26, 65)\n", + "(1, 65)\n", + "apply 0.0004944801330566406, fetch 0.036529541015625, rand: 0.0022573471069335938 sample:0.00032138824462890625, total:0.03960537910461426\n", + "(1, 27, 65)\n", + "(1, 65)\n", + "apply 0.0004668235778808594, fetch 0.040696144104003906, rand: 0.002485990524291992 sample:0.0003218650817871094, total:0.04397273063659668\n", + "(1, 28, 65)\n", + "(1, 65)\n", + "apply 0.00047135353088378906, fetch 0.16798734664916992, rand: 0.002162456512451172 sample:0.00037169456481933594, total:0.17099547386169434\n", + "(1, 29, 65)\n", + "(1, 65)\n", + "apply 0.00010633468627929688, fetch 0.040407657623291016, rand: 0.0026171207427978516 sample:0.00030517578125, total:0.04343867301940918\n", + "(1, 30, 65)\n", + "(1, 65)\n", + "apply 0.0005483627319335938, fetch 0.034818172454833984, rand: 0.002196073532104492 sample:0.00030040740966796875, total:0.037865400314331055\n", + "(1, 31, 65)\n", + "(1, 65)\n", + "apply 0.0005075931549072266, fetch 0.03462409973144531, rand: 0.0022614002227783203 sample:0.0003376007080078125, total:0.03774523735046387\n", + "(1, 32, 65)\n", + "(1, 65)\n", + "apply 0.0004947185516357422, fetch 0.034261226654052734, rand: 0.0022127628326416016 sample:0.0003063678741455078, total:0.0372772216796875\n", + "(1, 33, 65)\n", + "(1, 65)\n", + "apply 0.0005221366882324219, fetch 0.03338050842285156, rand: 0.002115964889526367 sample:0.0003349781036376953, total:0.03635525703430176\n", + "(1, 34, 65)\n", + "(1, 65)\n", + "apply 0.0004925727844238281, fetch 0.03609728813171387, rand: 0.002292633056640625 sample:0.0003094673156738281, total:0.03919363021850586\n", + "(1, 35, 65)\n", + "(1, 65)\n", + "apply 0.0011334419250488281, fetch 0.03726768493652344, rand: 0.0024628639221191406 sample:0.0003046989440917969, total:0.04117107391357422\n", + "(1, 36, 65)\n", + "(1, 65)\n", + "apply 0.0004928112030029297, fetch 0.034056901931762695, rand: 0.0021305084228515625 sample:0.0002810955047607422, total:0.036963462829589844\n", + "(1, 37, 65)\n", + "(1, 65)\n", + "apply 0.0005412101745605469, fetch 0.03381204605102539, rand: 0.002193927764892578 sample:0.0002751350402832031, total:0.03682446479797363\n", + "(1, 38, 65)\n", + "(1, 65)\n", + "apply 0.0001766681671142578, fetch 0.032378435134887695, rand: 0.002092123031616211 sample:0.00026988983154296875, total:0.034919023513793945\n", + "(1, 39, 65)\n", + "(1, 65)\n", + "apply 0.0004858970642089844, fetch 0.043538570404052734, rand: 0.0022830963134765625 sample:0.0003173351287841797, total:0.04662728309631348\n", + "(1, 40, 65)\n", + "(1, 65)\n", + "apply 0.0005092620849609375, fetch 0.03389787673950195, rand: 0.0022470951080322266 sample:0.00028967857360839844, total:0.03694581985473633\n", + "(1, 41, 65)\n", + "(1, 65)\n", + "apply 0.00047850608825683594, fetch 0.043410301208496094, rand: 0.002513408660888672 sample:0.0003464221954345703, total:0.04675102233886719\n", + "(1, 42, 65)\n", + "(1, 65)\n", + "apply 0.0004718303680419922, fetch 0.03394603729248047, rand: 0.0022313594818115234 sample:0.0003161430358886719, total:0.03696751594543457\n", + "(1, 43, 65)\n", + "(1, 65)\n", + "apply 0.00045943260192871094, fetch 0.03443288803100586, rand: 0.0031540393829345703 sample:0.00032782554626464844, total:0.038376808166503906\n", + "(1, 44, 65)\n", + "(1, 65)\n", + "apply 0.0002892017364501953, fetch 0.033853769302368164, rand: 0.0022063255310058594 sample:0.00028252601623535156, total:0.036633968353271484\n", + "(1, 45, 65)\n", + "(1, 65)\n", + "apply 0.0004913806915283203, fetch 0.03909873962402344, rand: 0.002314329147338867 sample:0.00029087066650390625, total:0.04219770431518555\n", + "(1, 46, 65)\n", + "(1, 65)\n", + "apply 0.0005013942718505859, fetch 0.03659510612487793, rand: 0.0022547245025634766 sample:0.000274658203125, total:0.039627790451049805\n", + "(1, 47, 65)\n", + "(1, 65)\n", + "apply 0.0005040168762207031, fetch 0.03713226318359375, rand: 0.0025844573974609375 sample:0.00031304359436035156, total:0.040535926818847656\n", + "(1, 48, 65)\n", + "(1, 65)\n", + "apply 0.0004892349243164062, fetch 0.03428483009338379, rand: 0.0023317337036132812 sample:0.0002830028533935547, total:0.037390947341918945\n", + "(1, 49, 65)\n", + "(1, 65)\n", + "apply 0.0004928112030029297, fetch 0.032887935638427734, rand: 0.002256155014038086 sample:0.0003209114074707031, total:0.03596019744873047\n", + "(1, 50, 65)\n", + "(1, 65)\n", + "apply 0.0005295276641845703, fetch 0.035512685775756836, rand: 0.0021750926971435547 sample:0.0003120899200439453, total:0.03853154182434082\n", + "(1, 51, 65)\n", + "(1, 65)\n", + "apply 0.0005235671997070312, fetch 0.03630781173706055, rand: 0.0022492408752441406 sample:0.0003135204315185547, total:0.039395809173583984\n", + "(1, 52, 65)\n", + "(1, 65)\n", + "apply 0.0004930496215820312, fetch 0.03442525863647461, rand: 0.0022301673889160156 sample:0.00031685829162597656, total:0.03746771812438965\n", + "(1, 53, 65)\n", + "(1, 65)\n", + "apply 0.0005328655242919922, fetch 0.03630828857421875, rand: 0.0026094913482666016 sample:0.00031113624572753906, total:0.0397639274597168\n", + "(1, 54, 65)\n", + "(1, 65)\n", + "apply 0.0005147457122802734, fetch 0.03552508354187012, rand: 0.002171039581298828 sample:0.0002918243408203125, total:0.03850531578063965\n", + "(1, 55, 65)\n", + "(1, 65)\n", + "apply 0.0004754066467285156, fetch 0.03363323211669922, rand: 0.002084016799926758 sample:0.00029921531677246094, total:0.03649425506591797\n", + "(1, 56, 65)\n", + "(1, 65)\n", + "apply 0.00046324729919433594, fetch 0.03212785720825195, rand: 0.0021419525146484375 sample:0.0003018379211425781, total:0.03503680229187012\n", + "(1, 57, 65)\n", + "(1, 65)\n", + "apply 0.00047707557678222656, fetch 0.03502345085144043, rand: 0.002373933792114258 sample:0.0002951622009277344, total:0.038172006607055664\n", + "(1, 58, 65)\n", + "(1, 65)\n", + "apply 0.0004749298095703125, fetch 0.03354239463806152, rand: 0.0021560192108154297 sample:0.0002906322479248047, total:0.036466121673583984\n", + "(1, 59, 65)\n", + "(1, 65)\n", + "apply 0.0004456043243408203, fetch 0.03715181350708008, rand: 0.002325773239135742 sample:0.0002875328063964844, total:0.04021286964416504\n", + "(1, 60, 65)\n", + "(1, 65)\n", + "apply 0.00047206878662109375, fetch 0.03511500358581543, rand: 0.0023941993713378906 sample:0.0002887248992919922, total:0.03827261924743652\n", + "(1, 61, 65)\n", + "(1, 65)\n", + "apply 0.0004887580871582031, fetch 0.035497426986694336, rand: 0.0022628307342529297 sample:0.0003247261047363281, total:0.038576364517211914\n", + "(1, 62, 65)\n", + "(1, 65)\n", + "apply 0.0005118846893310547, fetch 0.03496408462524414, rand: 0.0023431777954101562 sample:0.00031828880310058594, total:0.03813958168029785\n", + "(1, 63, 65)\n", + "(1, 65)\n", + "apply 0.0005247592926025391, fetch 0.03505682945251465, rand: 0.002232789993286133 sample:0.00028514862060546875, total:0.038101911544799805\n", + "(1, 64, 65)\n", + "(1, 65)\n", + "apply 0.0004918575286865234, fetch 0.03458690643310547, rand: 0.0022902488708496094 sample:0.0002923011779785156, total:0.03766345977783203\n", + "(1, 65, 65)\n", + "(1, 65)\n", + "apply 0.00046563148498535156, fetch 0.049041748046875, rand: 0.002659320831298828 sample:0.000286102294921875, total:0.05245494842529297\n", + "(1, 66, 65)\n", + "(1, 65)\n", + "apply 0.0005247592926025391, fetch 0.03327059745788574, rand: 0.002138853073120117 sample:0.00031185150146484375, total:0.03624868392944336\n", + "(1, 67, 65)\n", + "(1, 65)\n", + "apply 0.0004410743713378906, fetch 0.04506540298461914, rand: 0.0026183128356933594 sample:0.0003247261047363281, total:0.048451900482177734\n", + "(1, 68, 65)\n", + "(1, 65)\n", + "apply 0.00018596649169921875, fetch 0.04279971122741699, rand: 0.0023903846740722656 sample:0.0004248619079589844, total:0.04580330848693848\n", + "(1, 69, 65)\n", + "(1, 65)\n", + "apply 0.0006306171417236328, fetch 0.03531026840209961, rand: 0.0023009777069091797 sample:0.0003414154052734375, total:0.03858590126037598\n", + "(1, 70, 65)\n", + "(1, 65)\n", + "apply 0.00046539306640625, fetch 0.03606081008911133, rand: 0.0024008750915527344 sample:0.00031495094299316406, total:0.03924441337585449\n", + "(1, 71, 65)\n", + "(1, 65)\n", + "apply 0.00047135353088378906, fetch 0.03373908996582031, rand: 0.0021162033081054688 sample:0.0002856254577636719, total:0.03661489486694336\n", + "(1, 72, 65)\n", + "(1, 65)\n", + "apply 0.00047850608825683594, fetch 0.03369331359863281, rand: 0.0022661685943603516 sample:0.0003075599670410156, total:0.03674769401550293\n", + "(1, 73, 65)\n", + "(1, 65)\n", + "apply 0.0005681514739990234, fetch 0.0340423583984375, rand: 0.0022792816162109375 sample:0.0003139972686767578, total:0.03720569610595703\n", + "(1, 74, 65)\n", + "(1, 65)\n", + "apply 0.0004639625549316406, fetch 0.03503298759460449, rand: 0.0023622512817382812 sample:0.0003147125244140625, total:0.038176774978637695\n", + "(1, 75, 65)\n", + "(1, 65)\n", + "apply 0.00051116943359375, fetch 0.03286004066467285, rand: 0.0022122859954833984 sample:0.0003190040588378906, total:0.0359044075012207\n", + "(1, 76, 65)\n", + "(1, 65)\n", + "apply 0.00047779083251953125, fetch 0.03750944137573242, rand: 0.0025398731231689453 sample:0.00032210350036621094, total:0.04085183143615723\n", + "(1, 77, 65)\n", + "(1, 65)\n", + "apply 0.00011229515075683594, fetch 0.03283333778381348, rand: 0.0018007755279541016 sample:0.0003113746643066406, total:0.03505969047546387\n", + "(1, 78, 65)\n", + "(1, 65)\n", + "apply 0.0007066726684570312, fetch 0.032887935638427734, rand: 0.002187967300415039 sample:0.00028252601623535156, total:0.03606748580932617\n", + "(1, 79, 65)\n", + "(1, 65)\n", + "apply 0.0005395412445068359, fetch 0.03475666046142578, rand: 0.002254486083984375 sample:0.0003006458282470703, total:0.03785371780395508\n", + "(1, 80, 65)\n", + "(1, 65)\n", + "apply 0.0005538463592529297, fetch 0.03729367256164551, rand: 0.002325773239135742 sample:0.0002956390380859375, total:0.04047131538391113\n", + "(1, 81, 65)\n", + "(1, 65)\n", + "apply 0.0004775524139404297, fetch 0.03231501579284668, rand: 0.002199411392211914 sample:0.00030612945556640625, total:0.03530073165893555\n", + "(1, 82, 65)\n", + "(1, 65)\n", + "apply 0.0004909038543701172, fetch 0.036426544189453125, rand: 0.0023496150970458984 sample:0.00031375885009765625, total:0.03958296775817871\n", + "(1, 83, 65)\n", + "(1, 65)\n", + "apply 0.00045609474182128906, fetch 0.035466670989990234, rand: 0.002231597900390625 sample:0.00028061866760253906, total:0.0384371280670166\n", + "(1, 84, 65)\n", + "(1, 65)\n", + "apply 0.0005011558532714844, fetch 0.03577089309692383, rand: 0.002390146255493164 sample:0.0002968311309814453, total:0.038961172103881836\n", + "(1, 85, 65)\n", + "(1, 65)\n", + "apply 0.0005679130554199219, fetch 0.033548593521118164, rand: 0.002294301986694336 sample:0.0002868175506591797, total:0.036699533462524414\n", + "(1, 86, 65)\n", + "(1, 65)\n", + "apply 0.00046539306640625, fetch 0.03391146659851074, rand: 0.0022737979888916016 sample:0.00032806396484375, total:0.03698086738586426\n", + "(1, 87, 65)\n", + "(1, 65)\n", + "apply 0.0005002021789550781, fetch 0.03307485580444336, rand: 0.002150297164916992 sample:0.00028586387634277344, total:0.03601360321044922\n", + "(1, 88, 65)\n", + "(1, 65)\n", + "apply 0.00046062469482421875, fetch 0.0414433479309082, rand: 0.0026967525482177734 sample:0.00029277801513671875, total:0.044895172119140625\n", + "(1, 89, 65)\n", + "(1, 65)\n", + "apply 0.0004894733428955078, fetch 0.033232927322387695, rand: 0.0023353099822998047 sample:0.0002894401550292969, total:0.03634905815124512\n", + "(1, 90, 65)\n", + "(1, 65)\n", + "apply 0.0004942417144775391, fetch 0.033071279525756836, rand: 0.0023810863494873047 sample:0.0004107952117919922, total:0.036359310150146484\n", + "(1, 91, 65)\n", + "(1, 65)\n", + "apply 0.0005204677581787109, fetch 0.03531360626220703, rand: 0.0029680728912353516 sample:0.0003561973571777344, total:0.03916025161743164\n", + "(1, 92, 65)\n", + "(1, 65)\n", + "apply 0.00016117095947265625, fetch 0.041900634765625, rand: 0.0024335384368896484 sample:0.00030159950256347656, total:0.0447993278503418\n", + "(1, 93, 65)\n", + "(1, 65)\n", + "apply 0.0005536079406738281, fetch 0.032729148864746094, rand: 0.0022106170654296875 sample:0.0002655982971191406, total:0.035761117935180664\n", + "(1, 94, 65)\n", + "(1, 65)\n", + "apply 0.0004940032958984375, fetch 0.034377098083496094, rand: 0.0024232864379882812 sample:0.0003139972686767578, total:0.03761029243469238\n", + "(1, 95, 65)\n", + "(1, 65)\n", + "apply 0.00018143653869628906, fetch 0.03214859962463379, rand: 0.0021219253540039062 sample:0.00030994415283203125, total:0.03476428985595703\n", + "(1, 96, 65)\n", + "(1, 65)\n", + "apply 0.0004601478576660156, fetch 0.03141331672668457, rand: 0.0022804737091064453 sample:0.00030231475830078125, total:0.03445887565612793\n", + "(1, 97, 65)\n", + "(1, 65)\n", + "apply 0.0005335807800292969, fetch 0.0312657356262207, rand: 0.002066373825073242 sample:0.0003190040588378906, total:0.034186363220214844\n", + "(1, 98, 65)\n", + "(1, 65)\n", + "apply 0.0004448890686035156, fetch 0.03332257270812988, rand: 0.0021436214447021484 sample:0.00033736228942871094, total:0.036251068115234375\n", + "(1, 99, 65)\n", + "(1, 65)\n", + "apply 0.0004706382751464844, fetch 0.032656192779541016, rand: 0.0021228790283203125 sample:0.0002694129943847656, total:0.03552079200744629\n", + "(1, 100, 65)\n", + "(1, 65)\n", + "apply 0.0005025863647460938, fetch 0.03476595878601074, rand: 0.0023534297943115234 sample:0.0002772808074951172, total:0.03790140151977539\n", + "\n", + "yD.P.e'wn,CZsvq gP-f$f&W3aypokkuSEz?Paw:YCj?M;x\n", + "pctpxMvdJMlTZrmCZhPRjYRJUfTgldWbqlwXxBlCHIWu'FYEBTwJ\n", + "TIME total 4.07533597946167\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "%%timeit\n", + "\n", + "print(decode(generate(jax.lax.stop_gradient(params), flax_apply_jitted, key, jnp.zeros((1, 1), jnp.int32), 100)[0].tolist()))\n", + "print('TIME total', time.time()-t1)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 398 + }, + "id": "ZXhmaWHOp6Mr", + "outputId": "ceccd6e6-7635-4a94-ab87-ea34ce69be66" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "error", + "ename": "NameError", + "evalue": "name 'decode' is not defined", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_cell_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'timeit'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m''\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"\\nprint(decode(generate(jax.lax.stop_gradient(params), flax_apply_jitted, key, jnp.zeros((1, 1), jnp.int32), 100)[0].tolist()))\\nprint('TIME total', time.time()-t1)\\n\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/google/colab/_shell.py\u001b[0m in \u001b[0;36mrun_cell_magic\u001b[0;34m(self, magic_name, line, cell)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mline\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcell\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0mcell\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m' '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 334\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_cell_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmagic_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mline\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcell\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 335\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 336\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_cell_magic\u001b[0;34m(self, magic_name, line, cell)\u001b[0m\n\u001b[1;32m 2471\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuiltin_trap\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2472\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmagic_arg_s\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcell\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2473\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2474\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2475\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, line, cell, local_ns)\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/magic.py\u001b[0m in \u001b[0;36m\u001b[0;34m(f, *a, **k)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0;31m# but it's overkill for just that one bit of state.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmagic_deco\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m \u001b[0mcall\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/magics/execution.py\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, line, cell, local_ns)\u001b[0m\n\u001b[1;32m 1178\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1179\u001b[0m \u001b[0mnumber\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m10\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1180\u001b[0;31m \u001b[0mtime_number\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtimer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimeit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumber\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1181\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtime_number\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0;36m0.2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/magics/execution.py\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, number)\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0mgc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 169\u001b[0;31m \u001b[0mtiming\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 170\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mgcold\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36minner\u001b[0;34m(_it, _timer)\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'decode' is not defined" + ] + } + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eTyJ8qAaDdiF", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "5eca4221-8c9a-4b32-9ffd-1f386aa02a3f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Tracedwith Tracedwith\n", + "Epoch 0,Loss: 4.176700115203857, 0.03257417678833008, 0.5034492015838623\n", + "Epoch 1,Loss: 4.168147563934326, 0.01342916488647461, 0.0006744861602783203\n", + "Epoch 2,Loss: 4.17624568939209, 0.010905742645263672, 0.0005877017974853516\n", + "Epoch 3,Loss: 4.169358253479004, 0.010186195373535156, 0.0005640983581542969\n", + "Epoch 4,Loss: 4.175686359405518, 0.010203123092651367, 0.0005354881286621094\n", + "Epoch 5,Loss: 4.170393943786621, 0.009913444519042969, 0.0005452632904052734\n", + "Epoch 6,Loss: 4.169347763061523, 0.013239383697509766, 0.0005314350128173828\n", + "Epoch 7,Loss: 4.162715911865234, 0.01039743423461914, 0.0005896091461181641\n", + "Epoch 8,Loss: 4.157233238220215, 0.010396957397460938, 0.0005331039428710938\n", + "Epoch 9,Loss: 4.163628101348877, 0.00969386100769043, 0.0005295276641845703\n", + "Epoch 10,Loss: 4.16217565536499, 0.010032415390014648, 0.0005445480346679688\n", + "Epoch 11,Loss: 4.16450309753418, 0.00960087776184082, 0.0033788681030273438\n", + "Epoch 12,Loss: 4.15981388092041, 0.015790939331054688, 0.002191305160522461\n", + "Epoch 13,Loss: 4.157349586486816, 0.010314226150512695, 0.009167194366455078\n", + "Epoch 14,Loss: 4.163050174713135, 0.013430595397949219, 0.0026466846466064453\n", + "Epoch 15,Loss: 4.155461311340332, 0.020540714263916016, 0.0013380050659179688\n", + "Epoch 16,Loss: 4.154058456420898, 0.01090860366821289, 0.0005877017974853516\n", + "Epoch 17,Loss: 4.14748477935791, 0.00999140739440918, 0.0006780624389648438\n", + "Epoch 18,Loss: 4.1531476974487305, 0.010620355606079102, 0.0006525516510009766\n", + "Epoch 19,Loss: 4.15305757522583, 0.011104106903076172, 0.0006277561187744141\n", + "Epoch 20,Loss: 4.148075580596924, 0.01051187515258789, 0.0006189346313476562\n", + "Epoch 21,Loss: 4.1545729637146, 0.010356903076171875, 0.0005707740783691406\n", + "Epoch 22,Loss: 4.145484924316406, 0.010561943054199219, 0.0006885528564453125\n", + "Epoch 23,Loss: 4.146124839782715, 0.010853290557861328, 0.0006654262542724609\n", + "Epoch 24,Loss: 4.152686595916748, 0.010770797729492188, 0.00067138671875\n", + "Epoch 25,Loss: 4.148232460021973, 0.01275944709777832, 0.0006635189056396484\n", + "Epoch 26,Loss: 4.143317222595215, 0.01102590560913086, 0.0006601810455322266\n", + "Epoch 27,Loss: 4.135959625244141, 0.010260581970214844, 0.0006642341613769531\n", + "Epoch 28,Loss: 4.133705139160156, 0.018632888793945312, 0.0005812644958496094\n", + "Epoch 29,Loss: 4.144227981567383, 0.010734796524047852, 0.0006594657897949219\n", + "Epoch 30,Loss: 4.137174606323242, 0.010676383972167969, 0.00067901611328125\n", + "Epoch 31,Loss: 4.137599945068359, 0.016976118087768555, 0.002281665802001953\n", + "Epoch 32,Loss: 4.133066177368164, 0.010472536087036133, 0.000614166259765625\n", + "Epoch 33,Loss: 4.140651702880859, 0.01074361801147461, 0.000682830810546875\n", + "Epoch 34,Loss: 4.132898330688477, 0.019823312759399414, 0.0007305145263671875\n", + "Epoch 35,Loss: 4.129890441894531, 0.01900649070739746, 0.0006961822509765625\n", + "Epoch 36,Loss: 4.128242492675781, 0.017644405364990234, 0.0007212162017822266\n", + "Epoch 37,Loss: 4.1404337882995605, 0.018088340759277344, 0.0006604194641113281\n", + "Epoch 38,Loss: 4.124664306640625, 0.01122593879699707, 0.0007469654083251953\n", + "Epoch 39,Loss: 4.121993541717529, 0.010849952697753906, 0.0006933212280273438\n", + "Epoch 40,Loss: 4.131916046142578, 0.011729717254638672, 0.00067901611328125\n", + "Epoch 41,Loss: 4.121825218200684, 0.01046895980834961, 0.00066375732421875\n", + "Epoch 42,Loss: 4.124566078186035, 0.016879796981811523, 0.000637054443359375\n", + "Epoch 43,Loss: 4.122013568878174, 0.023413896560668945, 0.0007090568542480469\n", + "Epoch 44,Loss: 4.124298572540283, 0.021088600158691406, 0.0006775856018066406\n", + "Epoch 45,Loss: 4.116214275360107, 0.03228282928466797, 0.0007734298706054688\n", + "Epoch 46,Loss: 4.118075847625732, 0.012856006622314453, 0.0008099079132080078\n", + "Epoch 47,Loss: 4.114686965942383, 0.011289834976196289, 0.0007231235504150391\n", + "Epoch 48,Loss: 4.117861270904541, 0.011073589324951172, 0.0008485317230224609\n", + "Epoch 49,Loss: 4.111809730529785, 0.02568793296813965, 0.0007822513580322266\n", + "Epoch 50,Loss: 4.1100616455078125, 0.016759395599365234, 0.0007135868072509766\n", + "Epoch 51,Loss: 4.1063408851623535, 0.010864496231079102, 0.00066375732421875\n", + "Epoch 52,Loss: 4.106322765350342, 0.01942586898803711, 0.005199432373046875\n", + "Epoch 53,Loss: 4.105854511260986, 0.020357847213745117, 0.0006930828094482422\n", + "Epoch 54,Loss: 4.116119861602783, 0.011683225631713867, 0.0007033348083496094\n", + "Epoch 55,Loss: 4.107294082641602, 0.013722896575927734, 0.0006699562072753906\n", + "Epoch 56,Loss: 4.112220287322998, 0.011293649673461914, 0.0007617473602294922\n", + "Epoch 57,Loss: 4.106285095214844, 0.02633976936340332, 0.003446817398071289\n", + "Epoch 58,Loss: 4.101660251617432, 0.021680593490600586, 0.0007572174072265625\n", + "Epoch 59,Loss: 4.100406646728516, 0.011004447937011719, 0.0007081031799316406\n", + "Epoch 60,Loss: 4.098540306091309, 0.010811090469360352, 0.0006601810455322266\n", + "Epoch 61,Loss: 4.097071170806885, 0.013948678970336914, 0.00069427490234375\n", + "Epoch 62,Loss: 4.094281196594238, 0.019835710525512695, 0.0006990432739257812\n", + "Epoch 63,Loss: 4.098484039306641, 0.011063575744628906, 0.0006513595581054688\n", + "Epoch 64,Loss: 4.09868049621582, 0.012256860733032227, 0.0006780624389648438\n", + "Epoch 65,Loss: 4.098651885986328, 0.010885477066040039, 0.0005729198455810547\n", + "Epoch 66,Loss: 4.093440055847168, 0.010955572128295898, 0.0006401538848876953\n", + "Epoch 67,Loss: 4.098696708679199, 0.022384166717529297, 0.0007491111755371094\n", + "Epoch 68,Loss: 4.089413166046143, 0.0189666748046875, 0.000705718994140625\n", + "Epoch 69,Loss: 4.089560508728027, 0.014131784439086914, 0.0006935596466064453\n", + "Epoch 70,Loss: 4.0868353843688965, 0.01195526123046875, 0.0006866455078125\n", + "Epoch 71,Loss: 4.088634014129639, 0.011432886123657227, 0.0006272792816162109\n", + "Epoch 72,Loss: 4.084425926208496, 0.010274410247802734, 0.0006382465362548828\n", + "Epoch 73,Loss: 4.085216522216797, 0.010661602020263672, 0.0006399154663085938\n", + "Epoch 74,Loss: 4.077157020568848, 0.01113271713256836, 0.0006384849548339844\n", + "Epoch 75,Loss: 4.083675384521484, 0.011570930480957031, 0.0006246566772460938\n", + "Epoch 76,Loss: 4.080029487609863, 0.010760307312011719, 0.0007207393646240234\n", + "Epoch 77,Loss: 4.075240135192871, 0.010936975479125977, 0.0005242824554443359\n", + "Epoch 78,Loss: 4.077165603637695, 0.010789632797241211, 0.0006308555603027344\n", + "Epoch 79,Loss: 4.082056522369385, 0.011003971099853516, 0.0006132125854492188\n", + "Epoch 80,Loss: 4.084698677062988, 0.011073112487792969, 0.0006220340728759766\n", + "Epoch 81,Loss: 4.074877738952637, 0.010831832885742188, 0.0006229877471923828\n", + "Epoch 82,Loss: 4.065998077392578, 0.01100921630859375, 0.0006418228149414062\n", + "Epoch 83,Loss: 4.072014808654785, 0.013705968856811523, 0.0005524158477783203\n", + "Epoch 84,Loss: 4.065586090087891, 0.012798547744750977, 0.0006158351898193359\n", + "Epoch 85,Loss: 4.070856094360352, 0.009680509567260742, 0.0005240440368652344\n", + "Epoch 86,Loss: 4.06818151473999, 0.010037660598754883, 0.0005896091461181641\n", + "Epoch 87,Loss: 4.072195053100586, 0.009814023971557617, 0.0005784034729003906\n", + "Epoch 88,Loss: 4.065300941467285, 0.013661861419677734, 0.0005695819854736328\n", + "Epoch 89,Loss: 4.059263706207275, 0.009290456771850586, 0.0005545616149902344\n", + "Epoch 90,Loss: 4.06481409072876, 0.020060062408447266, 0.009862422943115234\n", + "Epoch 91,Loss: 4.067580699920654, 0.023707866668701172, 0.0005640983581542969\n", + "Epoch 92,Loss: 4.062394142150879, 0.029192209243774414, 0.0005948543548583984\n", + "Epoch 93,Loss: 4.053855895996094, 0.024830102920532227, 0.0018415451049804688\n", + "Epoch 94,Loss: 4.056528091430664, 0.024169921875, 0.0018668174743652344\n", + "Epoch 95,Loss: 4.060898780822754, 0.026465415954589844, 0.0005826950073242188\n", + "Epoch 96,Loss: 4.057040214538574, 0.02184319496154785, 0.0005927085876464844\n", + "Epoch 97,Loss: 4.053262710571289, 0.036957502365112305, 0.00471806526184082\n", + "Epoch 98,Loss: 4.054243564605713, 0.015822172164916992, 0.0004324913024902344\n", + "Epoch 99,Loss: 4.056582450866699, 0.00728297233581543, 0.000400543212890625\n" + ] + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax import random\n", + "from flax import linen as nn\n", + "from optax import adam\n", + "\n", + "# Define the optimizer\n", + "learning_rate = 1e-3 # Adjust as needed\n", + "tx = adam(learning_rate)\n", + "\n", + "# Initialize model parameters and optimizer state\n", + "prng = random.PRNGKey(1337)\n", + "params = model.init(prng, jnp.ones((1, 1), jnp.int32))\n", + "opt_state = tx.init(params)\n", + "\n", + "# Loss function (assuming you have a batch of data: xb, yb)\n", + "def loss_fn(params, xb, yb):\n", + " print(xb, yb)\n", + " logits, loss = model.apply(params, xb, yb)\n", + " return loss\n", + "\n", + "# Update function for a single training step\n", + "@jax.jit\n", + "def update_step(params, opt_state, xb, yb):\n", + " loss, grads = jax.value_and_grad(loss_fn)(params, xb, yb)\n", + " updates, opt_state = tx.update(grads, opt_state, params)\n", + " new_params = optax.apply_updates(params, updates)\n", + " return new_params, opt_state, loss\n", + "\n", + "# Training loop (example)\n", + "batch_size = 32\n", + "for steps in range(100):\n", + " t1 = time.time()\n", + " prng, subkey = random.split(prng)\n", + " xb, yb = get_batch('train', subkey)\n", + " t2 = time.time()\n", + " params, opt_state, loss = update_step(params, opt_state, xb, yb)\n", + " print(f\"Epoch {steps},Loss: {loss}, {t2-t1}, {time.time()-t2}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "EcVIDWAZEtjN", + "outputId": "6c640818-70c4-488e-9243-f717b7c4d1a7" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " yof le' nongo myorrthavemandse;\n", + "SETHOMA of of rerdse ou wilavemerdss s meng l l hinn conotlint w'ds rr, re w'd, r, w'dst w'sthavese.\n", + "\n", + "KIOMPHOMonth gou o l yof thou thar sen lle\n", + "GLOMA corrd s w' re;\n", + "\n", + "Yorix's w's\n", + "Tdsilit erdss gof thariave rononond, ther o six's\n", + "Yore;\n", + "\n", + "\n", + "By t s\n", + "KIASo's yorrr, co s at wix' wiatheco oriaromavemen yofer, shine t.\n", + "Whe\n", + "Byore y se owrr gortharorint.\n", + "\n", + "Whiarenof ange t\n", + "Mo marilbifecou fecour wix's t\n", + "POMASoul ou y yotle.\n", + "POUCET:\n", + "Mof y s y meveve\n", + "SEYorthix' yore ser, ss w' tlard'seveve sinof arenour rof rilave you s wrthard,Z' t winou mard'se.\n", + "MASoiavecoreathinonomaverd, f\n" + ] + } + ], + "source": [ + "print(decode(model.generate(params, key, jnp.ones((1, 1), jnp.int32), 600)[0].tolist()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KwvFppxNv8Yg", + "outputId": "f763a487-4e9f-4ec3-e6c7-bcb7637e4a16" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Th le' nongo myorrthavemandse;\n", + "SETHOMA of of rerdse ou wilavemerdss s meng l l hinn conotlint w'ds rr, re w'd, r, w'dst w'sthavese.\n", + "\n", + "KIOMPHOMonth gou o l yof thou thar sen lle\n", + "GLOMA corrd s w' re;\n", + "\n", + "Yorix's w's\n", + "Tdsilit erdss gof thariave rononond, ther o six's\n", + "Yore;\n", + "\n", + "\n", + "By t s\n", + "KIASo's yorrr, co s at wix' wiatheco oriaromavemen yofer, shine t.\n", + "Whe\n", + "Byore y se owrr gortharorint.\n", + "\n", + "Whiarenof ange t\n", + "Mo marilbifecou fecour wix's t\n", + "POMASoul ou y yotle.\n", + "POUCET:\n", + "Mof y s y meveve\n", + "SEYorthix' yore ser, ss w' tlard'seveve sinof arenour rof rilave you s wrthard,Z' t winou mard'se.\n", + "MASoiavecoreathinonomaverd, f\n" + ] + } + ], + "source": [ + "print(decode(model.generate(params, key, jnp.zeros((1, 1), jnp.int32), 600)[0].tolist()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YddmXCuI3lGU", + "outputId": "ff56be4d-92a0-480e-b67d-1242234cfbef" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([ 336148, 278170, 604967, 333936, 240920, 353408, 986691,\n", + " 855404, 1063794, 383313, 110829, 521517, 196487, 814541,\n", + " 350061, 368494, 978407, 308935, 132008, 539608, 749456,\n", + " 375898, 515910, 961126, 283819, 830869, 218387, 546800,\n", + " 248480, 867346, 969181, 108388], dtype=int32)" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prng = jax.random.PRNGKey(1337)\n", + "prng, subkey = random.split(prng)\n", + "random.randint(prng, (batch_size,), 0, len(data) - block_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "jxzTNhqF1eGI", + "outputId": "fc1a3b91-6146-464a-df78-489e27dfabab" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([ 213304, 267921, 260855, 747830, 768672, 544391, 187257,\n", + " 212117, 250143, 228211, 1090033, 982622, 177977, 700660,\n", + " 152160, 584462, 415188, 1065438, 161551, 98046, 305424,\n", + " 601585, 86080, 209041, 816187, 820975, 158124, 954360,\n", + " 227984, 282807, 174036, 229730], dtype=int32)" + ] + }, + "execution_count": 69, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prng, subkey = random.split(prng)\n", + "random.randint(prng, (batch_size,), 0, len(data) - block_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "PN28-wZ94GI3", + "outputId": "612880a3-1246-4bac-fbee-74d4db687dcb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 39216 10190 43768 52429 59508 23294 90087 57102 1663 19305\n", + " 39405 11300 16166 110476 99338 98527 32447 17017 25 48318\n", + " 88658 19627 26638 97044 76053 46584 90512 13184 64779 20642\n", + " 66664 62260]\n" + ] + }, + { + "data": { + "text/plain": [ + "(Array([[ 1, 56, 39, ..., 23, 52, 53],\n", + " [46, 43, 52, ..., 39, 52, 1],\n", + " [58, 1, 58, ..., 1, 39, 58],\n", + " ...,\n", + " [52, 42, 1, ..., 50, 50, 1],\n", + " [ 0, 14, 21, ..., 58, 53, 1],\n", + " [57, 1, 40, ..., 58, 7, 54]], dtype=int32),\n", + " Array([[56, 39, 58, ..., 52, 53, 61],\n", + " [43, 52, 1, ..., 52, 1, 57],\n", + " [ 1, 58, 46, ..., 39, 58, 1],\n", + " ...,\n", + " [42, 1, 58, ..., 50, 1, 40],\n", + " [14, 21, 13, ..., 53, 1, 24],\n", + " [ 1, 40, 43, ..., 7, 54, 50]], dtype=int32))" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prng = jax.random.PRNGKey(1337)\n", + "\n", + "def get_batch(split, prng):\n", + " # generate a small batch of data of inputs x and targets y\n", + " data = train_data if split == 'train' else val_data\n", + " prng, subkey = random.split(prng)\n", + " ix = random.randint(subkey, (batch_size,), 0, len(data) - block_size)\n", + " print(ix)\n", + " x = jnp.stack([data[i:i+block_size] for i in ix])\n", + " y = jnp.stack([data[i+1:i+block_size+1] for i in ix])\n", + " x, y = device_put(x), device_put(y)\n", + " return x, y\n", + "get_batch('eval', prng)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yNc7jn8wy5Kx" + }, + "source": [ + "#### put together" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1sBa-IFGy90b", + "outputId": "beef1d3c-6e3a-4cb0-b6c2-2c69140a50ad" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 0,Loss: 4.176700115203857\n", + "Epoch 1,Loss: 4.168147563934326\n", + "Epoch 2,Loss: 4.17624568939209\n", + "Epoch 3,Loss: 4.169358253479004\n", + "Epoch 4,Loss: 4.175686359405518\n", + "Epoch 5,Loss: 4.170393943786621\n", + "Epoch 6,Loss: 4.169347763061523\n", + "Epoch 7,Loss: 4.162715911865234\n", + "Epoch 8,Loss: 4.157233238220215\n", + "Epoch 9,Loss: 4.163628101348877\n", + "Epoch 10,Loss: 4.16217565536499\n", + "Epoch 11,Loss: 4.16450309753418\n", + "Epoch 12,Loss: 4.15981388092041\n", + "Epoch 13,Loss: 4.157349586486816\n", + "Epoch 14,Loss: 4.163050174713135\n", + "Epoch 15,Loss: 4.155461311340332\n", + "Epoch 16,Loss: 4.154058456420898\n", + "Epoch 17,Loss: 4.14748477935791\n", + "Epoch 18,Loss: 4.1531476974487305\n", + "Epoch 19,Loss: 4.15305757522583\n", + "Epoch 20,Loss: 4.148075580596924\n", + "Epoch 21,Loss: 4.1545729637146\n", + "Epoch 22,Loss: 4.145484924316406\n", + "Epoch 23,Loss: 4.146124839782715\n", + "Epoch 24,Loss: 4.152686595916748\n", + "Epoch 25,Loss: 4.148232460021973\n", + "Epoch 26,Loss: 4.143317222595215\n", + "Epoch 27,Loss: 4.135959625244141\n", + "Epoch 28,Loss: 4.133705139160156\n", + "Epoch 29,Loss: 4.144227981567383\n", + "Epoch 30,Loss: 4.137174606323242\n", + "Epoch 31,Loss: 4.137599945068359\n", + "Epoch 32,Loss: 4.133066177368164\n", + "Epoch 33,Loss: 4.140651702880859\n", + "Epoch 34,Loss: 4.132898330688477\n", + "Epoch 35,Loss: 4.129890441894531\n", + "Epoch 36,Loss: 4.128242492675781\n", + "Epoch 37,Loss: 4.1404337882995605\n", + "Epoch 38,Loss: 4.124664306640625\n", + "Epoch 39,Loss: 4.121993541717529\n", + "Epoch 40,Loss: 4.131916046142578\n", + "Epoch 41,Loss: 4.121825218200684\n", + "Epoch 42,Loss: 4.124566078186035\n", + "Epoch 43,Loss: 4.122013568878174\n", + "Epoch 44,Loss: 4.124298572540283\n", + "Epoch 45,Loss: 4.116214275360107\n", + "Epoch 46,Loss: 4.118075847625732\n", + "Epoch 47,Loss: 4.114686965942383\n", + "Epoch 48,Loss: 4.117861270904541\n", + "Epoch 49,Loss: 4.111809730529785\n", + "Epoch 50,Loss: 4.1100616455078125\n", + "Epoch 51,Loss: 4.1063408851623535\n", + "Epoch 52,Loss: 4.106322765350342\n", + "Epoch 53,Loss: 4.105854511260986\n", + "Epoch 54,Loss: 4.116119861602783\n", + "Epoch 55,Loss: 4.107294082641602\n", + "Epoch 56,Loss: 4.112220287322998\n", + "Epoch 57,Loss: 4.106285095214844\n", + "Epoch 58,Loss: 4.101660251617432\n", + "Epoch 59,Loss: 4.100406646728516\n", + "Epoch 60,Loss: 4.098540306091309\n", + "Epoch 61,Loss: 4.097071170806885\n", + "Epoch 62,Loss: 4.094281196594238\n", + "Epoch 63,Loss: 4.098484039306641\n", + "Epoch 64,Loss: 4.09868049621582\n", + "Epoch 65,Loss: 4.098651885986328\n", + "Epoch 66,Loss: 4.093440055847168\n", + "Epoch 67,Loss: 4.098696708679199\n", + "Epoch 68,Loss: 4.089413166046143\n", + "Epoch 69,Loss: 4.089560508728027\n", + "Epoch 70,Loss: 4.0868353843688965\n", + "Epoch 71,Loss: 4.088634014129639\n", + "Epoch 72,Loss: 4.084425926208496\n", + "Epoch 73,Loss: 4.085216522216797\n", + "Epoch 74,Loss: 4.077157020568848\n", + "Epoch 75,Loss: 4.083675384521484\n", + "Epoch 76,Loss: 4.080029487609863\n", + "Epoch 77,Loss: 4.075240135192871\n", + "Epoch 78,Loss: 4.077165603637695\n", + "Epoch 79,Loss: 4.082056522369385\n", + "Epoch 80,Loss: 4.084698677062988\n", + "Epoch 81,Loss: 4.074877738952637\n", + "Epoch 82,Loss: 4.065998077392578\n", + "Epoch 83,Loss: 4.072014808654785\n", + "Epoch 84,Loss: 4.065586090087891\n", + "Epoch 85,Loss: 4.070856094360352\n", + "Epoch 86,Loss: 4.06818151473999\n", + "Epoch 87,Loss: 4.072195053100586\n", + "Epoch 88,Loss: 4.065300941467285\n", + "Epoch 89,Loss: 4.059263706207275\n", + "Epoch 90,Loss: 4.06481409072876\n", + "Epoch 91,Loss: 4.067580699920654\n", + "Epoch 92,Loss: 4.062394142150879\n", + "Epoch 93,Loss: 4.053855895996094\n", + "Epoch 94,Loss: 4.056528091430664\n", + "Epoch 95,Loss: 4.060898780822754\n", + "Epoch 96,Loss: 4.057040214538574\n", + "Epoch 97,Loss: 4.053262710571289\n", + "Epoch 98,Loss: 4.054243564605713\n", + "Epoch 99,Loss: 4.056582450866699\n", + "Epoch 100,Loss: 4.046238899230957\n", + "Epoch 101,Loss: 4.0497331619262695\n", + "Epoch 102,Loss: 4.053548812866211\n", + "Epoch 103,Loss: 4.050479412078857\n", + "Epoch 104,Loss: 4.0534257888793945\n", + "Epoch 105,Loss: 4.04215145111084\n", + "Epoch 106,Loss: 4.047131538391113\n", + "Epoch 107,Loss: 4.044473648071289\n", + "Epoch 108,Loss: 4.040208339691162\n", + "Epoch 109,Loss: 4.046354293823242\n", + "Epoch 110,Loss: 4.03936767578125\n", + "Epoch 111,Loss: 4.042560577392578\n", + "Epoch 112,Loss: 4.035763740539551\n", + "Epoch 113,Loss: 4.035247325897217\n", + "Epoch 114,Loss: 4.0357985496521\n", + "Epoch 115,Loss: 4.029809474945068\n", + "Epoch 116,Loss: 4.034292221069336\n", + "Epoch 117,Loss: 4.03066349029541\n", + "Epoch 118,Loss: 4.036454200744629\n", + "Epoch 119,Loss: 4.026322364807129\n", + "Epoch 120,Loss: 4.024256706237793\n", + "Epoch 121,Loss: 4.022697448730469\n", + "Epoch 122,Loss: 4.022899627685547\n", + "Epoch 123,Loss: 4.025988578796387\n", + "Epoch 124,Loss: 4.022078514099121\n", + "Epoch 125,Loss: 4.02480411529541\n", + "Epoch 126,Loss: 4.015079498291016\n", + "Epoch 127,Loss: 4.017082214355469\n", + "Epoch 128,Loss: 4.021419048309326\n", + "Epoch 129,Loss: 4.012678146362305\n", + "Epoch 130,Loss: 4.012955188751221\n", + "Epoch 131,Loss: 4.008285045623779\n", + "Epoch 132,Loss: 4.011624813079834\n", + "Epoch 133,Loss: 4.008131504058838\n", + "Epoch 134,Loss: 4.015929222106934\n", + "Epoch 135,Loss: 4.0120649337768555\n", + "Epoch 136,Loss: 4.005334377288818\n", + "Epoch 137,Loss: 4.013724327087402\n", + "Epoch 138,Loss: 4.004214286804199\n", + "Epoch 139,Loss: 4.007635116577148\n", + "Epoch 140,Loss: 3.9990665912628174\n", + "Epoch 141,Loss: 4.004575252532959\n", + "Epoch 142,Loss: 3.9970438480377197\n", + "Epoch 143,Loss: 4.000253677368164\n", + "Epoch 144,Loss: 3.993255138397217\n", + "Epoch 145,Loss: 4.004382610321045\n", + "Epoch 146,Loss: 3.9942245483398438\n", + "Epoch 147,Loss: 3.996398687362671\n", + "Epoch 148,Loss: 3.9898180961608887\n", + "Epoch 149,Loss: 3.9965932369232178\n", + "Epoch 150,Loss: 3.988018751144409\n", + "Epoch 151,Loss: 3.9907279014587402\n", + "Epoch 152,Loss: 3.984743118286133\n", + "Epoch 153,Loss: 3.98683500289917\n", + "Epoch 154,Loss: 3.980862855911255\n", + "Epoch 155,Loss: 3.9790799617767334\n", + "Epoch 156,Loss: 3.9874353408813477\n", + "Epoch 157,Loss: 3.981985092163086\n", + "Epoch 158,Loss: 3.971966028213501\n", + "Epoch 159,Loss: 3.9836814403533936\n", + "Epoch 160,Loss: 3.9836783409118652\n", + "Epoch 161,Loss: 3.9711475372314453\n", + "Epoch 162,Loss: 3.9807727336883545\n", + "Epoch 163,Loss: 3.9743809700012207\n", + "Epoch 164,Loss: 3.9751386642456055\n", + "Epoch 165,Loss: 3.9707915782928467\n", + "Epoch 166,Loss: 3.966444492340088\n", + "Epoch 167,Loss: 3.972529649734497\n", + "Epoch 168,Loss: 3.970104217529297\n", + "Epoch 169,Loss: 3.9673550128936768\n", + "Epoch 170,Loss: 3.974461078643799\n", + "Epoch 171,Loss: 3.966738224029541\n", + "Epoch 172,Loss: 3.9683542251586914\n", + "Epoch 173,Loss: 3.9616072177886963\n", + "Epoch 174,Loss: 3.9639317989349365\n", + "Epoch 175,Loss: 3.9576151371002197\n", + "Epoch 176,Loss: 3.9661707878112793\n", + "Epoch 177,Loss: 3.9580233097076416\n", + "Epoch 178,Loss: 3.9642345905303955\n", + "Epoch 179,Loss: 3.968759775161743\n", + "Epoch 180,Loss: 3.9539124965667725\n", + "Epoch 181,Loss: 3.9508087635040283\n", + "Epoch 182,Loss: 3.964644432067871\n", + "Epoch 183,Loss: 3.951427459716797\n", + "Epoch 184,Loss: 3.951119899749756\n", + "Epoch 185,Loss: 3.953275203704834\n", + "Epoch 186,Loss: 3.949398994445801\n", + "Epoch 187,Loss: 3.9452455043792725\n", + "Epoch 188,Loss: 3.9493489265441895\n", + "Epoch 189,Loss: 3.9497108459472656\n", + "Epoch 190,Loss: 3.9341869354248047\n", + "Epoch 191,Loss: 3.94256591796875\n", + "Epoch 192,Loss: 3.939656972885132\n", + "Epoch 193,Loss: 3.9505527019500732\n", + "Epoch 194,Loss: 3.9368791580200195\n", + "Epoch 195,Loss: 3.93843936920166\n", + "Epoch 196,Loss: 3.9386792182922363\n", + "Epoch 197,Loss: 3.9371814727783203\n", + "Epoch 198,Loss: 3.9317574501037598\n", + "Epoch 199,Loss: 3.9382877349853516\n", + "Epoch 200,Loss: 3.931473970413208\n", + "Epoch 201,Loss: 3.9273834228515625\n", + "Epoch 202,Loss: 3.9355263710021973\n", + "Epoch 203,Loss: 3.9317989349365234\n", + "Epoch 204,Loss: 3.9241435527801514\n", + "Epoch 205,Loss: 3.927114486694336\n", + "Epoch 206,Loss: 3.9274179935455322\n", + "Epoch 207,Loss: 3.918044090270996\n", + "Epoch 208,Loss: 3.922666549682617\n", + "Epoch 209,Loss: 3.9170031547546387\n", + "Epoch 210,Loss: 3.9196343421936035\n", + "Epoch 211,Loss: 3.920548915863037\n", + "Epoch 212,Loss: 3.9187779426574707\n", + "Epoch 213,Loss: 3.914111614227295\n", + "Epoch 214,Loss: 3.915700674057007\n", + "Epoch 215,Loss: 3.912682056427002\n", + "Epoch 216,Loss: 3.913231372833252\n", + "Epoch 217,Loss: 3.909395694732666\n", + "Epoch 218,Loss: 3.9103047847747803\n", + "Epoch 219,Loss: 3.9027228355407715\n", + "Epoch 220,Loss: 3.9013731479644775\n", + "Epoch 221,Loss: 3.907486915588379\n", + "Epoch 222,Loss: 3.9011597633361816\n", + "Epoch 223,Loss: 3.911036252975464\n", + "Epoch 224,Loss: 3.919271469116211\n", + "Epoch 225,Loss: 3.8986048698425293\n", + "Epoch 226,Loss: 3.913900852203369\n", + "Epoch 227,Loss: 3.9005932807922363\n", + "Epoch 228,Loss: 3.894559144973755\n", + "Epoch 229,Loss: 3.9020142555236816\n", + "Epoch 230,Loss: 3.9012680053710938\n", + "Epoch 231,Loss: 3.89057993888855\n", + "Epoch 232,Loss: 3.901439666748047\n", + "Epoch 233,Loss: 3.8912413120269775\n", + "Epoch 234,Loss: 3.893927574157715\n", + "Epoch 235,Loss: 3.8888893127441406\n", + "Epoch 236,Loss: 3.891270637512207\n", + "Epoch 237,Loss: 3.8963348865509033\n", + "Epoch 238,Loss: 3.8835411071777344\n", + "Epoch 239,Loss: 3.8910601139068604\n", + "Epoch 240,Loss: 3.891777992248535\n", + "Epoch 241,Loss: 3.88920259475708\n", + "Epoch 242,Loss: 3.8773062229156494\n", + "Epoch 243,Loss: 3.8833348751068115\n", + "Epoch 244,Loss: 3.882423162460327\n", + "Epoch 245,Loss: 3.879281520843506\n", + "Epoch 246,Loss: 3.8650636672973633\n", + "Epoch 247,Loss: 3.88250732421875\n", + "Epoch 248,Loss: 3.86867094039917\n", + "Epoch 249,Loss: 3.87156343460083\n", + "Epoch 250,Loss: 3.8801052570343018\n", + "Epoch 251,Loss: 3.876453399658203\n", + "Epoch 252,Loss: 3.8785383701324463\n", + "Epoch 253,Loss: 3.8688297271728516\n", + "Epoch 254,Loss: 3.8692402839660645\n", + "Epoch 255,Loss: 3.865692615509033\n", + "Epoch 256,Loss: 3.8676300048828125\n", + "Epoch 257,Loss: 3.86896014213562\n", + "Epoch 258,Loss: 3.8605639934539795\n", + "Epoch 259,Loss: 3.8659508228302\n", + "Epoch 260,Loss: 3.870161533355713\n", + "Epoch 261,Loss: 3.862337350845337\n", + "Epoch 262,Loss: 3.8631129264831543\n", + "Epoch 263,Loss: 3.855680465698242\n", + "Epoch 264,Loss: 3.859565258026123\n", + "Epoch 265,Loss: 3.8612101078033447\n", + "Epoch 266,Loss: 3.8605751991271973\n", + "Epoch 267,Loss: 3.863940715789795\n", + "Epoch 268,Loss: 3.8460769653320312\n", + "Epoch 269,Loss: 3.8585381507873535\n", + "Epoch 270,Loss: 3.850302219390869\n", + "Epoch 271,Loss: 3.857239246368408\n", + "Epoch 272,Loss: 3.853991985321045\n", + "Epoch 273,Loss: 3.852595806121826\n", + "Epoch 274,Loss: 3.8553481101989746\n", + "Epoch 275,Loss: 3.845716714859009\n", + "Epoch 276,Loss: 3.846571922302246\n", + "Epoch 277,Loss: 3.851979970932007\n", + "Epoch 278,Loss: 3.8430185317993164\n", + "Epoch 279,Loss: 3.8463687896728516\n", + "Epoch 280,Loss: 3.839120864868164\n", + "Epoch 281,Loss: 3.8422348499298096\n", + "Epoch 282,Loss: 3.8394317626953125\n", + "Epoch 283,Loss: 3.8440678119659424\n", + "Epoch 284,Loss: 3.8503313064575195\n", + "Epoch 285,Loss: 3.838545799255371\n", + "Epoch 286,Loss: 3.829591751098633\n", + "Epoch 287,Loss: 3.8370213508605957\n", + "Epoch 288,Loss: 3.8436245918273926\n", + "Epoch 289,Loss: 3.8240108489990234\n", + "Epoch 290,Loss: 3.834958553314209\n", + "Epoch 291,Loss: 3.833307981491089\n", + "Epoch 292,Loss: 3.8331363201141357\n", + "Epoch 293,Loss: 3.838420867919922\n", + "Epoch 294,Loss: 3.8406732082366943\n", + "Epoch 295,Loss: 3.839576244354248\n", + "Epoch 296,Loss: 3.8232221603393555\n", + "Epoch 297,Loss: 3.8206682205200195\n", + "Epoch 298,Loss: 3.8219099044799805\n", + "Epoch 299,Loss: 3.834548234939575\n", + "Epoch 300,Loss: 3.8192687034606934\n", + "Epoch 301,Loss: 3.821470260620117\n", + "Epoch 302,Loss: 3.8147056102752686\n", + "Epoch 303,Loss: 3.8182125091552734\n", + "Epoch 304,Loss: 3.8224501609802246\n", + "Epoch 305,Loss: 3.8188157081604004\n", + "Epoch 306,Loss: 3.819611072540283\n", + "Epoch 307,Loss: 3.81486439704895\n", + "Epoch 308,Loss: 3.812788486480713\n", + "Epoch 309,Loss: 3.799957036972046\n", + "Epoch 310,Loss: 3.8292648792266846\n", + "Epoch 311,Loss: 3.8186426162719727\n", + "Epoch 312,Loss: 3.8149666786193848\n", + "Epoch 313,Loss: 3.8050060272216797\n", + "Epoch 314,Loss: 3.801513671875\n", + "Epoch 315,Loss: 3.7967958450317383\n", + "Epoch 316,Loss: 3.7983767986297607\n", + "Epoch 317,Loss: 3.794619560241699\n", + "Epoch 318,Loss: 3.812180519104004\n", + "Epoch 319,Loss: 3.810817241668701\n", + "Epoch 320,Loss: 3.790243148803711\n", + "Epoch 321,Loss: 3.804206132888794\n", + "Epoch 322,Loss: 3.7979207038879395\n", + "Epoch 323,Loss: 3.790651559829712\n", + "Epoch 324,Loss: 3.7874720096588135\n", + "Epoch 325,Loss: 3.791238784790039\n", + "Epoch 326,Loss: 3.794536828994751\n", + "Epoch 327,Loss: 3.784796714782715\n", + "Epoch 328,Loss: 3.790224552154541\n", + "Epoch 329,Loss: 3.782738208770752\n", + "Epoch 330,Loss: 3.7941744327545166\n", + "Epoch 331,Loss: 3.786827564239502\n", + "Epoch 332,Loss: 3.7888121604919434\n", + "Epoch 333,Loss: 3.78334641456604\n", + "Epoch 334,Loss: 3.7930338382720947\n", + "Epoch 335,Loss: 3.787445545196533\n", + "Epoch 336,Loss: 3.7902698516845703\n", + "Epoch 337,Loss: 3.7786812782287598\n", + "Epoch 338,Loss: 3.7811098098754883\n", + "Epoch 339,Loss: 3.76914644241333\n", + "Epoch 340,Loss: 3.789398670196533\n", + "Epoch 341,Loss: 3.779127359390259\n", + "Epoch 342,Loss: 3.770552635192871\n", + "Epoch 343,Loss: 3.766423463821411\n", + "Epoch 344,Loss: 3.7771596908569336\n", + "Epoch 345,Loss: 3.763667106628418\n", + "Epoch 346,Loss: 3.787656784057617\n", + "Epoch 347,Loss: 3.7644731998443604\n", + "Epoch 348,Loss: 3.760593891143799\n", + "Epoch 349,Loss: 3.7653212547302246\n", + "Epoch 350,Loss: 3.7683029174804688\n", + "Epoch 351,Loss: 3.769117832183838\n", + "Epoch 352,Loss: 3.760528087615967\n", + "Epoch 353,Loss: 3.771134614944458\n", + "Epoch 354,Loss: 3.766356945037842\n", + "Epoch 355,Loss: 3.7620558738708496\n", + "Epoch 356,Loss: 3.763491153717041\n", + "Epoch 357,Loss: 3.7664551734924316\n", + "Epoch 358,Loss: 3.7605996131896973\n", + "Epoch 359,Loss: 3.763265609741211\n", + "Epoch 360,Loss: 3.745481252670288\n", + "Epoch 361,Loss: 3.7506918907165527\n", + "Epoch 362,Loss: 3.7509055137634277\n", + "Epoch 363,Loss: 3.758342742919922\n", + "Epoch 364,Loss: 3.7418174743652344\n", + "Epoch 365,Loss: 3.7498393058776855\n", + "Epoch 366,Loss: 3.754758358001709\n", + "Epoch 367,Loss: 3.7491402626037598\n", + "Epoch 368,Loss: 3.7493896484375\n", + "Epoch 369,Loss: 3.7415528297424316\n", + "Epoch 370,Loss: 3.747483968734741\n", + "Epoch 371,Loss: 3.756594657897949\n", + "Epoch 372,Loss: 3.7347874641418457\n", + "Epoch 373,Loss: 3.73516845703125\n", + "Epoch 374,Loss: 3.7368345260620117\n", + "Epoch 375,Loss: 3.752837896347046\n", + "Epoch 376,Loss: 3.7481985092163086\n", + "Epoch 377,Loss: 3.743605852127075\n", + "Epoch 378,Loss: 3.7425527572631836\n", + "Epoch 379,Loss: 3.743187189102173\n", + "Epoch 380,Loss: 3.7335641384124756\n", + "Epoch 381,Loss: 3.7418200969696045\n", + "Epoch 382,Loss: 3.7498106956481934\n", + "Epoch 383,Loss: 3.730151414871216\n", + "Epoch 384,Loss: 3.735663414001465\n", + "Epoch 385,Loss: 3.7256903648376465\n", + "Epoch 386,Loss: 3.7350902557373047\n", + "Epoch 387,Loss: 3.725283622741699\n", + "Epoch 388,Loss: 3.717357635498047\n", + "Epoch 389,Loss: 3.729820728302002\n", + "Epoch 390,Loss: 3.728365182876587\n", + "Epoch 391,Loss: 3.72127628326416\n", + "Epoch 392,Loss: 3.730597496032715\n", + "Epoch 393,Loss: 3.7206125259399414\n", + "Epoch 394,Loss: 3.7320313453674316\n", + "Epoch 395,Loss: 3.726435422897339\n", + "Epoch 396,Loss: 3.7120227813720703\n", + "Epoch 397,Loss: 3.721881866455078\n", + "Epoch 398,Loss: 3.7123489379882812\n", + "Epoch 399,Loss: 3.7059667110443115\n", + "Epoch 400,Loss: 3.715552806854248\n", + "Epoch 401,Loss: 3.7071962356567383\n", + "Epoch 402,Loss: 3.717360019683838\n", + "Epoch 403,Loss: 3.722916603088379\n", + "Epoch 404,Loss: 3.7143983840942383\n", + "Epoch 405,Loss: 3.7218456268310547\n", + "Epoch 406,Loss: 3.70701265335083\n", + "Epoch 407,Loss: 3.707385540008545\n", + "Epoch 408,Loss: 3.7069830894470215\n", + "Epoch 409,Loss: 3.7150468826293945\n", + "Epoch 410,Loss: 3.709181785583496\n", + "Epoch 411,Loss: 3.7069830894470215\n", + "Epoch 412,Loss: 3.7000954151153564\n", + "Epoch 413,Loss: 3.6996006965637207\n", + "Epoch 414,Loss: 3.709993839263916\n", + "Epoch 415,Loss: 3.6976137161254883\n", + "Epoch 416,Loss: 3.692563772201538\n", + "Epoch 417,Loss: 3.6903860569000244\n", + "Epoch 418,Loss: 3.6960387229919434\n", + "Epoch 419,Loss: 3.684077262878418\n", + "Epoch 420,Loss: 3.6853389739990234\n", + "Epoch 421,Loss: 3.6937026977539062\n", + "Epoch 422,Loss: 3.692162036895752\n", + "Epoch 423,Loss: 3.697707414627075\n", + "Epoch 424,Loss: 3.683466911315918\n", + "Epoch 425,Loss: 3.703868865966797\n", + "Epoch 426,Loss: 3.700802803039551\n", + "Epoch 427,Loss: 3.6789040565490723\n", + "Epoch 428,Loss: 3.6846561431884766\n", + "Epoch 429,Loss: 3.696505308151245\n", + "Epoch 430,Loss: 3.6857666969299316\n", + "Epoch 431,Loss: 3.700955390930176\n", + "Epoch 432,Loss: 3.6856637001037598\n", + "Epoch 433,Loss: 3.687821388244629\n", + "Epoch 434,Loss: 3.666987657546997\n", + "Epoch 435,Loss: 3.6818313598632812\n", + "Epoch 436,Loss: 3.6712565422058105\n", + "Epoch 437,Loss: 3.6741445064544678\n", + "Epoch 438,Loss: 3.6812634468078613\n", + "Epoch 439,Loss: 3.682091236114502\n", + "Epoch 440,Loss: 3.682633876800537\n", + "Epoch 441,Loss: 3.674262523651123\n", + "Epoch 442,Loss: 3.6762471199035645\n", + "Epoch 443,Loss: 3.6800146102905273\n", + "Epoch 444,Loss: 3.6821203231811523\n", + "Epoch 445,Loss: 3.679013252258301\n", + "Epoch 446,Loss: 3.663221836090088\n", + "Epoch 447,Loss: 3.674647331237793\n", + "Epoch 448,Loss: 3.674830436706543\n", + "Epoch 449,Loss: 3.6704533100128174\n", + "Epoch 450,Loss: 3.653430938720703\n", + "Epoch 451,Loss: 3.676170587539673\n", + "Epoch 452,Loss: 3.665440559387207\n", + "Epoch 453,Loss: 3.662698745727539\n", + "Epoch 454,Loss: 3.664393186569214\n", + "Epoch 455,Loss: 3.663473129272461\n", + "Epoch 456,Loss: 3.652484178543091\n", + "Epoch 457,Loss: 3.657601833343506\n", + "Epoch 458,Loss: 3.659010410308838\n", + "Epoch 459,Loss: 3.669285297393799\n", + "Epoch 460,Loss: 3.65285587310791\n", + "Epoch 461,Loss: 3.661116600036621\n", + "Epoch 462,Loss: 3.651196002960205\n", + "Epoch 463,Loss: 3.6498472690582275\n", + "Epoch 464,Loss: 3.6548192501068115\n", + "Epoch 465,Loss: 3.6427948474884033\n", + "Epoch 466,Loss: 3.640430450439453\n", + "Epoch 467,Loss: 3.643303632736206\n", + "Epoch 468,Loss: 3.632857322692871\n", + "Epoch 469,Loss: 3.6472060680389404\n", + "Epoch 470,Loss: 3.6393682956695557\n", + "Epoch 471,Loss: 3.643303155899048\n", + "Epoch 472,Loss: 3.635702610015869\n", + "Epoch 473,Loss: 3.6435651779174805\n", + "Epoch 474,Loss: 3.6412034034729004\n", + "Epoch 475,Loss: 3.6328794956207275\n", + "Epoch 476,Loss: 3.6483945846557617\n", + "Epoch 477,Loss: 3.6373813152313232\n", + "Epoch 478,Loss: 3.6451635360717773\n", + "Epoch 479,Loss: 3.6504099369049072\n", + "Epoch 480,Loss: 3.637803077697754\n", + "Epoch 481,Loss: 3.6518633365631104\n", + "Epoch 482,Loss: 3.6186020374298096\n", + "Epoch 483,Loss: 3.6532328128814697\n", + "Epoch 484,Loss: 3.627650499343872\n", + "Epoch 485,Loss: 3.648397922515869\n", + "Epoch 486,Loss: 3.6308512687683105\n", + "Epoch 487,Loss: 3.622593879699707\n", + "Epoch 488,Loss: 3.624168872833252\n", + "Epoch 489,Loss: 3.625032901763916\n", + "Epoch 490,Loss: 3.614368438720703\n", + "Epoch 491,Loss: 3.629826545715332\n", + "Epoch 492,Loss: 3.6387391090393066\n", + "Epoch 493,Loss: 3.6167736053466797\n", + "Epoch 494,Loss: 3.6361451148986816\n", + "Epoch 495,Loss: 3.6265835762023926\n", + "Epoch 496,Loss: 3.6187691688537598\n", + "Epoch 497,Loss: 3.6252095699310303\n", + "Epoch 498,Loss: 3.621762275695801\n", + "Epoch 499,Loss: 3.6142048835754395\n", + "TIME jax 5.798891305923462\n" + ] + } + ], + "source": [ + "prng = jax.random.PRNGKey(1337)\n", + "\n", + "# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n", + "with open('input.txt', 'r', encoding='utf-8') as f:\n", + " text = f.read()\n", + "\n", + "# hyperparameters\n", + "batch_size = 16 # how many independent sequences will we process in parallel?\n", + "block_size = 32 # what is the maximum context length for predictions?\n", + "max_iters = 500\n", + "eval_interval = 100\n", + "learning_rate = 1e-3\n", + "eval_iters = 10\n", + "n_embd = 64\n", + "n_head = 4\n", + "n_layer = 4\n", + "dropout = 0.0\n", + "\n", + "\n", + "# here are all the unique characters that occur in this text\n", + "chars = sorted(list(set(text)))\n", + "vocab_size = len(chars)\n", + "# create a mapping from characters to integers\n", + "stoi = { ch:i for i,ch in enumerate(chars) }\n", + "itos = { i:ch for i,ch in enumerate(chars) }\n", + "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n", + "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n", + "\n", + "# Train and test splits\n", + "data = jnp.array(encode(text), dtype=jnp.int32)\n", + "data = device_put(data)\n", + "n = int(0.9*len(data)) # first 90% will be train, rest val\n", + "train_data = data[:n]\n", + "val_data = data[n:]\n", + "\n", + "def get_batch(split, subkey):\n", + " # generate a small batch of data of inputs x and targets y\n", + " data = train_data if split == 'train' else val_data\n", + " t1 = time.time()\n", + " ix = random.randint(subkey, (batch_size,), 0, len(data) - block_size)\n", + " t2 = time.time()\n", + "\n", + "\n", + " # x = jnp.stack([data[i:i+block_size] for i in ix])\n", + " # y = jnp.stack([data[i+1:i+block_size+1] for i in ix])\n", + " def slice_data(i):\n", + " return jax.lax.dynamic_slice(data, (i,), (block_size,))\n", + "\n", + " x = jax.vmap(slice_data)(ix)\n", + " y = jax.vmap(slice_data)(ix+1)\n", + " x, y = device_put(x), device_put(y)\n", + " # print('TIME rand idx', t2-t1)\n", + " # print('TIME rand idx fetch', time.time()-t2)\n", + " return x, y\n", + "\n", + "def estimate_loss(params, prng):\n", + " out = {}\n", + " for split in ['train', 'val']:\n", + " losses = jnp.zeros(eval_iters)\n", + " for k in range(eval_iters):\n", + " prng, subkey = random.split(prng)\n", + " X, Y = get_batch(split, subkey)\n", + " logits, loss = model.apply(params, X, Y)\n", + " losses = losses.at[k].set(loss)\n", + " out[split] = losses.mean()\n", + " return out\n", + "\n", + "class BigramLanguageModel(nn.Module):\n", + " vocab_size: int\n", + "\n", + " @nn.compact\n", + " def __call__(self, idx, targets=None):\n", + " # Token embedding table\n", + " embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)\n", + " logits = embedding_table(idx) # (B,T,C)\n", + "\n", + " if targets is None:\n", + " loss = None\n", + " else:\n", + " B, T, C = logits.shape\n", + " logits = logits.reshape(B*T, C)\n", + " targets = targets.reshape(B*T)\n", + " loss = -jnp.sum(jax.nn.one_hot(targets, C) * jax.nn.log_softmax(logits), axis=1).mean()\n", + "\n", + " return logits, loss\n", + "\n", + " def generate(self, params, key, idx, max_new_tokens):\n", + " for _ in range(max_new_tokens):\n", + " logits, _ = self.apply(params, idx)\n", + " logits = logits[:, -1, :] # (B, C)\n", + " # probs = jax.nn.softmax(logits, axis=-1) # (B, C)\n", + " key, subkey = random.split(key)\n", + " idx_next = random.categorical(subkey, logits)[:, None] # (B, 1)\n", + " idx = jnp.concatenate((idx, idx_next), axis=1) # (B, T+1)\n", + " return idx\n", + "\n", + "# Define the optimizer\n", + "tx = adam(learning_rate)\n", + "\n", + "# Initialize model parameters and optimizer state\n", + "model = BigramLanguageModel(vocab_size)\n", + "\n", + "# Initialize model parameters and optimizer\n", + "params = model.init(prng, jnp.ones((1, 1), jnp.int32))\n", + "opt_state = tx.init(params)\n", + "\n", + "# Loss function (assuming you have a batch of data: xb, yb)\n", + "def loss_fn(params, xb, yb):\n", + " logits, loss = model.apply(params, xb, yb)\n", + " return loss\n", + "\n", + "# Update function for a single training step\n", + "@jax.jit\n", + "def update_step(params, opt_state, xb, yb):\n", + " loss, grads = jax.value_and_grad(loss_fn)(params, xb, yb)\n", + " updates, opt_state = tx.update(grads, opt_state, params)\n", + " new_params = optax.apply_updates(params, updates)\n", + " return new_params, opt_state, loss\n", + "\n", + "# Training loop (example)\n", + "batch_size = 32\n", + "t = time.time()\n", + "for steps in range(max_iters):\n", + " # every once in a while evaluate the loss on train and val sets\n", + " t1 = time.time()\n", + " prng, subkey = random.split(prng)\n", + " xb, yb = get_batch('train', subkey)\n", + " t2 = time.time()\n", + " params, opt_state, loss = update_step(params, opt_state, xb, yb)\n", + " print(f\"Epoch {steps},Loss: {loss}\")\n", + "print('TIME jax', time.time()-t)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "U96Hrwpz_QUx" + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XinV8nmAnmKN" + }, + "source": [ + "## The mathematical trick in self-attention" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tukiH-NbRBhA", + "outputId": "a7fa4719-1653-46e4-f357-b1be1b395a20" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[1. 0. 0.]\n", + " [1. 1. 0.]\n", + " [1. 1. 1.]]\n", + "[[1. 0. 0. ]\n", + " [0.5 0.5 0. ]\n", + " [0.33333334 0.33333334 0.33333334]]\n", + "a=\n", + "[[1. 0. 0. ]\n", + " [0.5 0.5 0. ]\n", + " [0.33333334 0.33333334 0.33333334]]\n", + "--\n", + "b=\n", + "[[2 3]\n", + " [9 9]\n", + " [4 6]]\n", + "--\n", + "c=\n", + "[[2. 3. ]\n", + " [5.5 6. ]\n", + " [5. 6. ]]\n" + ] + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax import random\n", + "\n", + "# set the random key and seed\n", + "key = random.PRNGKey(42)\n", + "\n", + "# create a lower triangular matrix\n", + "a = jnp.tril(jnp.ones((3, 3)))\n", + "print(a)\n", + "a = a / jnp.sum(a, axis=1, keepdims=True)\n", + "print(a)\n", + "# create a random matrix\n", + "b = random.randint(key, (3, 2), 0, 10, dtype=jnp.int32)\n", + "\n", + "# perform the matrix multiplication\n", + "c = jnp.matmul(a, b)\n", + "\n", + "# print the matrices and the result\n", + "print(\"a=\")\n", + "print(a)\n", + "print(\"--\")\n", + "print(\"b=\")\n", + "print(b)\n", + "print(\"--\")\n", + "print(\"c=\")\n", + "print(c)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Hs_E24uRE8kr", + "outputId": "7006d8dd-9efc-404f-9a9c-bcce67451126" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(4, 8, 2)\n" + ] + } + ], + "source": [ + "# Set the random key and seed\n", + "key = random.PRNGKey(1337)\n", + "\n", + "# Define the batch size, sequence length, and number of channels\n", + "B, T, C = 4, 8, 2\n", + "\n", + "# Generate a random tensor with the specified shape\n", + "x = random.normal(key, (B, T, C))\n", + "\n", + "# Print the shape of the tensor\n", + "print(x.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "86NuXX0fn7ps", + "outputId": "1d2200f6-2fe5-4a54-df32-09a103ec4dfc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[[ 1.3654243 -1.3698599 ]\n", + " [ 1.8501079 -1.18029 ]\n", + " [ 1.2796469 -1.1314313 ]\n", + " [ 1.2234701 -0.95123225]\n", + " [ 1.0816249 -0.8464665 ]\n", + " [ 0.9756752 -0.40811488]\n", + " [ 1.089724 -0.345884 ]\n", + " [ 1.0498141 -0.29275966]]\n", + "\n", + " [[-0.02413951 1.4920624 ]\n", + " [-0.37343726 1.7603257 ]\n", + " [-0.48404503 1.4237347 ]\n", + " [-0.18112068 1.2025388 ]\n", + " [-0.28402337 0.9092719 ]\n", + " [-0.21978617 0.8809384 ]\n", + " [-0.3614355 1.045876 ]\n", + " [-0.27126977 0.8439261 ]]\n", + "\n", + " [[ 0.67737067 0.45489657]\n", + " [-0.02651259 0.6725235 ]\n", + " [ 0.3897322 0.32223004]\n", + " [-0.05313486 0.49777415]\n", + " [-0.09553531 0.17514856]\n", + " [ 0.09583326 0.24674284]\n", + " [ 0.17541157 0.14904368]\n", + " [ 0.36874425 -0.14085238]]\n", + "\n", + " [[-1.1729926 -1.0436211 ]\n", + " [-0.76886547 -0.83241093]\n", + " [-0.1617876 -0.5863479 ]\n", + " [-0.23202893 -0.7893201 ]\n", + " [-0.09417699 -0.38362506]\n", + " [-0.03415637 -0.40888947]\n", + " [-0.10969827 -0.27693108]\n", + " [ 0.08133804 -0.42021018]]]\n" + ] + } + ], + "source": [ + "# We want x[b,t] = mean_{i<=t} x[b,i]\n", + "# Initialize xbow tensor\n", + "xbow = jnp.zeros((B, T, C))\n", + "\n", + "# Loop over batch and time dimensions\n", + "for b in range(B):\n", + " for t in range(T):\n", + " xprev = x[b, :t + 1] # (t, C)\n", + " mean_xprev = jnp.mean(xprev, axis=0)\n", + " xbow = xbow.at[b, t].set(mean_xprev)\n", + "\n", + "# Print the result\n", + "print(xbow)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yhdOAd6-wXkZ", + "outputId": "734cc65c-397a-4625-c5a9-3fe1ec68bd72" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] + } + ], + "source": [ + "# version 2: using matrix multiply for a weighted aggregation\n", + "\n", + "# Create a lower triangular matrix with normalized rows\n", + "wei = jnp.tril(jnp.ones((T, T)))\n", + "wei = wei / jnp.sum(wei, axis=1, keepdims=True)\n", + "\n", + "# Perform the weighted aggregation using matrix multiplication\n", + "xbow2 = jnp.matmul(wei, x) # (T, T) @ (B, T, C) ----> (B, T, C)\n", + "\n", + "# Check if the results are close\n", + "all_close = jnp.allclose(xbow, xbow2)\n", + "print(all_close)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lUBprV0th_aR", + "outputId": "2d90ecd7-c776-4676-93f8-1bb5a17df1fa" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[1. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. ],\n", + " [0.5 , 0.5 , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. ],\n", + " [0.33333334, 0.33333334, 0.33333334, 0. , 0. ,\n", + " 0. , 0. , 0. ],\n", + " [0.25 , 0.25 , 0.25 , 0.25 , 0. ,\n", + " 0. , 0. , 0. ],\n", + " [0.2 , 0.2 , 0.2 , 0.2 , 0.2 ,\n", + " 0. , 0. , 0. ],\n", + " [0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,\n", + " 0.16666667, 0. , 0. ],\n", + " [0.14285715, 0.14285715, 0.14285715, 0.14285715, 0.14285715,\n", + " 0.14285715, 0.14285715, 0. ],\n", + " [0.125 , 0.125 , 0.125 , 0.125 , 0.125 ,\n", + " 0.125 , 0.125 , 0.125 ]], dtype=float32)" + ] + }, + "execution_count": 204, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "wei = jnp.tril(jnp.ones((T, T)))\n", + "wei = wei / jnp.sum(wei, axis=1, keepdims=True)\n", + "wei" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "wOURrfG-ysoL", + "outputId": "76b13814-fecb-4019-c6a4-931b9a335a98" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] + } + ], + "source": [ + "# version 3: use Softmax\n", + "from jax.nn import softmax\n", + "\n", + "# Create a lower triangular mask\n", + "tril = jnp.tril(jnp.ones((T, T)))\n", + "\n", + "# Create a mask filled with negative infinity where tril is 0\n", + "wei = jnp.zeros((T, T))\n", + "wei = jnp.where(tril == 0, -jnp.inf, wei)\n", + "\n", + "# Apply softmax along the last dimension\n", + "wei = jnn.softmax(wei, axis=-1)\n", + "\n", + "# Perform the weighted aggregation using matrix multiplication\n", + "xbow3 = jnp.matmul(wei, x)\n", + "\n", + "# Check if the results are close\n", + "all_close = jnp.allclose(xbow, xbow3)\n", + "print(all_close)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "EDarxEWIRMKq", + "outputId": "b38c45a4-7867-4d6a-e429-7c997cf38c8e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(4, 8, 32) (32, 16)\n", + "(4, 8, 16)\n", + "(4, 8, 16)\n", + "(4, 8, 8) (8, 8)\n", + "(4, 8, 16)\n" + ] + } + ], + "source": [ + "# Set the random key and seed\n", + "rng = random.PRNGKey(1337)\n", + "\n", + "# Create a random input tensor\n", + "B,T,C = 4,8,32 # batch, time, channels\n", + "x = random.normal(rng, (B, T, C))\n", + "\n", + "# Define the head size for self-attention\n", + "head_size = 16\n", + "\n", + "# Define the linear layers for key, query, and value\n", + "key = nn.Dense(head_size, kernel_init=nn.initializers.glorot_normal())\n", + "query = nn.Dense(head_size, kernel_init=nn.initializers.glorot_normal())\n", + "value = nn.Dense(head_size, kernel_init=nn.initializers.glorot_normal())\n", + "\n", + "# Compute the key, query, and value projections\n", + "k_variables = key.init(rng, random.normal(rng, (B, T, C)))\n", + "print(x.shape, k_variables['params']['kernel'].shape)\n", + "\n", + "k = key.apply(k_variables, x) # (B, T, 16)\n", + "print(k.shape)\n", + "\n", + "q_variables = query.init(rng, x)\n", + "q = query.apply(q_variables, x) # (B, T, 16)\n", + "print(q.shape)\n", + "\n", + "# Compute the attention weights\n", + "wei = jnp.matmul(q, k.transpose((0, 2, 1))) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)\n", + "# Create a lower triangular mask\n", + "tril = jnp.tril(jnp.ones((T, T)))\n", + "\n", + "# Apply the mask and then softmax along the last dimension\n", + "wei = jnp.where(tril == 0, -jnp.inf, wei)\n", + "wei = nn.softmax(wei, axis=-1)\n", + "print(wei.shape, wei[0].shape)\n", + "\n", + "# Compute the output using the attention weights and the value projection\n", + "v_variables = value.init(rng, x)\n", + "v = value.apply(v_variables, x)\n", + "out = jnp.matmul(wei, v)\n", + "\n", + "# Print the shape of the output\n", + "print(out.shape)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vT1hdtzXCjgL", + "outputId": "6cac0dca-3635-4200-d7a5-e7126b2e54dd" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(Array([[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,\n", + " 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", + " [9.7633256e-03, 9.9023670e-01, 0.0000000e+00, 0.0000000e+00,\n", + " 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", + " [3.8466230e-01, 3.2870498e-01, 2.8663272e-01, 0.0000000e+00,\n", + " 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", + " [1.0245054e-01, 8.3960545e-01, 7.0171729e-03, 5.0926875e-02,\n", + " 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", + " [2.0860654e-01, 5.2596384e-01, 5.7231013e-02, 1.3179399e-01,\n", + " 7.6404594e-02, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", + " [4.4360008e-02, 2.5677064e-01, 1.2710307e-02, 9.1593184e-02,\n", + " 4.3377329e-02, 5.5118859e-01, 0.0000000e+00, 0.0000000e+00],\n", + " [1.7478675e-02, 8.2577473e-01, 1.8446613e-04, 7.8453608e-03,\n", + " 9.8771520e-04, 1.8481736e-03, 1.4588076e-01, 0.0000000e+00],\n", + " [8.8755831e-02, 4.8783898e-01, 1.2247147e-02, 6.4910784e-02,\n", + " 2.6166208e-02, 3.7589636e-02, 2.3768106e-01, 4.4810358e-02]], dtype=float32),\n", + " (4, 8, 8))" + ] + }, + "execution_count": 234, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "wei[0], wei.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M5CvobiQ0pLr" + }, + "source": [ + "Notes:\n", + "- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.\n", + "- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.\n", + "- Each example across batch dimension is of course processed completely independently and never \"talk\" to each other\n", + "- In an \"encoder\" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a \"decoder\" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.\n", + "- \"self-attention\" just means that the keys and values are produced from the same source as queries. In \"cross-attention\", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)\n", + "- \"Scaled\" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4SNbLq5z3oBw", + "outputId": "4cb77edf-da2b-46a6-ae68-5d06e943dc4b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(4, 8, 8)\n" + ] + } + ], + "source": [ + "\n", + "# Create random key\n", + "key = random.PRNGKey(1337)\n", + "\n", + "# Create random query and key tensors\n", + "q = random.normal(key, (B, T, head_size))\n", + "k = random.normal(key, (B, T, head_size))\n", + "\n", + "# Compute the dot product between q and k and scale by head_size^-0.5\n", + "wei = jnp.matmul(q, k.transpose((0, 2, 1))) * head_size**-0.5\n", + "\n", + "# Print the shape of the output\n", + "print(wei.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Nl6I9n9IRTSo", + "outputId": "fc004765-f9fa-4582-f97a-e43bfa63a568" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(0.91944367, dtype=float32)" + ] + }, + "execution_count": 247, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "k.var()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "T1tQx7oeRvtc", + "outputId": "145aa77c-a7e9-47f1-c379-cafe5076de71" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(0.91944367, dtype=float32)" + ] + }, + "execution_count": 248, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q.var()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "MLb_odHU3iKM", + "outputId": "bddd798c-9177-4829-e152-24e4466414ca" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(2.3990507, dtype=float32)" + ] + }, + "execution_count": 249, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "wei.var()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JB82yzt44REI", + "outputId": "5ee64a9a-4a7b-4f8d-b0aa-96fece70b224" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([0.1924978 , 0.14260589, 0.23511736, 0.14260589, 0.287173 ], dtype=float32)" + ] + }, + "execution_count": 251, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nn.softmax(jnp.array([0.1, -0.2, 0.3, -0.2, 0.5]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Mpt8569BB9_f", + "outputId": "23bfc4ef-ab85-4e0e-d561-71a239f8abbe" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([0.03260834, 0.00295816, 0.16151018, 0.00295816, 0.79996514], dtype=float32)" + ] + }, + "execution_count": 252, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nn.softmax(jnp.array([0.1, -0.2, 0.3, -0.2, 0.5])*8) # gets too peaky, converges to one-hot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2Num7sX9CKOH", + "outputId": "b2509df4-cca4-4558-97c8-020ca5e4a50e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(32, 100)\n" + ] + } + ], + "source": [ + "class LayerNorm1d: # (used to be BatchNorm1d)\n", + " def __init__(self, dim, eps=1e-5, momentum=0.1):\n", + " self.eps = eps\n", + " self.gamma = jnp.ones((dim,))\n", + " self.beta = jnp.zeros((dim,))\n", + "\n", + " def __call__(self, x):\n", + " # Calculate the forward pass\n", + " xmean = x.mean(1, keepdims=True) # batch mean\n", + " xvar = x.var(1, keepdims=True) # batch variance\n", + " xhat = (x - xmean) / jnp.sqrt(xvar + self.eps) # normalize to unit variance\n", + " self.out = self.gamma * xhat + self.beta\n", + " return self.out\n", + "\n", + " def parameters(self):\n", + " return [self.gamma, self.beta]\n", + "\n", + "# Set the random key and seed\n", + "rng = random.PRNGKey(1337)\n", + "\n", + "# Create an instance of LayerNorm1d\n", + "module = LayerNorm1d(100)\n", + "\n", + "# Create a random input tensor\n", + "x = random.normal(rng, (32, 100)) # batch size 32 of 100-dimensional vectors\n", + "\n", + "# Apply the LayerNorm1d module to the input tensor\n", + "x = module(x)\n", + "\n", + "# Print the shape of the output tensor\n", + "print(x.shape)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "633T2cmnW1uk", + "outputId": "b635b977-7e18-44fc-dc0b-7008ab9d27b2" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(Array(-0.03567989, dtype=float32), Array(0.9839538, dtype=float32))" + ] + }, + "execution_count": 256, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x[:,0].mean(), x[:,0].std() # mean,std of one feature across all batch inputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LN9cK9BoXCYb", + "outputId": "398d3168-1251-4d6b-87a8-d3a2dfa49dc3" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(Array(-7.152557e-09, dtype=float32), Array(0.99999595, dtype=float32))" + ] + }, + "execution_count": 257, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x[0,:].mean(), x[0,:].std() # mean,std of a single input from the batch, of its features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dRJH6wM_XFfU" + }, + "outputs": [], + "source": [ + "# French to English translation example:\n", + "\n", + "# <--------- ENCODE ------------------><--------------- DECODE ----------------->\n", + "# les réseaux de neurones sont géniaux! neural networks are awesome!\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZcvKeBXoZFOY" + }, + "source": [ + "### Full finished code, for reference\n", + "\n", + "You may want to refer directly to the git repo instead though." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 360 + }, + "id": "hoelkOrFY8bN", + "outputId": "ab3e2d01-28e4-4abc-c4f4-9e00b0f50b6c" + }, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "module 'flax.linen' has no attribute 'Embedding'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBigramLanguageModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;31m# create a JAX optimizer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0;31m# each token directly reads off the logits for the next token from a lookup table\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 141\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtoken_embedding_table\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEmbedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_embd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 142\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mposition_embedding_table\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEmbedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mblock_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_embd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mblocks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSequential\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mBlock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_embd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_head\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn_head\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_layer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: module 'flax.linen' has no attribute 'Embedding'" + ] + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax import grad, jit, vmap, random\n", + "\n", + "# hyperparameters\n", + "batch_size = 16 # how many independent sequences will we process in parallel?\n", + "block_size = 32 # what is the maximum context length for predictions?\n", + "max_iters = 5000\n", + "eval_interval = 100\n", + "learning_rate = 1e-3\n", + "eval_iters = 200\n", + "n_embd = 64\n", + "n_head = 4\n", + "n_layer = 4\n", + "dropout = 0.0\n", + "# ------------\n", + "\n", + "\n", + "# random number generator\n", + "rng = random.PRNGKey(1337)\n", + "\n", + "# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n", + "with open('input.txt', 'r', encoding='utf-8') as f:\n", + " text = f.read()\n", + "\n", + "# here are all the unique characters that occur in this text\n", + "chars = sorted(list(set(text)))\n", + "vocab_size = len(chars)\n", + "# create a mapping from characters to integers\n", + "stoi = { ch:i for i,ch in enumerate(chars) }\n", + "itos = { i:ch for i,ch in enumerate(chars) }\n", + "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n", + "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n", + "\n", + "# Train and test splits\n", + "data = jnp.array(encode(text), dtype=jnp.int32)\n", + "n = int(0.9*len(data)) # first 90% will be train, rest val\n", + "train_data = data[:n]\n", + "val_data = data[n:]\n", + "\n", + "# data loading\n", + "def get_batch(split):\n", + " # generate a small batch of data of inputs x and targets y\n", + " data = train_data if split == 'train' else val_data\n", + " ix = random.randint(rng, (batch_size,), 0, len(data) - block_size)\n", + " x = jnp.stack([data[i:i+block_size] for i in ix])\n", + " y = jnp.stack([data[i+1:i+block_size+1] for i in ix])\n", + " # x, y = x.device_put(device), y.device_put(device)\n", + " return x, y\n", + "\n", + "@jax.jit\n", + "def estimate_loss():\n", + " out = {}\n", + " for split in ['train', 'val']:\n", + " losses = jnp.zeros(eval_iters)\n", + " for k in range(eval_iters):\n", + " X, Y = get_batch(split)\n", + " logits, loss = model.apply(params, X, Y)\n", + " losses = jax.ops.index_update(losses, k, loss)\n", + " out[split] = losses.mean()\n", + " return out\n", + "\n", + "class Head(nn.Module):\n", + " \"\"\" one head of self-attention \"\"\"\n", + "\n", + " def __init__(self, head_size):\n", + " super().__init__()\n", + " self.key = nn.Linear(n_embd, head_size, bias=False)\n", + " self.query = nn.Linear(n_embd, head_size, bias=False)\n", + " self.value = nn.Linear(n_embd, head_size, bias=False)\n", + " self.tril = jnp.tril(jnp.ones((block_size, block_size)))\n", + "\n", + " self.dropout = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x):\n", + " B,T,C = x.shape\n", + " k = self.key(x) # (B,T,C)\n", + " q = self.query(x) # (B,T,C)\n", + " # compute attention scores (\"affinities\")\n", + " wei = jnp.matmul(q, k.transpose((0,2,1))) * C**-0.5 # (B, T, T)\n", + " wei = jax.ops.index_update(wei, self.tril[:T, :T] == 0, -1e9) # (B, T, T)\n", + " wei = jax.nn.softmax(wei, axis=-1) # (B, T, T)\n", + " wei = self.dropout(wei)\n", + " # perform the weighted aggregation of the values\n", + " v = self.value(x) # (B,T,C)\n", + " out = jnp.matmul(wei, v) # (B, T, C)\n", + " return out\n", + "\n", + "class MultiHeadAttention(nn.Module):\n", + " \"\"\" multiple heads of self-attention in parallel \"\"\"\n", + "\n", + " def __init__(self, num_heads, head_size):\n", + " super().__init__()\n", + " self.heads = [Head(head_size) for _ in range(num_heads)]\n", + " self.proj = nn.Linear(n_embd, n_embd)\n", + " self.dropout = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x):\n", + " out = jnp.concatenate([h(x) for h in self.heads], axis=-1)\n", + " out = self.dropout(self.proj(out))\n", + " return out\n", + "\n", + "class FeedFoward(nn.Module):\n", + " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n", + "\n", + " def __init__(self, n_embd):\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(n_embd, 4 * n_embd),\n", + " nn.ReLU(),\n", + " nn.Linear(4 * n_embd, n_embd),\n", + " nn.Dropout(dropout),\n", + " )\n", + "\n", + " def forward(self, x):\n", + " return self.net(x)\n", + "\n", + "class Block(nn.Module):\n", + " \"\"\" Transformer block: communication followed by computation \"\"\"\n", + "\n", + " def __init__(self, n_embd, n_head):\n", + " # n_embd: embedding dimension, n_head: the number of heads we'd like\n", + " super().__init__()\n", + " head_size = n_embd // n_head\n", + " self.sa = MultiHeadAttention(n_head, head_size)\n", + " self.ffwd = FeedFoward(n_embd)\n", + " self.ln1 = nn.LayerNorm(n_embd)\n", + " self.ln2 = nn.LayerNorm(n_embd)\n", + "\n", + " def forward(self, x):\n", + " x = x + self.sa(self.ln1(x))\n", + " x = x + self.ffwd(self.ln2(x))\n", + " return x\n", + "\n", + "# super simple bigram model\n", + "class BigramLanguageModel(nn.Module):\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + " # each token directly reads off the logits for the next token from a lookup table\n", + " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n", + " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n", + " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n", + " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n", + " self.lm_head = nn.Linear(n_embd, vocab_size)\n", + "\n", + " def forward(self, idx, targets=None):\n", + " B, T = idx.shape\n", + "\n", + " # idx and targets are both (B,T) tensor of integers\n", + " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n", + " pos_emb = self.position_embedding_table(jnp.arange(T)) # (T,C)\n", + " x = tok_emb + pos_emb # (B,T,C)\n", + " x = self.blocks(x) # (B,T,C)\n", + " x = self.ln_f(x) # (B,T,C)\n", + " logits = self.lm_head(x) # (B,T,vocab_size)\n", + "\n", + " if targets is None:\n", + " loss = None\n", + " else:\n", + " B, T, C = logits.shape\n", + " logits = logits.reshape(B*T, C)\n", + " targets = targets.reshape(B*T)\n", + " loss = jnp.mean(jax.nn.log_softmax(logits, axis=-1)[jnp.arange(B*T), targets]) * -1\n", + "\n", + " return logits, loss\n", + "\n", + " def generate(self, idx, max_new_tokens):\n", + " # idx is (B, T) array of indices in the current context\n", + " for _ in range(max_new_tokens):\n", + " # crop idx to the last block_size tokens\n", + " idx_cond = idx[:, -block_size:]\n", + " # get the predictions\n", + " logits, loss = self(idx_cond)\n", + " # focus only on the last time step\n", + " logits = logits[:, -1, :] # becomes (B, C)\n", + " # apply softmax to get probabilities\n", + " probs = jax.nn.softmax(logits, axis=-1) # (B, C)\n", + " # sample from the distribution\n", + " idx_next = jax.random.categorical(key, probs, axis=-1) # (B, 1)\n", + " # append sampled index to the running sequence\n", + " idx = jnp.concatenate((idx, idx_next), axis=1) # (B, T+1)\n", + " return idx\n", + "\n", + "# Define the optimizer\n", + "tx = adam(learning_rate)\n", + "\n", + "# Initialize model parameters and optimizer state\n", + "rng = random.PRNGKey(1337)\n", + "params = model.init(rng, jnp.ones((B, T, C), jnp.int32))\n", + "opt_state = tx.init(params)\n", + "\n", + "# Loss function (assuming you have a batch of data: xb, yb)\n", + "def loss_fn(params, xb, yb):\n", + " print(xb, yb)\n", + " logits, loss = model.apply(params, xb, yb)\n", + " return loss\n", + "\n", + "# Update function for a single training step\n", + "@jax.jit\n", + "def update_step(params, opt_state, xb, yb):\n", + " loss, grads = jax.value_and_grad(loss_fn)(params, xb, yb)\n", + " updates, opt_state = tx.update(grads, opt_state, params)\n", + " new_params = optax.apply_updates(params, updates)\n", + " return new_params, opt_state, loss\n", + "\n", + "# Training loop (example)\n", + "batch_size = 32\n", + "for iter in range(max_iters):\n", + " # every once in a while evaluate the loss on train and val sets\n", + " if iter % eval_interval == 0 or iter == max_iters - 1:\n", + " losses = estimate_loss()\n", + " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n", + "\n", + " xb, yb = get_batch('train')\n", + " params, opt_state, loss = update_step(params, opt_state, xb, yb)\n", + " print(f\"Epoch {steps},Loss: {loss}\")\n", + "\n", + "model = BigramLanguageModel(vocab_size)\n", + "\n", + "# create a JAX optimizer\n", + "optimizer = jax.experimental.optimizers.adamw(learning_rate=learning_rate).create(model.parameters())\n", + "\n", + "for iter in range(max_iters):\n", + "\n", + " # every once in a while evaluate the loss on train and val sets\n", + " if iter % eval_interval == 0 or iter == max_iters - 1:\n", + " losses = estimate_loss()\n", + " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n", + "\n", + " # sample a batch of data\n", + " xb, yb = get_batch('train')\n", + "\n", + " # evaluate the loss\n", + " logits, loss = model(xb, yb)\n", + " gradients = grad(lambda params, x, y: model(x, y)[1])(optimizer.target, xb, yb)\n", + " optimizer = optimizer.apply_gradient(gradients)\n", + "\n", + "# generate from the model\n", + "context = jnp.zeros((1, 1), dtype=jnp.int32)\n", + "print(decode(model.generate(context, max_new_tokens=2000)[0].tolist()))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fjjvMifYZf7x" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [], + "toc_visible": true, + "include_colab_link": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file