From e704446a3e43e5bb197c83123fbff03994071ed2 Mon Sep 17 00:00:00 2001 From: Jin Xu Date: Wed, 3 Apr 2024 14:39:24 -0700 Subject: [PATCH] Delete colabs/llms directory --- colabs/llms /jax_gpt_dev_bigram.ipynb | 1678 ------------------------- 1 file changed, 1678 deletions(-) delete mode 100644 colabs/llms /jax_gpt_dev_bigram.ipynb diff --git a/colabs/llms /jax_gpt_dev_bigram.ipynb b/colabs/llms /jax_gpt_dev_bigram.ipynb deleted file mode 100644 index 7f6eeb4..0000000 --- a/colabs/llms /jax_gpt_dev_bigram.ipynb +++ /dev/null @@ -1,1678 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "source": [ - "\n", - "# Changed Andrej Karpathy's minGPT from pytorch to Jax code (see [Zero To Hero](https://karpathy.ai/zero-to-hero.html) video on GPT)." - ], - "metadata": { - "id": "vv5qRLYk-aNq" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NIbX_6V1ELk2" - }, - "source": [ - "## import" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "a6NK_zR1EN1I" - }, - "outputs": [], - "source": [ - "import time\n", - "import jax\n", - "from jax import device_put\n", - "import jax.numpy as jnp\n", - "from jax import lax\n", - "import jax.random as random\n", - "import jax.nn as jnn\n", - "from jax.nn.initializers import normal\n", - "\n", - "import flax.linen as nn\n", - "\n", - "import optax\n", - "from optax import adam" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wJpXpmjEYC_T" - }, - "source": [ - "## Building a GPT" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kqorRX1vF3KP" - }, - "source": [ - "### load data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "h5hjCcLDr2WC", - "outputId": "18e575dc-0dcc-4519-d97d-fd83c49a9e38" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "--2024-04-02 04:10:38-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n", - "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.109.133, ...\n", - "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 1115394 (1.1M) [text/plain]\n", - "Saving to: ‘input.txt’\n", - "\n", - "\rinput.txt 0%[ ] 0 --.-KB/s \rinput.txt 100%[===================>] 1.06M --.-KB/s in 0.04s \n", - "\n", - "2024-04-02 04:10:39 (25.5 MB/s) - ‘input.txt’ saved [1115394/1115394]\n", - "\n" - ] - } - ], - "source": [ - "# We always start with a dataset to train on. Let's download the tiny shakespeare dataset\n", - "!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "O6medjfRsLD9" - }, - "outputs": [], - "source": [ - "# read it in to inspect it\n", - "with open('input.txt', 'r', encoding='utf-8') as f:\n", - " text = f.read()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "6xWI_VyAsN8F", - "outputId": "c10c8eb5-547b-4357-b817-3820cc3b1df8" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "length of dataset in characters: 1115394\n" - ] - } - ], - "source": [ - "print(\"length of dataset in characters: \", len(text))" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "2c5V0FvqseE0", - "outputId": "6d410839-095a-4cb7-87f2-e4fd75200c03" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "First Citizen:\n", - "Before we proceed any further, hear me speak.\n", - "\n", - "All:\n", - "Speak, speak.\n", - "\n", - "First Citizen:\n", - "You are all resolved rather to die than to famish?\n", - "\n", - "All:\n", - "Resolved. resolved.\n", - "\n", - "First Citizen:\n", - "First, you know Caius Marcius is chief enemy to the people.\n", - "\n", - "All:\n", - "We know't, we know't.\n", - "\n", - "First Citizen:\n", - "Let us kill him, and we'll have corn at our own price.\n", - "Is't a verdict?\n", - "\n", - "All:\n", - "No more talking on't; let it be done: away, away!\n", - "\n", - "Second Citizen:\n", - "One word, good citizens.\n", - "\n", - "First Citizen:\n", - "We are accounted poor citizens, the patricians good.\n", - "What authority surfeits on would relieve us: if they\n", - "would yield us but the superfluity, while it were\n", - "wholesome, we might guess they relieved us humanely;\n", - "but they think we are too dear: the leanness that\n", - "afflicts us, the object of our misery, is as an\n", - "inventory to particularise their abundance; our\n", - "sufferance is a gain to them Let us revenge this with\n", - "our pikes, ere we become rakes: for the gods know I\n", - "speak this in hunger for bread, not in thirst for revenge.\n", - "\n", - "\n" - ] - } - ], - "source": [ - "# let's look at the first 1000 characters\n", - "print(text[:1000])" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0e-Rbyr8sfM8", - "outputId": "a34b8a27-8433-4613-8bf4-c4dee3596867" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n", - " !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n", - "65\n" - ] - } - ], - "source": [ - "# here are all the unique characters that occur in this text\n", - "chars = sorted(list(set(text)))\n", - "vocab_size = len(chars)\n", - "print(''.join(chars))\n", - "print(vocab_size)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "qFaINwGqD1Bm", - "outputId": "b7da73de-d04b-4019-e3f5-76e5980bc5a4" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "True" - ] - }, - "metadata": {}, - "execution_count": 8 - } - ], - "source": [ - "'!' in chars" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Yw1LKNCgwjj1", - "outputId": "c8d96823-504d-4465-b091-72e67be69666" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "[46, 47, 47, 1, 58, 46, 43, 56, 43]\n", - "hii there\n" - ] - } - ], - "source": [ - "# create a mapping from characters to integers\n", - "stoi = { ch:i for i,ch in enumerate(chars) }\n", - "itos = { i:ch for i,ch in enumerate(chars) }\n", - "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n", - "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n", - "\n", - "print(encode(\"hii there\"))\n", - "print(decode(encode(\"hii there\")))" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "YJb0OXPwzvqg", - "outputId": "0ab5bbf0-e783-4fe6-e022-cd12e910570c" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - ":8: UserWarning: Explicitly requested dtype requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n", - " data = jnp.array(encode(text), dtype=jnp.int64)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "(1115394,) int32\n", - "[18 47 56 57 58 1 15 47 58 47 64 43 52 10 0 14 43 44 53 56 43 1 61 43\n", - " 1 54 56 53 41 43 43 42 1 39 52 63 1 44 59 56 58 46 43 56 6 1 46 43\n", - " 39 56 1 51 43 1 57 54 43 39 49 8 0 0 13 50 50 10 0 31 54 43 39 49\n", - " 6 1 57 54 43 39 49 8 0 0 18 47 56 57 58 1 15 47 58 47 64 43 52 10\n", - " 0 37 53 59 1 39 56 43 1 39 50 50 1 56 43 57 53 50 60 43 42 1 56 39\n", - " 58 46 43 56 1 58 53 1 42 47 43 1 58 46 39 52 1 58 53 1 44 39 51 47\n", - " 57 46 12 0 0 13 50 50 10 0 30 43 57 53 50 60 43 42 8 1 56 43 57 53\n", - " 50 60 43 42 8 0 0 18 47 56 57 58 1 15 47 58 47 64 43 52 10 0 18 47\n", - " 56 57 58 6 1 63 53 59 1 49 52 53 61 1 15 39 47 59 57 1 25 39 56 41\n", - " 47 59 57 1 47 57 1 41 46 47 43 44 1 43 52 43 51 63 1 58 53 1 58 46\n", - " 43 1 54 43 53 54 50 43 8 0 0 13 50 50 10 0 35 43 1 49 52 53 61 5\n", - " 58 6 1 61 43 1 49 52 53 61 5 58 8 0 0 18 47 56 57 58 1 15 47 58\n", - " 47 64 43 52 10 0 24 43 58 1 59 57 1 49 47 50 50 1 46 47 51 6 1 39\n", - " 52 42 1 61 43 5 50 50 1 46 39 60 43 1 41 53 56 52 1 39 58 1 53 59\n", - " 56 1 53 61 52 1 54 56 47 41 43 8 0 21 57 5 58 1 39 1 60 43 56 42\n", - " 47 41 58 12 0 0 13 50 50 10 0 26 53 1 51 53 56 43 1 58 39 50 49 47\n", - " 52 45 1 53 52 5 58 11 1 50 43 58 1 47 58 1 40 43 1 42 53 52 43 10\n", - " 1 39 61 39 63 6 1 39 61 39 63 2 0 0 31 43 41 53 52 42 1 15 47 58\n", - " 47 64 43 52 10 0 27 52 43 1 61 53 56 42 6 1 45 53 53 42 1 41 47 58\n", - " 47 64 43 52 57 8 0 0 18 47 56 57 58 1 15 47 58 47 64 43 52 10 0 35\n", - " 43 1 39 56 43 1 39 41 41 53 59 52 58 43 42 1 54 53 53 56 1 41 47 58\n", - " 47 64 43 52 57 6 1 58 46 43 1 54 39 58 56 47 41 47 39 52 57 1 45 53\n", - " 53 42 8 0 35 46 39 58 1 39 59 58 46 53 56 47 58 63 1 57 59 56 44 43\n", - " 47 58 57 1 53 52 1 61 53 59 50 42 1 56 43 50 47 43 60 43 1 59 57 10\n", - " 1 47 44 1 58 46 43 63 0 61 53 59 50 42 1 63 47 43 50 42 1 59 57 1\n", - " 40 59 58 1 58 46 43 1 57 59 54 43 56 44 50 59 47 58 63 6 1 61 46 47\n", - " 50 43 1 47 58 1 61 43 56 43 0 61 46 53 50 43 57 53 51 43 6 1 61 43\n", - " 1 51 47 45 46 58 1 45 59 43 57 57 1 58 46 43 63 1 56 43 50 47 43 60\n", - " 43 42 1 59 57 1 46 59 51 39 52 43 50 63 11 0 40 59 58 1 58 46 43 63\n", - " 1 58 46 47 52 49 1 61 43 1 39 56 43 1 58 53 53 1 42 43 39 56 10 1\n", - " 58 46 43 1 50 43 39 52 52 43 57 57 1 58 46 39 58 0 39 44 44 50 47 41\n", - " 58 57 1 59 57 6 1 58 46 43 1 53 40 48 43 41 58 1 53 44 1 53 59 56\n", - " 1 51 47 57 43 56 63 6 1 47 57 1 39 57 1 39 52 0 47 52 60 43 52 58\n", - " 53 56 63 1 58 53 1 54 39 56 58 47 41 59 50 39 56 47 57 43 1 58 46 43\n", - " 47 56 1 39 40 59 52 42 39 52 41 43 11 1 53 59 56 0 57 59 44 44 43 56\n", - " 39 52 41 43 1 47 57 1 39 1 45 39 47 52 1 58 53 1 58 46 43 51 1 24\n", - " 43 58 1 59 57 1 56 43 60 43 52 45 43 1 58 46 47 57 1 61 47 58 46 0\n", - " 53 59 56 1 54 47 49 43 57 6 1 43 56 43 1 61 43 1 40 43 41 53 51 43\n", - " 1 56 39 49 43 57 10 1 44 53 56 1 58 46 43 1 45 53 42 57 1 49 52 53\n", - " 61 1 21 0 57 54 43 39 49 1 58 46 47 57 1 47 52 1 46 59 52 45 43 56\n", - " 1 44 53 56 1 40 56 43 39 42 6 1 52 53 58 1 47 52 1 58 46 47 56 57\n", - " 58 1 44 53 56 1 56 43 60 43 52 45 43 8 0 0]\n" - ] - } - ], - "source": [ - "# # let's now encode the entire text dataset and store it into a torch.Tensor\n", - "# import torch # we use PyTorch: https://pytorch.org\n", - "# data = torch.tensor(encode(text), dtype=torch.long)\n", - "# print(data.shape, data.dtype)\n", - "# print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this\n", - "\n", - "# Assuming `encode` is a function defined elsewhere\n", - "data = jnp.array(encode(text), dtype=jnp.int64)\n", - "print(data.shape, data.dtype)\n", - "print(data[:1000])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pBlRYpvhF9xj" - }, - "source": [ - "#### split into train and test" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "f_WIXqxz0lU5" - }, - "outputs": [], - "source": [ - "# Let's now split up the data into train and validation sets\n", - "n = int(0.9*len(data)) # first 90% will be train, rest val\n", - "train_data = data[:n]\n", - "val_data = data[n:]" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "bCAY6ej2DQUm", - "outputId": "6114a515-ebab-45be-fa5c-5f58c8560d29" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(111540,)" - ] - }, - "metadata": {}, - "execution_count": 12 - } - ], - "source": [ - "val_data.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "TD5Bj8Y6IAD4", - "outputId": "5087c2d6-c461-4dbe-989a-ff6fd49a9187" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "Array([18, 47, 56, 57, 58, 1, 15, 47, 58], dtype=int32)" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ], - "source": [ - "block_size = 8\n", - "train_data[:block_size+1]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nHbUA5zZGBqn" - }, - "source": [ - "#### build feature input x and target output y" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9HXDe8vGJCEn", - "outputId": "af9aab79-9594-47ac-8c33-62a92d7138c5" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "0\n", - "when input is [18] the target: 47\n", - "1\n", - "when input is [18 47] the target: 56\n", - "2\n", - "when input is [18 47 56] the target: 57\n", - "3\n", - "when input is [18 47 56 57] the target: 58\n", - "4\n", - "when input is [18 47 56 57 58] the target: 1\n", - "5\n", - "when input is [18 47 56 57 58 1] the target: 15\n", - "6\n", - "when input is [18 47 56 57 58 1 15] the target: 47\n", - "7\n", - "when input is [18 47 56 57 58 1 15 47] the target: 58\n" - ] - } - ], - "source": [ - "x = train_data[:block_size]\n", - "y = train_data[1:block_size+1]\n", - "for t in range(block_size):\n", - " print(t)\n", - " context = x[:t+1]\n", - " target = y[t]\n", - " print(f\"when input is {context} the target: {target}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Q3k1Czf7LuA9", - "outputId": "35e955e0-4d8e-48a2-a2cf-f7f0c341a925" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "inputs:\n", - "(4, 8)\n", - "[[ 1 51 39 52 58 50 43 0]\n", - " [25 17 27 10 0 27 6 1]\n", - " [47 51 8 0 14 59 58 1]\n", - " [ 1 57 59 41 46 1 50 43]]\n", - "targets:\n", - "(4, 8)\n", - "[[51 39 52 58 50 43 0 53]\n", - " [17 27 10 0 27 6 1 50]\n", - " [51 8 0 14 59 58 1 46]\n", - " [57 59 41 46 1 50 43 52]]\n", - "----\n", - "when input is [1] the target: 51\n", - "when input is [1, 51] the target: 39\n", - "when input is [1, 51, 39] the target: 52\n", - "when input is [1, 51, 39, 52] the target: 58\n", - "when input is [1, 51, 39, 52, 58] the target: 50\n", - "when input is [1, 51, 39, 52, 58, 50] the target: 43\n", - "when input is [1, 51, 39, 52, 58, 50, 43] the target: 0\n", - "when input is [1, 51, 39, 52, 58, 50, 43, 0] the target: 53\n", - "when input is [25] the target: 17\n", - "when input is [25, 17] the target: 27\n", - "when input is [25, 17, 27] the target: 10\n", - "when input is [25, 17, 27, 10] the target: 0\n", - "when input is [25, 17, 27, 10, 0] the target: 27\n", - "when input is [25, 17, 27, 10, 0, 27] the target: 6\n", - "when input is [25, 17, 27, 10, 0, 27, 6] the target: 1\n", - "when input is [25, 17, 27, 10, 0, 27, 6, 1] the target: 50\n", - "when input is [47] the target: 51\n", - "when input is [47, 51] the target: 8\n", - "when input is [47, 51, 8] the target: 0\n", - "when input is [47, 51, 8, 0] the target: 14\n", - "when input is [47, 51, 8, 0, 14] the target: 59\n", - "when input is [47, 51, 8, 0, 14, 59] the target: 58\n", - "when input is [47, 51, 8, 0, 14, 59, 58] the target: 1\n", - "when input is [47, 51, 8, 0, 14, 59, 58, 1] the target: 46\n", - "when input is [1] the target: 57\n", - "when input is [1, 57] the target: 59\n", - "when input is [1, 57, 59] the target: 41\n", - "when input is [1, 57, 59, 41] the target: 46\n", - "when input is [1, 57, 59, 41, 46] the target: 1\n", - "when input is [1, 57, 59, 41, 46, 1] the target: 50\n", - "when input is [1, 57, 59, 41, 46, 1, 50] the target: 43\n", - "when input is [1, 57, 59, 41, 46, 1, 50, 43] the target: 52\n" - ] - } - ], - "source": [ - "prng = jax.random.PRNGKey(1337)\n", - "batch_size = 4 # how many independent sequences will we process in parallel?\n", - "block_size = 8 # what is the maximum context length for predictions?\n", - "ix = random.randint(random.PRNGKey(0), (batch_size,), 0, len(data) - block_size)\n", - "\n", - "def get_batch(split, subkey):\n", - " # generate a small batch of data of inputs x and targets y\n", - " data = train_data if split == 'train' else val_data\n", - " ix = random.randint(subkey, (batch_size,), 0, len(data) - block_size)\n", - " # x = jnp.stack([data[i:i+block_size] for i in ix])\n", - " # y = jnp.stack([data[i+1:i+block_size+1] for i in ix])\n", - " # optimize the above code ^^.\n", - " # speed up by using dynamic_slice and vmap\n", - " def slice_data(i):\n", - " return jax.lax.dynamic_slice(data, (i,), (block_size,))\n", - "\n", - " x = jax.vmap(slice_data)(ix)\n", - " y = jax.vmap(slice_data)(ix+1)\n", - " x, y = device_put(x), device_put(y)\n", - " return x, y\n", - "\n", - "xb, yb = get_batch('train', prng)\n", - "print('inputs:')\n", - "print(xb.shape)\n", - "print(xb)\n", - "print('targets:')\n", - "print(yb.shape)\n", - "print(yb)\n", - "\n", - "print('----')\n", - "\n", - "for b in range(batch_size): # batch dimension\n", - " for t in range(block_size): # time dimension\n", - " context = xb[b, :t+1]\n", - " target = yb[b,t]\n", - " print(f\"when input is {context.tolist()} the target: {target}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "qpyyAeIzQjlO", - "outputId": "ecc878aa-2650-4cad-b32e-1d18c946d0a1" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "[[ 1 51 39 52 58 50 43 0]\n", - " [25 17 27 10 0 27 6 1]\n", - " [47 51 8 0 14 59 58 1]\n", - " [ 1 57 59 41 46 1 50 43]]\n" - ] - } - ], - "source": [ - "print(xb) # our input to the transformer" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yZF__SLaGJs9" - }, - "source": [ - "### Baseline model: Bigram" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "nql_1ER53oCf", - "outputId": "cab32afd-03fd-4d3b-fd94-74bcbe2a4af9" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "(32, 65)\n", - "4.1973763\n" - ] - } - ], - "source": [ - "class BigramLanguageModel(nn.Module):\n", - " vocab_size: int\n", - "\n", - " @nn.compact\n", - " def __call__(self, idx, targets=None):\n", - " # Token embedding table\n", - " embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)\n", - " logits = embedding_table(idx) # (B,T,C)\n", - "\n", - " if targets is None:\n", - " loss = None\n", - " else:\n", - " B, T, C = logits.shape\n", - " logits = logits.reshape(B*T, C)\n", - " targets = targets.reshape(B*T)\n", - " loss = -jnp.sum(jax.nn.one_hot(targets, C) * jax.nn.log_softmax(logits), axis=1).mean()\n", - "\n", - " return logits, loss\n", - "\n", - "# Example usage\n", - "model = BigramLanguageModel(vocab_size)\n", - "\n", - "# Initialize model parameters and optimizer\n", - "key = random.PRNGKey(1337)\n", - "params = model.init(key, jnp.ones((1, 1), jnp.int32))\n", - "\n", - "# jax jit the model apply to speed up\n", - "flax_apply_jitted = jax.jit(lambda params, xb, yb: model.apply(params, xb, yb))\n", - "\n", - "logits, loss = flax_apply_jitted(params, xb, yb)\n", - "print(loss)" - ] - }, - { - "cell_type": "markdown", - "source": [ - "#### training" - ], - "metadata": { - "id": "wXW3MFQqA7AD" - } - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": { - "id": "eTyJ8qAaDdiF", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "0ddc5e2f-273b-4bdb-c01e-1011e8aab9e5" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Tracedwith Tracedwith\n", - "Epoch 0,Loss: 4.176701545715332, sample batch: 0.006464958190917969, forward pass and backward grad: 0.3173556327819824\n", - "Epoch 1,Loss: 4.168148040771484, sample batch: 0.006571531295776367, forward pass and backward grad: 0.007528066635131836\n", - "Epoch 2,Loss: 4.176245212554932, sample batch: 0.006763458251953125, forward pass and backward grad: 0.007700681686401367\n", - "Epoch 3,Loss: 4.169358730316162, sample batch: 0.007287740707397461, forward pass and backward grad: 0.009604454040527344\n", - "Epoch 4,Loss: 4.175686836242676, sample batch: 0.009287357330322266, forward pass and backward grad: 0.010252714157104492\n", - "Epoch 5,Loss: 4.170393943786621, sample batch: 0.0073015689849853516, forward pass and backward grad: 0.008276939392089844\n", - "Epoch 6,Loss: 4.169349193572998, sample batch: 0.006832599639892578, forward pass and backward grad: 0.007781982421875\n", - "Epoch 7,Loss: 4.162716388702393, sample batch: 0.007829904556274414, forward pass and backward grad: 0.008838653564453125\n", - "Epoch 8,Loss: 4.157233238220215, sample batch: 0.007964611053466797, forward pass and backward grad: 0.008940458297729492\n", - "Epoch 9,Loss: 4.163629055023193, sample batch: 0.0069615840911865234, forward pass and backward grad: 0.007949113845825195\n", - "Epoch 10,Loss: 4.16217565536499, sample batch: 0.007250785827636719, forward pass and backward grad: 0.008380413055419922\n", - "Epoch 11,Loss: 4.1645026206970215, sample batch: 0.007141828536987305, forward pass and backward grad: 0.008124113082885742\n", - "Epoch 12,Loss: 4.1598124504089355, sample batch: 0.006840944290161133, forward pass and backward grad: 0.007810831069946289\n", - "Epoch 13,Loss: 4.157349586486816, sample batch: 0.006644248962402344, forward pass and backward grad: 0.0076143741607666016\n", - "Epoch 14,Loss: 4.163051128387451, sample batch: 0.007123470306396484, forward pass and backward grad: 0.008063554763793945\n", - "Epoch 15,Loss: 4.15546178817749, sample batch: 0.005999565124511719, forward pass and backward grad: 0.006939888000488281\n", - "Epoch 16,Loss: 4.154058456420898, sample batch: 0.005856513977050781, forward pass and backward grad: 0.0068018436431884766\n", - "Epoch 17,Loss: 4.14748477935791, sample batch: 0.010221004486083984, forward pass and backward grad: 0.011186599731445312\n", - "Epoch 18,Loss: 4.153148174285889, sample batch: 0.006284952163696289, forward pass and backward grad: 0.0072596073150634766\n", - "Epoch 19,Loss: 4.15305757522583, sample batch: 0.005934238433837891, forward pass and backward grad: 0.0069620609283447266\n", - "Epoch 20,Loss: 4.148075580596924, sample batch: 0.005886554718017578, forward pass and backward grad: 0.006838560104370117\n", - "Epoch 21,Loss: 4.1545729637146, sample batch: 0.006558895111083984, forward pass and backward grad: 0.007547140121459961\n", - "Epoch 22,Loss: 4.1454854011535645, sample batch: 0.0068204402923583984, forward pass and backward grad: 0.007781505584716797\n", - "Epoch 23,Loss: 4.146124839782715, sample batch: 0.006976604461669922, forward pass and backward grad: 0.007932424545288086\n", - "Epoch 24,Loss: 4.152686595916748, sample batch: 0.0068511962890625, forward pass and backward grad: 0.007791042327880859\n", - "Epoch 25,Loss: 4.148232460021973, sample batch: 0.012384414672851562, forward pass and backward grad: 0.013611793518066406\n", - "Epoch 26,Loss: 4.143317222595215, sample batch: 0.00710606575012207, forward pass and backward grad: 0.00807332992553711\n", - "Epoch 27,Loss: 4.135959625244141, sample batch: 0.0061304569244384766, forward pass and backward grad: 0.007124662399291992\n", - "Epoch 28,Loss: 4.133705139160156, sample batch: 0.00642704963684082, forward pass and backward grad: 0.007401466369628906\n", - "Epoch 29,Loss: 4.144227504730225, sample batch: 0.008374929428100586, forward pass and backward grad: 0.009456157684326172\n", - "Epoch 30,Loss: 4.137173652648926, sample batch: 0.007355928421020508, forward pass and backward grad: 0.008334636688232422\n", - "Epoch 31,Loss: 4.137599945068359, sample batch: 0.007583141326904297, forward pass and backward grad: 0.008576154708862305\n", - "Epoch 32,Loss: 4.1330671310424805, sample batch: 0.006922483444213867, forward pass and backward grad: 0.007891654968261719\n", - "Epoch 33,Loss: 4.140651702880859, sample batch: 0.00725555419921875, forward pass and backward grad: 0.008275270462036133\n", - "Epoch 34,Loss: 4.132898807525635, sample batch: 0.0066339969635009766, forward pass and backward grad: 0.00758814811706543\n", - "Epoch 35,Loss: 4.1298909187316895, sample batch: 0.00668644905090332, forward pass and backward grad: 0.007630586624145508\n", - "Epoch 36,Loss: 4.1282429695129395, sample batch: 0.006491661071777344, forward pass and backward grad: 0.007435798645019531\n", - "Epoch 37,Loss: 4.1404337882995605, sample batch: 0.006692171096801758, forward pass and backward grad: 0.0076482295989990234\n", - "Epoch 38,Loss: 4.124664783477783, sample batch: 0.007221221923828125, forward pass and backward grad: 0.008726119995117188\n", - "Epoch 39,Loss: 4.121993064880371, sample batch: 0.010175466537475586, forward pass and backward grad: 0.011160850524902344\n", - "Epoch 40,Loss: 4.131916522979736, sample batch: 0.0070574283599853516, forward pass and backward grad: 0.008001327514648438\n", - "Epoch 41,Loss: 4.121824741363525, sample batch: 0.006624698638916016, forward pass and backward grad: 0.007588624954223633\n", - "Epoch 42,Loss: 4.124565124511719, sample batch: 0.007647991180419922, forward pass and backward grad: 0.008681535720825195\n", - "Epoch 43,Loss: 4.122013568878174, sample batch: 0.006825923919677734, forward pass and backward grad: 0.0077838897705078125\n", - "Epoch 44,Loss: 4.124299049377441, sample batch: 0.00981760025024414, forward pass and backward grad: 0.010839700698852539\n", - "Epoch 45,Loss: 4.116213798522949, sample batch: 0.00638580322265625, forward pass and backward grad: 0.008318901062011719\n", - "Epoch 46,Loss: 4.118075847625732, sample batch: 0.00923776626586914, forward pass and backward grad: 0.010228395462036133\n", - "Epoch 47,Loss: 4.114686489105225, sample batch: 0.007014036178588867, forward pass and backward grad: 0.007985591888427734\n", - "Epoch 48,Loss: 4.117861270904541, sample batch: 0.007818460464477539, forward pass and backward grad: 0.008879899978637695\n", - "Epoch 49,Loss: 4.111810207366943, sample batch: 0.015448570251464844, forward pass and backward grad: 0.016459941864013672\n", - "Epoch 50,Loss: 4.1100616455078125, sample batch: 0.00766754150390625, forward pass and backward grad: 0.00865626335144043\n", - "Epoch 51,Loss: 4.106339931488037, sample batch: 0.0077228546142578125, forward pass and backward grad: 0.008730649948120117\n", - "Epoch 52,Loss: 4.106321811676025, sample batch: 0.00809621810913086, forward pass and backward grad: 0.009130001068115234\n", - "Epoch 53,Loss: 4.105854034423828, sample batch: 0.007456302642822266, forward pass and backward grad: 0.008474588394165039\n", - "Epoch 54,Loss: 4.116119384765625, sample batch: 0.0074770450592041016, forward pass and backward grad: 0.008490324020385742\n", - "Epoch 55,Loss: 4.107293605804443, sample batch: 0.008399009704589844, forward pass and backward grad: 0.009408950805664062\n", - "Epoch 56,Loss: 4.112220764160156, sample batch: 0.007843494415283203, forward pass and backward grad: 0.008843421936035156\n", - "Epoch 57,Loss: 4.106284141540527, sample batch: 0.007303953170776367, forward pass and backward grad: 0.008324861526489258\n", - "Epoch 58,Loss: 4.101660251617432, sample batch: 0.009732961654663086, forward pass and backward grad: 0.010820388793945312\n", - "Epoch 59,Loss: 4.100407123565674, sample batch: 0.0070569515228271484, forward pass and backward grad: 0.008041620254516602\n", - "Epoch 60,Loss: 4.098540306091309, sample batch: 0.007241725921630859, forward pass and backward grad: 0.00823354721069336\n", - "Epoch 61,Loss: 4.097070217132568, sample batch: 0.008341073989868164, forward pass and backward grad: 0.009309768676757812\n", - "Epoch 62,Loss: 4.0942816734313965, sample batch: 0.0075092315673828125, forward pass and backward grad: 0.008522748947143555\n", - "Epoch 63,Loss: 4.098483562469482, sample batch: 0.007592439651489258, forward pass and backward grad: 0.008698225021362305\n", - "Epoch 64,Loss: 4.098681449890137, sample batch: 0.007315397262573242, forward pass and backward grad: 0.008310079574584961\n", - "Epoch 65,Loss: 4.09865140914917, sample batch: 0.007418632507324219, forward pass and backward grad: 0.008446455001831055\n", - "Epoch 66,Loss: 4.09343957901001, sample batch: 0.008169412612915039, forward pass and backward grad: 0.009173154830932617\n", - "Epoch 67,Loss: 4.098697185516357, sample batch: 0.0086212158203125, forward pass and backward grad: 0.009929418563842773\n", - "Epoch 68,Loss: 4.089414596557617, sample batch: 0.008454084396362305, forward pass and backward grad: 0.009839296340942383\n", - "Epoch 69,Loss: 4.089560031890869, sample batch: 0.007783412933349609, forward pass and backward grad: 0.008769750595092773\n", - "Epoch 70,Loss: 4.0868353843688965, sample batch: 0.007253885269165039, forward pass and backward grad: 0.008234262466430664\n", - "Epoch 71,Loss: 4.0886335372924805, sample batch: 0.007701396942138672, forward pass and backward grad: 0.00868678092956543\n", - "Epoch 72,Loss: 4.08442497253418, sample batch: 0.013611316680908203, forward pass and backward grad: 0.014733552932739258\n", - "Epoch 73,Loss: 4.0852155685424805, sample batch: 0.007182121276855469, forward pass and backward grad: 0.008143424987792969\n", - "Epoch 74,Loss: 4.077157974243164, sample batch: 0.006951570510864258, forward pass and backward grad: 0.007899045944213867\n", - "Epoch 75,Loss: 4.083674430847168, sample batch: 0.006950855255126953, forward pass and backward grad: 0.00792551040649414\n", - "Epoch 76,Loss: 4.080029487609863, sample batch: 0.007283687591552734, forward pass and backward grad: 0.008302688598632812\n", - "Epoch 77,Loss: 4.075239658355713, sample batch: 0.006971836090087891, forward pass and backward grad: 0.008055925369262695\n", - "Epoch 78,Loss: 4.077165603637695, sample batch: 0.008187294006347656, forward pass and backward grad: 0.00914311408996582\n", - "Epoch 79,Loss: 4.082056522369385, sample batch: 0.007593870162963867, forward pass and backward grad: 0.009018421173095703\n", - "Epoch 80,Loss: 4.08469820022583, sample batch: 0.009556293487548828, forward pass and backward grad: 0.010587453842163086\n", - "Epoch 81,Loss: 4.074877738952637, sample batch: 0.007441520690917969, forward pass and backward grad: 0.008409261703491211\n", - "Epoch 82,Loss: 4.065997123718262, sample batch: 0.008413314819335938, forward pass and backward grad: 0.009436368942260742\n", - "Epoch 83,Loss: 4.072015762329102, sample batch: 0.009505033493041992, forward pass and backward grad: 0.010768413543701172\n", - "Epoch 84,Loss: 4.065585136413574, sample batch: 0.008172035217285156, forward pass and backward grad: 0.009321451187133789\n", - "Epoch 85,Loss: 4.070855617523193, sample batch: 0.007485151290893555, forward pass and backward grad: 0.008470535278320312\n", - "Epoch 86,Loss: 4.06818151473999, sample batch: 0.007288455963134766, forward pass and backward grad: 0.008320093154907227\n", - "Epoch 87,Loss: 4.072195529937744, sample batch: 0.007587909698486328, forward pass and backward grad: 0.008961677551269531\n", - "Epoch 88,Loss: 4.065299987792969, sample batch: 0.0156404972076416, forward pass and backward grad: 0.017002105712890625\n", - "Epoch 89,Loss: 4.059263229370117, sample batch: 0.008633852005004883, forward pass and backward grad: 0.009637594223022461\n", - "Epoch 90,Loss: 4.064813613891602, sample batch: 0.006797313690185547, forward pass and backward grad: 0.007784843444824219\n", - "Epoch 91,Loss: 4.0675811767578125, sample batch: 0.010977745056152344, forward pass and backward grad: 0.012456655502319336\n", - "Epoch 92,Loss: 4.062394142150879, sample batch: 0.012706518173217773, forward pass and backward grad: 0.014220237731933594\n", - "Epoch 93,Loss: 4.053857326507568, sample batch: 0.013547658920288086, forward pass and backward grad: 0.01579761505126953\n", - "Epoch 94,Loss: 4.056527614593506, sample batch: 0.00801706314086914, forward pass and backward grad: 0.009006500244140625\n", - "Epoch 95,Loss: 4.060898780822754, sample batch: 0.00725102424621582, forward pass and backward grad: 0.008263349533081055\n", - "Epoch 96,Loss: 4.057039737701416, sample batch: 0.00706171989440918, forward pass and backward grad: 0.008035659790039062\n", - "Epoch 97,Loss: 4.053262710571289, sample batch: 0.007007598876953125, forward pass and backward grad: 0.008001327514648438\n", - "Epoch 98,Loss: 4.054243564605713, sample batch: 0.008085012435913086, forward pass and backward grad: 0.00907278060913086\n", - "Epoch 99,Loss: 4.056581974029541, sample batch: 0.0069582462310791016, forward pass and backward grad: 0.007887125015258789\n" - ] - } - ], - "source": [ - "# Define the optimizer\n", - "learning_rate = 1e-3 # Adjust as needed\n", - "tx = adam(learning_rate)\n", - "\n", - "# Initialize model parameters and optimizer state\n", - "prng = random.PRNGKey(1337)\n", - "params = model.init(prng, jnp.ones((1, 1), jnp.int32))\n", - "opt_state = tx.init(params)\n", - "\n", - "# Loss function (assuming you have a batch of data: xb, yb)\n", - "def loss_fn(params, xb, yb):\n", - " print(xb, yb)\n", - " logits, loss = model.apply(params, xb, yb)\n", - " return loss\n", - "\n", - "# Update function for a single training step\n", - "@jax.jit\n", - "def update_step(params, opt_state, xb, yb):\n", - " loss, grads = jax.value_and_grad(loss_fn)(params, xb, yb)\n", - " updates, opt_state = tx.update(grads, opt_state, params)\n", - " new_params = optax.apply_updates(params, updates)\n", - " return new_params, opt_state, loss\n", - "\n", - "# Training loop (example)\n", - "batch_size = 32\n", - "for steps in range(100):\n", - " t1 = time.time()\n", - " prng, subkey = random.split(prng)\n", - " xb, yb = get_batch('train', subkey)\n", - " t2 = time.time()\n", - " params, opt_state, loss = update_step(params, opt_state, xb, yb)\n", - " print(f\"Epoch {steps},Loss: {loss}, sample batch: {t2-t1}, forward pass and backward grad: {time.time()-t1}\")\n", - "\n", - "flax_apply_jitted = jax.jit(lambda params, xb, yb: model.apply(jax.lax.stop_gradient(params), xb, yb))" - ] - }, - { - "cell_type": "markdown", - "source": [ - "#### inference\n", - "Slower than torch code\n", - " 1. logits[:, -1, :]\n", - " 2. random.categorical " - ], - "metadata": { - "id": "DwB7mg9tBt0D" - } - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "EcVIDWAZEtjN", - "outputId": "42718af7-305d-4216-99a3-d51504a2722b" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n", - "yD.P.e'wn,CZsvq gP-f$f&W3aypokkuSEz?Paw:YCj?M;x\n", - "pctpxMvdJMlTZrmCZhPRjYRJUfrgld,bqlwXxBlCHIWu'FYEBTwJrbX;b!HR'Fr;rI?&Nui3;woGFdW pAZYho3YO!hHPv:F3uMAHbG:slLyWXd;woxmBMTexUpY ZEP\n", - "tTk?BlWOP&ZP.zNS YjFV,OxrO?!$wNDsXCd;iM:c!elaw'uOPGCJJDBsSf,E.XguCoK-rJP-kybvHsxxwu,:i3UJgZbBMO;s:coPALGSTE-hJWOStcI3$VaeVYfJsTPqaqT-ebJqAWy\n", - "Ev:WFmCykXrvetkGbw-3-N!'oW\n", - "nKqi:FgOyU3XdQwNr gVItNvRo,JbtDAvcfHSKDkh.caNKrf CMrJIGs?lbiNDbgJg'cHB:rRwAuGq&UDPhOdnmc:&jU,ZCuG?mF.An-r,EMDfCHfITHsvztXPL U3iSE-dAsTxeqf??i\n", - "OUQfArTnZ.Hgv\n", - "CPU times: user 1.03 s, sys: 9 ms, total: 1.04 s\n", - "Wall time: 1.05 s\n" - ] - } - ], - "source": [ - "def generate(params, flax_apply_jitted, key, idx, max_new_tokens):\n", - " for _ in range(max_new_tokens):\n", - " logits, _ = flax_apply_jitted(params, idx, None)\n", - " logits = logits[:, -1, :] # (B, C)\n", - " key, subkey = random.split(key)\n", - " idx_next = random.categorical(subkey, logits)[:, None] # (B, 1)\n", - " idx = jnp.concatenate((idx, idx_next), axis=1) # (B, T+1)\n", - " return idx\n", - "\n", - "%time print(decode(generate(params, flax_apply_jitted, key, jnp.zeros((1, 1), jnp.int32), 500)[0].tolist()))" - ] - }, - { - "cell_type": "code", - "source": [ - "# Try speed up with stop gradient\n", - "t1=time.time()\n", - "print(decode(generate(jax.lax.stop_gradient(params), flax_apply_jitted, key, jnp.zeros((1, 1), jnp.int32), 100)[0].tolist()))\n", - "print('TIME total', time.time()-t1)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vfX1ymJPDoEF", - "outputId": "7d908357-0706-446d-bee4-cb84791872f9" - }, - "execution_count": 38, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n", - "yD.P.e'wn,CZsvq gP-f$f&W3aypokkuSEz?Paw:YCj?M;x\n", - "pctpxMvdJMlTZrmCZhPRjYRJUfrgld,bqlwXxBlCHIWu'FYEBTwJ\n", - "TIME total 0.3211863040924072\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yNc7jn8wy5Kx" - }, - "source": [ - "#### put together" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "1sBa-IFGy90b", - "outputId": "7b559e10-3b16-4ebd-e4f5-fbcdacaa5970" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "step 0: train loss 4.1747, val loss 4.1776\n", - "Epoch 0,Loss: 4.176701545715332\n", - "Epoch 1,Loss: 4.168148040771484\n", - "Epoch 2,Loss: 4.176245212554932\n", - "Epoch 3,Loss: 4.169358730316162\n", - "Epoch 4,Loss: 4.175686836242676\n", - "Epoch 5,Loss: 4.170393943786621\n", - "Epoch 6,Loss: 4.169349193572998\n", - "Epoch 7,Loss: 4.162716388702393\n", - "Epoch 8,Loss: 4.157233238220215\n", - "Epoch 9,Loss: 4.163629055023193\n", - "Epoch 10,Loss: 4.16217565536499\n", - "Epoch 11,Loss: 4.1645026206970215\n", - "Epoch 12,Loss: 4.1598124504089355\n", - "Epoch 13,Loss: 4.157349586486816\n", - "Epoch 14,Loss: 4.163051128387451\n", - "Epoch 15,Loss: 4.15546178817749\n", - "Epoch 16,Loss: 4.154058456420898\n", - "Epoch 17,Loss: 4.14748477935791\n", - "Epoch 18,Loss: 4.153148174285889\n", - "Epoch 19,Loss: 4.15305757522583\n", - "Epoch 20,Loss: 4.148075580596924\n", - "Epoch 21,Loss: 4.1545729637146\n", - "Epoch 22,Loss: 4.1454854011535645\n", - "Epoch 23,Loss: 4.146124839782715\n", - "Epoch 24,Loss: 4.152686595916748\n", - "Epoch 25,Loss: 4.148232460021973\n", - "Epoch 26,Loss: 4.143317222595215\n", - "Epoch 27,Loss: 4.135959625244141\n", - "Epoch 28,Loss: 4.133705139160156\n", - "Epoch 29,Loss: 4.144227504730225\n", - "Epoch 30,Loss: 4.137173652648926\n", - "Epoch 31,Loss: 4.137599945068359\n", - "Epoch 32,Loss: 4.1330671310424805\n", - "Epoch 33,Loss: 4.140651702880859\n", - "Epoch 34,Loss: 4.132898807525635\n", - "Epoch 35,Loss: 4.1298909187316895\n", - "Epoch 36,Loss: 4.1282429695129395\n", - "Epoch 37,Loss: 4.1404337882995605\n", - "Epoch 38,Loss: 4.124664783477783\n", - "Epoch 39,Loss: 4.121993064880371\n", - "Epoch 40,Loss: 4.131916522979736\n", - "Epoch 41,Loss: 4.121824741363525\n", - "Epoch 42,Loss: 4.124565124511719\n", - "Epoch 43,Loss: 4.122013568878174\n", - "Epoch 44,Loss: 4.124299049377441\n", - "Epoch 45,Loss: 4.116213798522949\n", - "Epoch 46,Loss: 4.118075847625732\n", - "Epoch 47,Loss: 4.114686489105225\n", - "Epoch 48,Loss: 4.117861270904541\n", - "Epoch 49,Loss: 4.111810207366943\n", - "Epoch 50,Loss: 4.1100616455078125\n", - "Epoch 51,Loss: 4.106339931488037\n", - "Epoch 52,Loss: 4.106321811676025\n", - "Epoch 53,Loss: 4.105854034423828\n", - "Epoch 54,Loss: 4.116119384765625\n", - "Epoch 55,Loss: 4.107293605804443\n", - "Epoch 56,Loss: 4.112220764160156\n", - "Epoch 57,Loss: 4.106284141540527\n", - "Epoch 58,Loss: 4.101660251617432\n", - "Epoch 59,Loss: 4.100407123565674\n", - "Epoch 60,Loss: 4.098540306091309\n", - "Epoch 61,Loss: 4.097070217132568\n", - "Epoch 62,Loss: 4.0942816734313965\n", - "Epoch 63,Loss: 4.098483562469482\n", - "Epoch 64,Loss: 4.098681449890137\n", - "Epoch 65,Loss: 4.09865140914917\n", - "Epoch 66,Loss: 4.09343957901001\n", - "Epoch 67,Loss: 4.098697185516357\n", - "Epoch 68,Loss: 4.089414596557617\n", - "Epoch 69,Loss: 4.089560031890869\n", - "Epoch 70,Loss: 4.0868353843688965\n", - "Epoch 71,Loss: 4.0886335372924805\n", - "Epoch 72,Loss: 4.08442497253418\n", - "Epoch 73,Loss: 4.0852155685424805\n", - "Epoch 74,Loss: 4.077157974243164\n", - "Epoch 75,Loss: 4.083674430847168\n", - "Epoch 76,Loss: 4.080029487609863\n", - "Epoch 77,Loss: 4.075239658355713\n", - "Epoch 78,Loss: 4.077165603637695\n", - "Epoch 79,Loss: 4.082056522369385\n", - "Epoch 80,Loss: 4.08469820022583\n", - "Epoch 81,Loss: 4.074877738952637\n", - "Epoch 82,Loss: 4.065997123718262\n", - "Epoch 83,Loss: 4.072015762329102\n", - "Epoch 84,Loss: 4.065585136413574\n", - "Epoch 85,Loss: 4.070855617523193\n", - "Epoch 86,Loss: 4.06818151473999\n", - "Epoch 87,Loss: 4.072195529937744\n", - "Epoch 88,Loss: 4.065299987792969\n", - "Epoch 89,Loss: 4.059263229370117\n", - "Epoch 90,Loss: 4.064813613891602\n", - "Epoch 91,Loss: 4.0675811767578125\n", - "Epoch 92,Loss: 4.062394142150879\n", - "Epoch 93,Loss: 4.053857326507568\n", - "Epoch 94,Loss: 4.056527614593506\n", - "Epoch 95,Loss: 4.060898780822754\n", - "Epoch 96,Loss: 4.057039737701416\n", - "Epoch 97,Loss: 4.053262710571289\n", - "Epoch 98,Loss: 4.054243564605713\n", - "Epoch 99,Loss: 4.056581974029541\n", - "step 100: train loss 4.0528, val loss 4.0548\n", - "Epoch 100,Loss: 4.046238899230957\n", - "Epoch 101,Loss: 4.049732685089111\n", - "Epoch 102,Loss: 4.0535478591918945\n", - "Epoch 103,Loss: 4.050480365753174\n", - "Epoch 104,Loss: 4.053426265716553\n", - "Epoch 105,Loss: 4.042150974273682\n", - "Epoch 106,Loss: 4.0471320152282715\n", - "Epoch 107,Loss: 4.044474124908447\n", - "Epoch 108,Loss: 4.040208339691162\n", - "Epoch 109,Loss: 4.046355247497559\n", - "Epoch 110,Loss: 4.039368629455566\n", - "Epoch 111,Loss: 4.04256010055542\n", - "Epoch 112,Loss: 4.035763740539551\n", - "Epoch 113,Loss: 4.035247325897217\n", - "Epoch 114,Loss: 4.035799026489258\n", - "Epoch 115,Loss: 4.029808521270752\n", - "Epoch 116,Loss: 4.034291744232178\n", - "Epoch 117,Loss: 4.03066349029541\n", - "Epoch 118,Loss: 4.036454200744629\n", - "Epoch 119,Loss: 4.026322364807129\n", - "Epoch 120,Loss: 4.024256229400635\n", - "Epoch 121,Loss: 4.022698402404785\n", - "Epoch 122,Loss: 4.022899150848389\n", - "Epoch 123,Loss: 4.025989532470703\n", - "Epoch 124,Loss: 4.022078990936279\n", - "Epoch 125,Loss: 4.024803638458252\n", - "Epoch 126,Loss: 4.015080451965332\n", - "Epoch 127,Loss: 4.0170817375183105\n", - "Epoch 128,Loss: 4.021418571472168\n", - "Epoch 129,Loss: 4.012678146362305\n", - "Epoch 130,Loss: 4.012955188751221\n", - "Epoch 131,Loss: 4.0082855224609375\n", - "Epoch 132,Loss: 4.011625289916992\n", - "Epoch 133,Loss: 4.0081305503845215\n", - "Epoch 134,Loss: 4.015928745269775\n", - "Epoch 135,Loss: 4.0120649337768555\n", - "Epoch 136,Loss: 4.00533390045166\n", - "Epoch 137,Loss: 4.013724327087402\n", - "Epoch 138,Loss: 4.004214763641357\n", - "Epoch 139,Loss: 4.007634162902832\n", - "Epoch 140,Loss: 3.9990668296813965\n", - "Epoch 141,Loss: 4.004575252532959\n", - "Epoch 142,Loss: 3.9970436096191406\n", - "Epoch 143,Loss: 4.000253677368164\n", - "Epoch 144,Loss: 3.993255615234375\n", - "Epoch 145,Loss: 4.004382610321045\n", - "Epoch 146,Loss: 3.994224786758423\n", - "Epoch 147,Loss: 3.99639892578125\n", - "Epoch 148,Loss: 3.9898176193237305\n", - "Epoch 149,Loss: 3.9965922832489014\n", - "Epoch 150,Loss: 3.9880197048187256\n", - "Epoch 151,Loss: 3.990727663040161\n", - "Epoch 152,Loss: 3.984743118286133\n", - "Epoch 153,Loss: 3.9868345260620117\n", - "Epoch 154,Loss: 3.9808623790740967\n", - "Epoch 155,Loss: 3.979079246520996\n", - "Epoch 156,Loss: 3.9874355792999268\n", - "Epoch 157,Loss: 3.9819841384887695\n", - "Epoch 158,Loss: 3.9719669818878174\n", - "Epoch 159,Loss: 3.9836819171905518\n", - "Epoch 160,Loss: 3.9836788177490234\n", - "Epoch 161,Loss: 3.9711480140686035\n", - "Epoch 162,Loss: 3.9807724952697754\n", - "Epoch 163,Loss: 3.974381446838379\n", - "Epoch 164,Loss: 3.975137948989868\n", - "Epoch 165,Loss: 3.970792293548584\n", - "Epoch 166,Loss: 3.9664454460144043\n", - "Epoch 167,Loss: 3.972529888153076\n", - "Epoch 168,Loss: 3.970104217529297\n", - "Epoch 169,Loss: 3.9673550128936768\n", - "Epoch 170,Loss: 3.974461793899536\n", - "Epoch 171,Loss: 3.966738700866699\n", - "Epoch 172,Loss: 3.968355178833008\n", - "Epoch 173,Loss: 3.961606740951538\n", - "Epoch 174,Loss: 3.9639313220977783\n", - "Epoch 175,Loss: 3.9576148986816406\n", - "Epoch 176,Loss: 3.966170072555542\n", - "Epoch 177,Loss: 3.9580230712890625\n", - "Epoch 178,Loss: 3.9642348289489746\n", - "Epoch 179,Loss: 3.9687607288360596\n", - "Epoch 180,Loss: 3.9539127349853516\n", - "Epoch 181,Loss: 3.9508092403411865\n", - "Epoch 182,Loss: 3.964644193649292\n", - "Epoch 183,Loss: 3.9514272212982178\n", - "Epoch 184,Loss: 3.9511210918426514\n", - "Epoch 185,Loss: 3.953274965286255\n", - "Epoch 186,Loss: 3.9493982791900635\n", - "Epoch 187,Loss: 3.9452455043792725\n", - "Epoch 188,Loss: 3.9493496417999268\n", - "Epoch 189,Loss: 3.9497110843658447\n", - "Epoch 190,Loss: 3.934187650680542\n", - "Epoch 191,Loss: 3.942565679550171\n", - "Epoch 192,Loss: 3.939657211303711\n", - "Epoch 193,Loss: 3.950552225112915\n", - "Epoch 194,Loss: 3.9368786811828613\n", - "Epoch 195,Loss: 3.938438892364502\n", - "Epoch 196,Loss: 3.9386792182922363\n", - "Epoch 197,Loss: 3.937180519104004\n", - "Epoch 198,Loss: 3.9317572116851807\n", - "Epoch 199,Loss: 3.938286781311035\n", - "step 200: train loss 3.9314, val loss 3.9342\n", - "Epoch 200,Loss: 3.931473731994629\n", - "Epoch 201,Loss: 3.9273831844329834\n", - "Epoch 202,Loss: 3.9355266094207764\n", - "Epoch 203,Loss: 3.9317986965179443\n", - "Epoch 204,Loss: 3.9241442680358887\n", - "Epoch 205,Loss: 3.927114963531494\n", - "Epoch 206,Loss: 3.927417516708374\n", - "Epoch 207,Loss: 3.918044090270996\n", - "Epoch 208,Loss: 3.9226672649383545\n", - "Epoch 209,Loss: 3.9170031547546387\n", - "Epoch 210,Loss: 3.9196341037750244\n", - "Epoch 211,Loss: 3.9205477237701416\n", - "Epoch 212,Loss: 3.9187772274017334\n", - "Epoch 213,Loss: 3.914111852645874\n", - "Epoch 214,Loss: 3.915701389312744\n", - "Epoch 215,Loss: 3.912682056427002\n", - "Epoch 216,Loss: 3.9132320880889893\n", - "Epoch 217,Loss: 3.9093964099884033\n", - "Epoch 218,Loss: 3.910304546356201\n", - "Epoch 219,Loss: 3.902721643447876\n", - "Epoch 220,Loss: 3.9013726711273193\n", - "Epoch 221,Loss: 3.9074864387512207\n", - "Epoch 222,Loss: 3.9011595249176025\n", - "Epoch 223,Loss: 3.9110357761383057\n", - "Epoch 224,Loss: 3.919271469116211\n", - "Epoch 225,Loss: 3.8986053466796875\n", - "Epoch 226,Loss: 3.913900375366211\n", - "Epoch 227,Loss: 3.9005937576293945\n", - "Epoch 228,Loss: 3.894559383392334\n", - "Epoch 229,Loss: 3.9020140171051025\n", - "Epoch 230,Loss: 3.9012680053710938\n", - "Epoch 231,Loss: 3.890580177307129\n", - "Epoch 232,Loss: 3.901439666748047\n", - "Epoch 233,Loss: 3.8912415504455566\n", - "Epoch 234,Loss: 3.8939270973205566\n", - "Epoch 235,Loss: 3.8888893127441406\n", - "Epoch 236,Loss: 3.8912696838378906\n", - "Epoch 237,Loss: 3.8963351249694824\n", - "Epoch 238,Loss: 3.8835411071777344\n", - "Epoch 239,Loss: 3.8910605907440186\n", - "Epoch 240,Loss: 3.891777992248535\n", - "Epoch 241,Loss: 3.8892035484313965\n", - "Epoch 242,Loss: 3.8773059844970703\n", - "Epoch 243,Loss: 3.883334159851074\n", - "Epoch 244,Loss: 3.8824222087860107\n", - "Epoch 245,Loss: 3.8792808055877686\n", - "Epoch 246,Loss: 3.8650636672973633\n", - "Epoch 247,Loss: 3.882506847381592\n", - "Epoch 248,Loss: 3.8686697483062744\n", - "Epoch 249,Loss: 3.871563196182251\n", - "Epoch 250,Loss: 3.8801050186157227\n", - "Epoch 251,Loss: 3.876453161239624\n", - "Epoch 252,Loss: 3.878538131713867\n", - "Epoch 253,Loss: 3.8688294887542725\n", - "Epoch 254,Loss: 3.869239568710327\n", - "Epoch 255,Loss: 3.8656933307647705\n", - "Epoch 256,Loss: 3.8676295280456543\n", - "Epoch 257,Loss: 3.868959903717041\n", - "Epoch 258,Loss: 3.860564947128296\n", - "Epoch 259,Loss: 3.8659508228302\n", - "Epoch 260,Loss: 3.870162010192871\n", - "Epoch 261,Loss: 3.862337112426758\n", - "Epoch 262,Loss: 3.8631129264831543\n", - "Epoch 263,Loss: 3.8556807041168213\n", - "Epoch 264,Loss: 3.859565496444702\n", - "Epoch 265,Loss: 3.8612101078033447\n", - "Epoch 266,Loss: 3.860574245452881\n", - "Epoch 267,Loss: 3.863940954208374\n", - "Epoch 268,Loss: 3.846076726913452\n", - "Epoch 269,Loss: 3.858536958694458\n", - "Epoch 270,Loss: 3.8503029346466064\n", - "Epoch 271,Loss: 3.857239007949829\n", - "Epoch 272,Loss: 3.853991746902466\n", - "Epoch 273,Loss: 3.8525960445404053\n", - "Epoch 274,Loss: 3.8553481101989746\n", - "Epoch 275,Loss: 3.845716953277588\n", - "Epoch 276,Loss: 3.846571207046509\n", - "Epoch 277,Loss: 3.8519811630249023\n", - "Epoch 278,Loss: 3.8430190086364746\n", - "Epoch 279,Loss: 3.8463683128356934\n", - "Epoch 280,Loss: 3.839120864868164\n", - "Epoch 281,Loss: 3.8422346115112305\n", - "Epoch 282,Loss: 3.8394320011138916\n", - "Epoch 283,Loss: 3.8440675735473633\n", - "Epoch 284,Loss: 3.8503305912017822\n", - "Epoch 285,Loss: 3.838545322418213\n", - "Epoch 286,Loss: 3.829591751098633\n", - "Epoch 287,Loss: 3.8370203971862793\n", - "Epoch 288,Loss: 3.8436241149902344\n", - "Epoch 289,Loss: 3.8240110874176025\n", - "Epoch 290,Loss: 3.83495831489563\n", - "Epoch 291,Loss: 3.8333072662353516\n", - "Epoch 292,Loss: 3.833137035369873\n", - "Epoch 293,Loss: 3.8384203910827637\n", - "Epoch 294,Loss: 3.840672731399536\n", - "Epoch 295,Loss: 3.83957576751709\n", - "Epoch 296,Loss: 3.8232216835021973\n", - "Epoch 297,Loss: 3.8206679821014404\n", - "Epoch 298,Loss: 3.8219101428985596\n", - "Epoch 299,Loss: 3.834547758102417\n", - "step 300: train loss 3.8211, val loss 3.8261\n", - "Epoch 300,Loss: 3.8192684650421143\n", - "Epoch 301,Loss: 3.8214690685272217\n", - "Epoch 302,Loss: 3.8147058486938477\n", - "Epoch 303,Loss: 3.8182125091552734\n", - "Epoch 304,Loss: 3.822451114654541\n", - "Epoch 305,Loss: 3.8188161849975586\n", - "Epoch 306,Loss: 3.8196117877960205\n", - "Epoch 307,Loss: 3.814863681793213\n", - "Epoch 308,Loss: 3.812788724899292\n", - "Epoch 309,Loss: 3.799957036972046\n", - "Epoch 310,Loss: 3.8292646408081055\n", - "Epoch 311,Loss: 3.8186419010162354\n", - "Epoch 312,Loss: 3.8149664402008057\n", - "Epoch 313,Loss: 3.805006742477417\n", - "Epoch 314,Loss: 3.801513195037842\n", - "Epoch 315,Loss: 3.79679536819458\n", - "Epoch 316,Loss: 3.79837703704834\n", - "Epoch 317,Loss: 3.7946202754974365\n", - "Epoch 318,Loss: 3.812180995941162\n", - "Epoch 319,Loss: 3.8108162879943848\n", - "Epoch 320,Loss: 3.790243625640869\n", - "Epoch 321,Loss: 3.8042056560516357\n", - "Epoch 322,Loss: 3.7979207038879395\n", - "Epoch 323,Loss: 3.790651321411133\n", - "Epoch 324,Loss: 3.7874715328216553\n", - "Epoch 325,Loss: 3.79123854637146\n", - "Epoch 326,Loss: 3.794536828994751\n", - "Epoch 327,Loss: 3.784797191619873\n", - "Epoch 328,Loss: 3.790224552154541\n", - "Epoch 329,Loss: 3.782738208770752\n", - "Epoch 330,Loss: 3.794175148010254\n", - "Epoch 331,Loss: 3.786827564239502\n", - "Epoch 332,Loss: 3.7888123989105225\n", - "Epoch 333,Loss: 3.783346176147461\n", - "Epoch 334,Loss: 3.793034076690674\n", - "Epoch 335,Loss: 3.7874460220336914\n", - "Epoch 336,Loss: 3.7902700901031494\n", - "Epoch 337,Loss: 3.778681755065918\n", - "Epoch 338,Loss: 3.7811100482940674\n", - "Epoch 339,Loss: 3.76914644241333\n", - "Epoch 340,Loss: 3.789398193359375\n", - "Epoch 341,Loss: 3.7791271209716797\n", - "Epoch 342,Loss: 3.770552635192871\n", - "Epoch 343,Loss: 3.766422986984253\n", - "Epoch 344,Loss: 3.7771592140197754\n", - "Epoch 345,Loss: 3.7636663913726807\n", - "Epoch 346,Loss: 3.787656545639038\n", - "Epoch 347,Loss: 3.764472723007202\n", - "Epoch 348,Loss: 3.7605934143066406\n", - "Epoch 349,Loss: 3.7653210163116455\n", - "Epoch 350,Loss: 3.768303155899048\n", - "Epoch 351,Loss: 3.7691171169281006\n", - "Epoch 352,Loss: 3.7605278491973877\n", - "Epoch 353,Loss: 3.771134614944458\n", - "Epoch 354,Loss: 3.766357183456421\n", - "Epoch 355,Loss: 3.7620556354522705\n", - "Epoch 356,Loss: 3.763490915298462\n", - "Epoch 357,Loss: 3.7664549350738525\n", - "Epoch 358,Loss: 3.7606000900268555\n", - "Epoch 359,Loss: 3.7632651329040527\n", - "Epoch 360,Loss: 3.745481252670288\n", - "Epoch 361,Loss: 3.7506914138793945\n", - "Epoch 362,Loss: 3.7509055137634277\n", - "Epoch 363,Loss: 3.758342981338501\n", - "Epoch 364,Loss: 3.7418177127838135\n", - "Epoch 365,Loss: 3.7498388290405273\n", - "Epoch 366,Loss: 3.75475811958313\n", - "Epoch 367,Loss: 3.7491402626037598\n", - "Epoch 368,Loss: 3.749390125274658\n", - "Epoch 369,Loss: 3.74155330657959\n", - "Epoch 370,Loss: 3.7474842071533203\n", - "Epoch 371,Loss: 3.75659441947937\n", - "Epoch 372,Loss: 3.7347867488861084\n", - "Epoch 373,Loss: 3.73516845703125\n", - "Epoch 374,Loss: 3.736833333969116\n", - "Epoch 375,Loss: 3.7528369426727295\n", - "Epoch 376,Loss: 3.7481980323791504\n", - "Epoch 377,Loss: 3.7436063289642334\n", - "Epoch 378,Loss: 3.7425522804260254\n", - "Epoch 379,Loss: 3.743187189102173\n", - "Epoch 380,Loss: 3.7335641384124756\n", - "Epoch 381,Loss: 3.7418203353881836\n", - "Epoch 382,Loss: 3.7498104572296143\n", - "Epoch 383,Loss: 3.730151891708374\n", - "Epoch 384,Loss: 3.735663414001465\n", - "Epoch 385,Loss: 3.725689649581909\n", - "Epoch 386,Loss: 3.7350897789001465\n", - "Epoch 387,Loss: 3.7252840995788574\n", - "Epoch 388,Loss: 3.7173573970794678\n", - "Epoch 389,Loss: 3.72982120513916\n", - "Epoch 390,Loss: 3.728365898132324\n", - "Epoch 391,Loss: 3.7212765216827393\n", - "Epoch 392,Loss: 3.7305972576141357\n", - "Epoch 393,Loss: 3.720611572265625\n", - "Epoch 394,Loss: 3.7320303916931152\n", - "Epoch 395,Loss: 3.726435661315918\n", - "Epoch 396,Loss: 3.7120232582092285\n", - "Epoch 397,Loss: 3.721881151199341\n", - "Epoch 398,Loss: 3.7123491764068604\n", - "Epoch 399,Loss: 3.7059667110443115\n", - "step 400: train loss 3.7182, val loss 3.7246\n", - "Epoch 400,Loss: 3.715552806854248\n", - "Epoch 401,Loss: 3.70719575881958\n", - "Epoch 402,Loss: 3.717360258102417\n", - "Epoch 403,Loss: 3.7229161262512207\n", - "Epoch 404,Loss: 3.7143988609313965\n", - "Epoch 405,Loss: 3.7218456268310547\n", - "Epoch 406,Loss: 3.707012414932251\n", - "Epoch 407,Loss: 3.7073845863342285\n", - "Epoch 408,Loss: 3.7069833278656006\n", - "Epoch 409,Loss: 3.7150471210479736\n", - "Epoch 410,Loss: 3.709181547164917\n", - "Epoch 411,Loss: 3.7069830894470215\n", - "Epoch 412,Loss: 3.7000956535339355\n", - "Epoch 413,Loss: 3.6996009349823\n", - "Epoch 414,Loss: 3.709993362426758\n", - "Epoch 415,Loss: 3.6976137161254883\n", - "Epoch 416,Loss: 3.69256329536438\n", - "Epoch 417,Loss: 3.6903862953186035\n", - "Epoch 418,Loss: 3.6960389614105225\n", - "Epoch 419,Loss: 3.6840765476226807\n", - "Epoch 420,Loss: 3.6853392124176025\n", - "Epoch 421,Loss: 3.693702459335327\n", - "Epoch 422,Loss: 3.692162036895752\n", - "Epoch 423,Loss: 3.697707176208496\n", - "Epoch 424,Loss: 3.683466911315918\n", - "Epoch 425,Loss: 3.7038679122924805\n", - "Epoch 426,Loss: 3.700803279876709\n", - "Epoch 427,Loss: 3.678903818130493\n", - "Epoch 428,Loss: 3.6846554279327393\n", - "Epoch 429,Loss: 3.696505308151245\n", - "Epoch 430,Loss: 3.6857666969299316\n", - "Epoch 431,Loss: 3.700955867767334\n", - "Epoch 432,Loss: 3.6856634616851807\n", - "Epoch 433,Loss: 3.6878204345703125\n", - "Epoch 434,Loss: 3.6669888496398926\n", - "Epoch 435,Loss: 3.681831121444702\n", - "Epoch 436,Loss: 3.6712563037872314\n", - "Epoch 437,Loss: 3.674145460128784\n", - "Epoch 438,Loss: 3.6812634468078613\n", - "Epoch 439,Loss: 3.6820902824401855\n", - "Epoch 440,Loss: 3.682633638381958\n", - "Epoch 441,Loss: 3.674262762069702\n", - "Epoch 442,Loss: 3.6762478351593018\n", - "Epoch 443,Loss: 3.6800148487091064\n", - "Epoch 444,Loss: 3.6821205615997314\n", - "Epoch 445,Loss: 3.679013729095459\n", - "Epoch 446,Loss: 3.663222551345825\n", - "Epoch 447,Loss: 3.674647808074951\n", - "Epoch 448,Loss: 3.674830675125122\n", - "Epoch 449,Loss: 3.6704530715942383\n", - "Epoch 450,Loss: 3.6534314155578613\n", - "Epoch 451,Loss: 3.67617130279541\n", - "Epoch 452,Loss: 3.6654410362243652\n", - "Epoch 453,Loss: 3.6626992225646973\n", - "Epoch 454,Loss: 3.6643929481506348\n", - "Epoch 455,Loss: 3.663473129272461\n", - "Epoch 456,Loss: 3.652484893798828\n", - "Epoch 457,Loss: 3.657602071762085\n", - "Epoch 458,Loss: 3.659010171890259\n", - "Epoch 459,Loss: 3.6692850589752197\n", - "Epoch 460,Loss: 3.6528565883636475\n", - "Epoch 461,Loss: 3.661115884780884\n", - "Epoch 462,Loss: 3.6511964797973633\n", - "Epoch 463,Loss: 3.6498470306396484\n", - "Epoch 464,Loss: 3.6548194885253906\n", - "Epoch 465,Loss: 3.6427958011627197\n", - "Epoch 466,Loss: 3.640429973602295\n", - "Epoch 467,Loss: 3.6433045864105225\n", - "Epoch 468,Loss: 3.6328577995300293\n", - "Epoch 469,Loss: 3.6472063064575195\n", - "Epoch 470,Loss: 3.6393685340881348\n", - "Epoch 471,Loss: 3.643303155899048\n", - "Epoch 472,Loss: 3.635701894760132\n", - "Epoch 473,Loss: 3.6435654163360596\n", - "Epoch 474,Loss: 3.6412038803100586\n", - "Epoch 475,Loss: 3.6328794956207275\n", - "Epoch 476,Loss: 3.6483938694000244\n", - "Epoch 477,Loss: 3.6373813152313232\n", - "Epoch 478,Loss: 3.645163059234619\n", - "Epoch 479,Loss: 3.650409460067749\n", - "Epoch 480,Loss: 3.6378026008605957\n", - "Epoch 481,Loss: 3.651862382888794\n", - "Epoch 482,Loss: 3.618602752685547\n", - "Epoch 483,Loss: 3.653233528137207\n", - "Epoch 484,Loss: 3.6276497840881348\n", - "Epoch 485,Loss: 3.6483981609344482\n", - "Epoch 486,Loss: 3.6308512687683105\n", - "Epoch 487,Loss: 3.6225945949554443\n", - "Epoch 488,Loss: 3.6241683959960938\n", - "Epoch 489,Loss: 3.625032901763916\n", - "Epoch 490,Loss: 3.614368438720703\n", - "Epoch 491,Loss: 3.6298270225524902\n", - "Epoch 492,Loss: 3.638739585876465\n", - "Epoch 493,Loss: 3.6167731285095215\n", - "Epoch 494,Loss: 3.6361453533172607\n", - "Epoch 495,Loss: 3.626582622528076\n", - "Epoch 496,Loss: 3.6187684535980225\n", - "Epoch 497,Loss: 3.625209331512451\n", - "Epoch 498,Loss: 3.621763229370117\n", - "step 499: train loss 3.6193, val loss 3.6217\n", - "Epoch 499,Loss: 3.614203929901123\n", - "TIME jax 5.985821008682251\n" - ] - } - ], - "source": [ - "prng = jax.random.PRNGKey(1337)\n", - "\n", - "# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n", - "with open('input.txt', 'r', encoding='utf-8') as f:\n", - " text = f.read()\n", - "\n", - "# hyperparameters\n", - "batch_size = 16 # how many independent sequences will we process in parallel?\n", - "block_size = 32 # what is the maximum context length for predictions?\n", - "max_iters = 500\n", - "eval_interval = 100\n", - "learning_rate = 1e-3\n", - "eval_iters = 10\n", - "n_embd = 64\n", - "n_head = 4\n", - "n_layer = 4\n", - "dropout = 0.0\n", - "\n", - "\n", - "# here are all the unique characters that occur in this text\n", - "chars = sorted(list(set(text)))\n", - "vocab_size = len(chars)\n", - "# create a mapping from characters to integers\n", - "stoi = { ch:i for i,ch in enumerate(chars) }\n", - "itos = { i:ch for i,ch in enumerate(chars) }\n", - "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n", - "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n", - "\n", - "# Train and test splits\n", - "data = jnp.array(encode(text), dtype=jnp.int32)\n", - "data = device_put(data)\n", - "n = int(0.9*len(data)) # first 90% will be train, rest val\n", - "train_data = data[:n]\n", - "val_data = data[n:]\n", - "\n", - "def get_batch(split, subkey):\n", - " # generate a small batch of data of inputs x and targets y\n", - " data = train_data if split == 'train' else val_data\n", - " t1 = time.time()\n", - " ix = random.randint(subkey, (batch_size,), 0, len(data) - block_size)\n", - " t2 = time.time()\n", - " # x = jnp.stack([data[i:i+block_size] for i in ix])\n", - " # y = jnp.stack([data[i+1:i+block_size+1] for i in ix])\n", - " def slice_data(i):\n", - " return jax.lax.dynamic_slice(data, (i,), (block_size,))\n", - "\n", - " x = jax.vmap(slice_data)(ix)\n", - " y = jax.vmap(slice_data)(ix+1)\n", - " x, y = device_put(x), device_put(y)\n", - " # print('TIME rand idx', t2-t1)\n", - " # print('TIME rand idx fetch', time.time()-t2)\n", - " return x, y\n", - "\n", - "def estimate_loss(params, prng):\n", - " out = {}\n", - " for split in ['train', 'val']:\n", - " losses = jnp.zeros(eval_iters)\n", - " for k in range(eval_iters):\n", - " prng, subkey = random.split(prng)\n", - " X, Y = get_batch(split, subkey)\n", - " logits, loss = model.apply(params, X, Y)\n", - " losses = losses.at[k].set(loss)\n", - " out[split] = losses.mean()\n", - " return out\n", - "\n", - "class BigramLanguageModel(nn.Module):\n", - " vocab_size: int\n", - "\n", - " @nn.compact\n", - " def __call__(self, idx, targets=None):\n", - " # Token embedding table\n", - " embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)\n", - " logits = embedding_table(idx) # (B,T,C)\n", - "\n", - " if targets is None:\n", - " loss = None\n", - " else:\n", - " B, T, C = logits.shape\n", - " logits = logits.reshape(B*T, C)\n", - " targets = targets.reshape(B*T)\n", - " loss = -jnp.sum(jax.nn.one_hot(targets, C) * jax.nn.log_softmax(logits), axis=1).mean()\n", - "\n", - " return logits, loss\n", - "\n", - " def generate(self, params, key, idx, max_new_tokens):\n", - " for _ in range(max_new_tokens):\n", - " logits, _ = self.apply(params, idx)\n", - " logits = logits[:, -1, :] # (B, C)\n", - " key, subkey = random.split(key)\n", - " idx_next = random.categorical(subkey, logits)[:, None] # (B, 1)\n", - " idx = jnp.concatenate((idx, idx_next), axis=1) # (B, T+1)\n", - " return idx\n", - "\n", - "# Define the optimizer\n", - "tx = adam(learning_rate)\n", - "\n", - "# Initialize model parameters and optimizer state\n", - "model = BigramLanguageModel(vocab_size)\n", - "\n", - "# Initialize model parameters and optimizer\n", - "params = model.init(prng, jnp.ones((1, 1), jnp.int32))\n", - "opt_state = tx.init(params)\n", - "\n", - "# Loss function (assuming you have a batch of data: xb, yb)\n", - "def loss_fn(params, xb, yb):\n", - " logits, loss = model.apply(params, xb, yb)\n", - " return loss\n", - "\n", - "# Update function for a single training step\n", - "@jax.jit\n", - "def update_step(params, opt_state, xb, yb):\n", - " loss, grads = jax.value_and_grad(loss_fn)(params, xb, yb)\n", - " updates, opt_state = tx.update(grads, opt_state, params)\n", - " new_params = optax.apply_updates(params, updates)\n", - " return new_params, opt_state, loss\n", - "\n", - "# Training loop (example)\n", - "batch_size = 32\n", - "t = time.time()\n", - "for steps in range(max_iters):\n", - " # every once in a while evaluate the loss on train and val sets\n", - " if steps == max_iters - 1 or steps % eval_interval == 0:\n", - " losses = estimate_loss(params, prng)\n", - " print(f\"step {steps}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n", - " prng, subkey = random.split(prng)\n", - " xb, yb = get_batch('train', subkey)\n", - " t2 = time.time()\n", - " params, opt_state, loss = update_step(params, opt_state, xb, yb)\n", - " print(f\"Epoch {steps},Loss: {loss}\")\n", - "print('TIME jax', time.time()-t)" - ] - }, - { - "cell_type": "code", - "source": [ - "flax_apply_jitted = jax.jit(lambda params, xb, yb: model.apply(params, xb, yb))" - ], - "metadata": { - "id": "IFG4kdaRF7b1" - }, - "execution_count": 49, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "def generate(params, flax_apply_jitted, key, idx, max_new_tokens):\n", - " for _ in range(max_new_tokens):\n", - " logits, _ = flax_apply_jitted(params, idx, None)\n", - " logits = logits[:, -1, :] # (B, C)\n", - " key, subkey = random.split(key)\n", - " idx_next = random.categorical(subkey, logits)[:, None] # (B, 1)\n", - " idx = jnp.concatenate((idx, idx_next), axis=1) # (B, T+1)\n", - " return idx\n", - "\n", - "%time print(decode(generate(params, flax_apply_jitted, key, jnp.zeros((1, 1), jnp.int32), 500)[0].tolist()))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Ly4f5Bm-F5t8", - "outputId": "37bb46de-bd9a-4fd8-97d2-b6488d3cccbd" - }, - "execution_count": 54, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n", - "yD.P.e'wn,CZsvq gPrf$f&W3aypokkuSEz?Paw:YCj?M;x\n", - "pctexMadJMlTZr,CyhaRoYRJUfrsld,bqlwXxclCHIWu'FYEBldJrby;b!HR'Frcr,?&Nui3;woGFdW psZYhosYO!hHPv:F3uMAHbGoslLIWXd;woxmBMTe UpY ZEP\n", - "tTk?BlWOPrZP.zNS pjFR,OxrO?!$wNDsXCd;il:c!'lal'uOPGCJeDusSf,E.XgunoK-rJP-ky oHsxxwu,:i3UJgZbBMO;s:\n", - "oPALGSTE-heWO,tcI3$VaeVY JsTPqaqT-ebedAWhoEv:WFiCykXrvetkGbw'3-N!'oW\n", - "n\n", - "qi:FgOyU3Xd wrr gVItNvRo,JbtDAvcfHSKDWh.caNKrf CMr IGs?lbiNDerJg'cHB:rRwAuGq&UDUhOdnmc:&jUSZCuG?mF.An--,EMDfCHfITHs ztXPL U3iSE--AsTxeqf??imOUQfArTnZ.Hgv\n", - "CPU times: user 1.02 s, sys: 7.95 ms, total: 1.03 s\n", - "Wall time: 1.03 s\n" - ] - } - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [], - "toc_visible": true, - "include_colab_link": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -}