diff --git a/colabs/llms /jax_gpt_dev_bigram.ipynb b/colabs/llms /jax_gpt_dev_bigram.ipynb
deleted file mode 100644
index 7f6eeb4..0000000
--- a/colabs/llms /jax_gpt_dev_bigram.ipynb
+++ /dev/null
@@ -1,1678 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "view-in-github",
- "colab_type": "text"
- },
- "source": [
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "\n",
- "# Changed Andrej Karpathy's minGPT from pytorch to Jax code (see [Zero To Hero](https://karpathy.ai/zero-to-hero.html) video on GPT)."
- ],
- "metadata": {
- "id": "vv5qRLYk-aNq"
- }
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "NIbX_6V1ELk2"
- },
- "source": [
- "## import"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "id": "a6NK_zR1EN1I"
- },
- "outputs": [],
- "source": [
- "import time\n",
- "import jax\n",
- "from jax import device_put\n",
- "import jax.numpy as jnp\n",
- "from jax import lax\n",
- "import jax.random as random\n",
- "import jax.nn as jnn\n",
- "from jax.nn.initializers import normal\n",
- "\n",
- "import flax.linen as nn\n",
- "\n",
- "import optax\n",
- "from optax import adam"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "wJpXpmjEYC_T"
- },
- "source": [
- "## Building a GPT"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "kqorRX1vF3KP"
- },
- "source": [
- "### load data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "h5hjCcLDr2WC",
- "outputId": "18e575dc-0dcc-4519-d97d-fd83c49a9e38"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "--2024-04-02 04:10:38-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
- "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.109.133, ...\n",
- "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
- "HTTP request sent, awaiting response... 200 OK\n",
- "Length: 1115394 (1.1M) [text/plain]\n",
- "Saving to: ‘input.txt’\n",
- "\n",
- "\rinput.txt 0%[ ] 0 --.-KB/s \rinput.txt 100%[===================>] 1.06M --.-KB/s in 0.04s \n",
- "\n",
- "2024-04-02 04:10:39 (25.5 MB/s) - ‘input.txt’ saved [1115394/1115394]\n",
- "\n"
- ]
- }
- ],
- "source": [
- "# We always start with a dataset to train on. Let's download the tiny shakespeare dataset\n",
- "!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "id": "O6medjfRsLD9"
- },
- "outputs": [],
- "source": [
- "# read it in to inspect it\n",
- "with open('input.txt', 'r', encoding='utf-8') as f:\n",
- " text = f.read()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "6xWI_VyAsN8F",
- "outputId": "c10c8eb5-547b-4357-b817-3820cc3b1df8"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "length of dataset in characters: 1115394\n"
- ]
- }
- ],
- "source": [
- "print(\"length of dataset in characters: \", len(text))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "2c5V0FvqseE0",
- "outputId": "6d410839-095a-4cb7-87f2-e4fd75200c03"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "First Citizen:\n",
- "Before we proceed any further, hear me speak.\n",
- "\n",
- "All:\n",
- "Speak, speak.\n",
- "\n",
- "First Citizen:\n",
- "You are all resolved rather to die than to famish?\n",
- "\n",
- "All:\n",
- "Resolved. resolved.\n",
- "\n",
- "First Citizen:\n",
- "First, you know Caius Marcius is chief enemy to the people.\n",
- "\n",
- "All:\n",
- "We know't, we know't.\n",
- "\n",
- "First Citizen:\n",
- "Let us kill him, and we'll have corn at our own price.\n",
- "Is't a verdict?\n",
- "\n",
- "All:\n",
- "No more talking on't; let it be done: away, away!\n",
- "\n",
- "Second Citizen:\n",
- "One word, good citizens.\n",
- "\n",
- "First Citizen:\n",
- "We are accounted poor citizens, the patricians good.\n",
- "What authority surfeits on would relieve us: if they\n",
- "would yield us but the superfluity, while it were\n",
- "wholesome, we might guess they relieved us humanely;\n",
- "but they think we are too dear: the leanness that\n",
- "afflicts us, the object of our misery, is as an\n",
- "inventory to particularise their abundance; our\n",
- "sufferance is a gain to them Let us revenge this with\n",
- "our pikes, ere we become rakes: for the gods know I\n",
- "speak this in hunger for bread, not in thirst for revenge.\n",
- "\n",
- "\n"
- ]
- }
- ],
- "source": [
- "# let's look at the first 1000 characters\n",
- "print(text[:1000])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "0e-Rbyr8sfM8",
- "outputId": "a34b8a27-8433-4613-8bf4-c4dee3596867"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n",
- " !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n",
- "65\n"
- ]
- }
- ],
- "source": [
- "# here are all the unique characters that occur in this text\n",
- "chars = sorted(list(set(text)))\n",
- "vocab_size = len(chars)\n",
- "print(''.join(chars))\n",
- "print(vocab_size)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "qFaINwGqD1Bm",
- "outputId": "b7da73de-d04b-4019-e3f5-76e5980bc5a4"
- },
- "outputs": [
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "True"
- ]
- },
- "metadata": {},
- "execution_count": 8
- }
- ],
- "source": [
- "'!' in chars"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "Yw1LKNCgwjj1",
- "outputId": "c8d96823-504d-4465-b091-72e67be69666"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "[46, 47, 47, 1, 58, 46, 43, 56, 43]\n",
- "hii there\n"
- ]
- }
- ],
- "source": [
- "# create a mapping from characters to integers\n",
- "stoi = { ch:i for i,ch in enumerate(chars) }\n",
- "itos = { i:ch for i,ch in enumerate(chars) }\n",
- "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
- "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
- "\n",
- "print(encode(\"hii there\"))\n",
- "print(decode(encode(\"hii there\")))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "YJb0OXPwzvqg",
- "outputId": "0ab5bbf0-e783-4fe6-e022-cd12e910570c"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- ":8: UserWarning: Explicitly requested dtype requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n",
- " data = jnp.array(encode(text), dtype=jnp.int64)\n"
- ]
- },
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "(1115394,) int32\n",
- "[18 47 56 57 58 1 15 47 58 47 64 43 52 10 0 14 43 44 53 56 43 1 61 43\n",
- " 1 54 56 53 41 43 43 42 1 39 52 63 1 44 59 56 58 46 43 56 6 1 46 43\n",
- " 39 56 1 51 43 1 57 54 43 39 49 8 0 0 13 50 50 10 0 31 54 43 39 49\n",
- " 6 1 57 54 43 39 49 8 0 0 18 47 56 57 58 1 15 47 58 47 64 43 52 10\n",
- " 0 37 53 59 1 39 56 43 1 39 50 50 1 56 43 57 53 50 60 43 42 1 56 39\n",
- " 58 46 43 56 1 58 53 1 42 47 43 1 58 46 39 52 1 58 53 1 44 39 51 47\n",
- " 57 46 12 0 0 13 50 50 10 0 30 43 57 53 50 60 43 42 8 1 56 43 57 53\n",
- " 50 60 43 42 8 0 0 18 47 56 57 58 1 15 47 58 47 64 43 52 10 0 18 47\n",
- " 56 57 58 6 1 63 53 59 1 49 52 53 61 1 15 39 47 59 57 1 25 39 56 41\n",
- " 47 59 57 1 47 57 1 41 46 47 43 44 1 43 52 43 51 63 1 58 53 1 58 46\n",
- " 43 1 54 43 53 54 50 43 8 0 0 13 50 50 10 0 35 43 1 49 52 53 61 5\n",
- " 58 6 1 61 43 1 49 52 53 61 5 58 8 0 0 18 47 56 57 58 1 15 47 58\n",
- " 47 64 43 52 10 0 24 43 58 1 59 57 1 49 47 50 50 1 46 47 51 6 1 39\n",
- " 52 42 1 61 43 5 50 50 1 46 39 60 43 1 41 53 56 52 1 39 58 1 53 59\n",
- " 56 1 53 61 52 1 54 56 47 41 43 8 0 21 57 5 58 1 39 1 60 43 56 42\n",
- " 47 41 58 12 0 0 13 50 50 10 0 26 53 1 51 53 56 43 1 58 39 50 49 47\n",
- " 52 45 1 53 52 5 58 11 1 50 43 58 1 47 58 1 40 43 1 42 53 52 43 10\n",
- " 1 39 61 39 63 6 1 39 61 39 63 2 0 0 31 43 41 53 52 42 1 15 47 58\n",
- " 47 64 43 52 10 0 27 52 43 1 61 53 56 42 6 1 45 53 53 42 1 41 47 58\n",
- " 47 64 43 52 57 8 0 0 18 47 56 57 58 1 15 47 58 47 64 43 52 10 0 35\n",
- " 43 1 39 56 43 1 39 41 41 53 59 52 58 43 42 1 54 53 53 56 1 41 47 58\n",
- " 47 64 43 52 57 6 1 58 46 43 1 54 39 58 56 47 41 47 39 52 57 1 45 53\n",
- " 53 42 8 0 35 46 39 58 1 39 59 58 46 53 56 47 58 63 1 57 59 56 44 43\n",
- " 47 58 57 1 53 52 1 61 53 59 50 42 1 56 43 50 47 43 60 43 1 59 57 10\n",
- " 1 47 44 1 58 46 43 63 0 61 53 59 50 42 1 63 47 43 50 42 1 59 57 1\n",
- " 40 59 58 1 58 46 43 1 57 59 54 43 56 44 50 59 47 58 63 6 1 61 46 47\n",
- " 50 43 1 47 58 1 61 43 56 43 0 61 46 53 50 43 57 53 51 43 6 1 61 43\n",
- " 1 51 47 45 46 58 1 45 59 43 57 57 1 58 46 43 63 1 56 43 50 47 43 60\n",
- " 43 42 1 59 57 1 46 59 51 39 52 43 50 63 11 0 40 59 58 1 58 46 43 63\n",
- " 1 58 46 47 52 49 1 61 43 1 39 56 43 1 58 53 53 1 42 43 39 56 10 1\n",
- " 58 46 43 1 50 43 39 52 52 43 57 57 1 58 46 39 58 0 39 44 44 50 47 41\n",
- " 58 57 1 59 57 6 1 58 46 43 1 53 40 48 43 41 58 1 53 44 1 53 59 56\n",
- " 1 51 47 57 43 56 63 6 1 47 57 1 39 57 1 39 52 0 47 52 60 43 52 58\n",
- " 53 56 63 1 58 53 1 54 39 56 58 47 41 59 50 39 56 47 57 43 1 58 46 43\n",
- " 47 56 1 39 40 59 52 42 39 52 41 43 11 1 53 59 56 0 57 59 44 44 43 56\n",
- " 39 52 41 43 1 47 57 1 39 1 45 39 47 52 1 58 53 1 58 46 43 51 1 24\n",
- " 43 58 1 59 57 1 56 43 60 43 52 45 43 1 58 46 47 57 1 61 47 58 46 0\n",
- " 53 59 56 1 54 47 49 43 57 6 1 43 56 43 1 61 43 1 40 43 41 53 51 43\n",
- " 1 56 39 49 43 57 10 1 44 53 56 1 58 46 43 1 45 53 42 57 1 49 52 53\n",
- " 61 1 21 0 57 54 43 39 49 1 58 46 47 57 1 47 52 1 46 59 52 45 43 56\n",
- " 1 44 53 56 1 40 56 43 39 42 6 1 52 53 58 1 47 52 1 58 46 47 56 57\n",
- " 58 1 44 53 56 1 56 43 60 43 52 45 43 8 0 0]\n"
- ]
- }
- ],
- "source": [
- "# # let's now encode the entire text dataset and store it into a torch.Tensor\n",
- "# import torch # we use PyTorch: https://pytorch.org\n",
- "# data = torch.tensor(encode(text), dtype=torch.long)\n",
- "# print(data.shape, data.dtype)\n",
- "# print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this\n",
- "\n",
- "# Assuming `encode` is a function defined elsewhere\n",
- "data = jnp.array(encode(text), dtype=jnp.int64)\n",
- "print(data.shape, data.dtype)\n",
- "print(data[:1000])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "pBlRYpvhF9xj"
- },
- "source": [
- "#### split into train and test"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {
- "id": "f_WIXqxz0lU5"
- },
- "outputs": [],
- "source": [
- "# Let's now split up the data into train and validation sets\n",
- "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
- "train_data = data[:n]\n",
- "val_data = data[n:]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "bCAY6ej2DQUm",
- "outputId": "6114a515-ebab-45be-fa5c-5f58c8560d29"
- },
- "outputs": [
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "(111540,)"
- ]
- },
- "metadata": {},
- "execution_count": 12
- }
- ],
- "source": [
- "val_data.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "TD5Bj8Y6IAD4",
- "outputId": "5087c2d6-c461-4dbe-989a-ff6fd49a9187"
- },
- "outputs": [
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "Array([18, 47, 56, 57, 58, 1, 15, 47, 58], dtype=int32)"
- ]
- },
- "metadata": {},
- "execution_count": 13
- }
- ],
- "source": [
- "block_size = 8\n",
- "train_data[:block_size+1]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "nHbUA5zZGBqn"
- },
- "source": [
- "#### build feature input x and target output y"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "9HXDe8vGJCEn",
- "outputId": "af9aab79-9594-47ac-8c33-62a92d7138c5"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "0\n",
- "when input is [18] the target: 47\n",
- "1\n",
- "when input is [18 47] the target: 56\n",
- "2\n",
- "when input is [18 47 56] the target: 57\n",
- "3\n",
- "when input is [18 47 56 57] the target: 58\n",
- "4\n",
- "when input is [18 47 56 57 58] the target: 1\n",
- "5\n",
- "when input is [18 47 56 57 58 1] the target: 15\n",
- "6\n",
- "when input is [18 47 56 57 58 1 15] the target: 47\n",
- "7\n",
- "when input is [18 47 56 57 58 1 15 47] the target: 58\n"
- ]
- }
- ],
- "source": [
- "x = train_data[:block_size]\n",
- "y = train_data[1:block_size+1]\n",
- "for t in range(block_size):\n",
- " print(t)\n",
- " context = x[:t+1]\n",
- " target = y[t]\n",
- " print(f\"when input is {context} the target: {target}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "Q3k1Czf7LuA9",
- "outputId": "35e955e0-4d8e-48a2-a2cf-f7f0c341a925"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "inputs:\n",
- "(4, 8)\n",
- "[[ 1 51 39 52 58 50 43 0]\n",
- " [25 17 27 10 0 27 6 1]\n",
- " [47 51 8 0 14 59 58 1]\n",
- " [ 1 57 59 41 46 1 50 43]]\n",
- "targets:\n",
- "(4, 8)\n",
- "[[51 39 52 58 50 43 0 53]\n",
- " [17 27 10 0 27 6 1 50]\n",
- " [51 8 0 14 59 58 1 46]\n",
- " [57 59 41 46 1 50 43 52]]\n",
- "----\n",
- "when input is [1] the target: 51\n",
- "when input is [1, 51] the target: 39\n",
- "when input is [1, 51, 39] the target: 52\n",
- "when input is [1, 51, 39, 52] the target: 58\n",
- "when input is [1, 51, 39, 52, 58] the target: 50\n",
- "when input is [1, 51, 39, 52, 58, 50] the target: 43\n",
- "when input is [1, 51, 39, 52, 58, 50, 43] the target: 0\n",
- "when input is [1, 51, 39, 52, 58, 50, 43, 0] the target: 53\n",
- "when input is [25] the target: 17\n",
- "when input is [25, 17] the target: 27\n",
- "when input is [25, 17, 27] the target: 10\n",
- "when input is [25, 17, 27, 10] the target: 0\n",
- "when input is [25, 17, 27, 10, 0] the target: 27\n",
- "when input is [25, 17, 27, 10, 0, 27] the target: 6\n",
- "when input is [25, 17, 27, 10, 0, 27, 6] the target: 1\n",
- "when input is [25, 17, 27, 10, 0, 27, 6, 1] the target: 50\n",
- "when input is [47] the target: 51\n",
- "when input is [47, 51] the target: 8\n",
- "when input is [47, 51, 8] the target: 0\n",
- "when input is [47, 51, 8, 0] the target: 14\n",
- "when input is [47, 51, 8, 0, 14] the target: 59\n",
- "when input is [47, 51, 8, 0, 14, 59] the target: 58\n",
- "when input is [47, 51, 8, 0, 14, 59, 58] the target: 1\n",
- "when input is [47, 51, 8, 0, 14, 59, 58, 1] the target: 46\n",
- "when input is [1] the target: 57\n",
- "when input is [1, 57] the target: 59\n",
- "when input is [1, 57, 59] the target: 41\n",
- "when input is [1, 57, 59, 41] the target: 46\n",
- "when input is [1, 57, 59, 41, 46] the target: 1\n",
- "when input is [1, 57, 59, 41, 46, 1] the target: 50\n",
- "when input is [1, 57, 59, 41, 46, 1, 50] the target: 43\n",
- "when input is [1, 57, 59, 41, 46, 1, 50, 43] the target: 52\n"
- ]
- }
- ],
- "source": [
- "prng = jax.random.PRNGKey(1337)\n",
- "batch_size = 4 # how many independent sequences will we process in parallel?\n",
- "block_size = 8 # what is the maximum context length for predictions?\n",
- "ix = random.randint(random.PRNGKey(0), (batch_size,), 0, len(data) - block_size)\n",
- "\n",
- "def get_batch(split, subkey):\n",
- " # generate a small batch of data of inputs x and targets y\n",
- " data = train_data if split == 'train' else val_data\n",
- " ix = random.randint(subkey, (batch_size,), 0, len(data) - block_size)\n",
- " # x = jnp.stack([data[i:i+block_size] for i in ix])\n",
- " # y = jnp.stack([data[i+1:i+block_size+1] for i in ix])\n",
- " # optimize the above code ^^.\n",
- " # speed up by using dynamic_slice and vmap\n",
- " def slice_data(i):\n",
- " return jax.lax.dynamic_slice(data, (i,), (block_size,))\n",
- "\n",
- " x = jax.vmap(slice_data)(ix)\n",
- " y = jax.vmap(slice_data)(ix+1)\n",
- " x, y = device_put(x), device_put(y)\n",
- " return x, y\n",
- "\n",
- "xb, yb = get_batch('train', prng)\n",
- "print('inputs:')\n",
- "print(xb.shape)\n",
- "print(xb)\n",
- "print('targets:')\n",
- "print(yb.shape)\n",
- "print(yb)\n",
- "\n",
- "print('----')\n",
- "\n",
- "for b in range(batch_size): # batch dimension\n",
- " for t in range(block_size): # time dimension\n",
- " context = xb[b, :t+1]\n",
- " target = yb[b,t]\n",
- " print(f\"when input is {context.tolist()} the target: {target}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "qpyyAeIzQjlO",
- "outputId": "ecc878aa-2650-4cad-b32e-1d18c946d0a1"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "[[ 1 51 39 52 58 50 43 0]\n",
- " [25 17 27 10 0 27 6 1]\n",
- " [47 51 8 0 14 59 58 1]\n",
- " [ 1 57 59 41 46 1 50 43]]\n"
- ]
- }
- ],
- "source": [
- "print(xb) # our input to the transformer"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "yZF__SLaGJs9"
- },
- "source": [
- "### Baseline model: Bigram"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "nql_1ER53oCf",
- "outputId": "cab32afd-03fd-4d3b-fd94-74bcbe2a4af9"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "(32, 65)\n",
- "4.1973763\n"
- ]
- }
- ],
- "source": [
- "class BigramLanguageModel(nn.Module):\n",
- " vocab_size: int\n",
- "\n",
- " @nn.compact\n",
- " def __call__(self, idx, targets=None):\n",
- " # Token embedding table\n",
- " embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)\n",
- " logits = embedding_table(idx) # (B,T,C)\n",
- "\n",
- " if targets is None:\n",
- " loss = None\n",
- " else:\n",
- " B, T, C = logits.shape\n",
- " logits = logits.reshape(B*T, C)\n",
- " targets = targets.reshape(B*T)\n",
- " loss = -jnp.sum(jax.nn.one_hot(targets, C) * jax.nn.log_softmax(logits), axis=1).mean()\n",
- "\n",
- " return logits, loss\n",
- "\n",
- "# Example usage\n",
- "model = BigramLanguageModel(vocab_size)\n",
- "\n",
- "# Initialize model parameters and optimizer\n",
- "key = random.PRNGKey(1337)\n",
- "params = model.init(key, jnp.ones((1, 1), jnp.int32))\n",
- "\n",
- "# jax jit the model apply to speed up\n",
- "flax_apply_jitted = jax.jit(lambda params, xb, yb: model.apply(params, xb, yb))\n",
- "\n",
- "logits, loss = flax_apply_jitted(params, xb, yb)\n",
- "print(loss)"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "#### training"
- ],
- "metadata": {
- "id": "wXW3MFQqA7AD"
- }
- },
- {
- "cell_type": "code",
- "execution_count": 35,
- "metadata": {
- "id": "eTyJ8qAaDdiF",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "0ddc5e2f-273b-4bdb-c01e-1011e8aab9e5"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Tracedwith Tracedwith\n",
- "Epoch 0,Loss: 4.176701545715332, sample batch: 0.006464958190917969, forward pass and backward grad: 0.3173556327819824\n",
- "Epoch 1,Loss: 4.168148040771484, sample batch: 0.006571531295776367, forward pass and backward grad: 0.007528066635131836\n",
- "Epoch 2,Loss: 4.176245212554932, sample batch: 0.006763458251953125, forward pass and backward grad: 0.007700681686401367\n",
- "Epoch 3,Loss: 4.169358730316162, sample batch: 0.007287740707397461, forward pass and backward grad: 0.009604454040527344\n",
- "Epoch 4,Loss: 4.175686836242676, sample batch: 0.009287357330322266, forward pass and backward grad: 0.010252714157104492\n",
- "Epoch 5,Loss: 4.170393943786621, sample batch: 0.0073015689849853516, forward pass and backward grad: 0.008276939392089844\n",
- "Epoch 6,Loss: 4.169349193572998, sample batch: 0.006832599639892578, forward pass and backward grad: 0.007781982421875\n",
- "Epoch 7,Loss: 4.162716388702393, sample batch: 0.007829904556274414, forward pass and backward grad: 0.008838653564453125\n",
- "Epoch 8,Loss: 4.157233238220215, sample batch: 0.007964611053466797, forward pass and backward grad: 0.008940458297729492\n",
- "Epoch 9,Loss: 4.163629055023193, sample batch: 0.0069615840911865234, forward pass and backward grad: 0.007949113845825195\n",
- "Epoch 10,Loss: 4.16217565536499, sample batch: 0.007250785827636719, forward pass and backward grad: 0.008380413055419922\n",
- "Epoch 11,Loss: 4.1645026206970215, sample batch: 0.007141828536987305, forward pass and backward grad: 0.008124113082885742\n",
- "Epoch 12,Loss: 4.1598124504089355, sample batch: 0.006840944290161133, forward pass and backward grad: 0.007810831069946289\n",
- "Epoch 13,Loss: 4.157349586486816, sample batch: 0.006644248962402344, forward pass and backward grad: 0.0076143741607666016\n",
- "Epoch 14,Loss: 4.163051128387451, sample batch: 0.007123470306396484, forward pass and backward grad: 0.008063554763793945\n",
- "Epoch 15,Loss: 4.15546178817749, sample batch: 0.005999565124511719, forward pass and backward grad: 0.006939888000488281\n",
- "Epoch 16,Loss: 4.154058456420898, sample batch: 0.005856513977050781, forward pass and backward grad: 0.0068018436431884766\n",
- "Epoch 17,Loss: 4.14748477935791, sample batch: 0.010221004486083984, forward pass and backward grad: 0.011186599731445312\n",
- "Epoch 18,Loss: 4.153148174285889, sample batch: 0.006284952163696289, forward pass and backward grad: 0.0072596073150634766\n",
- "Epoch 19,Loss: 4.15305757522583, sample batch: 0.005934238433837891, forward pass and backward grad: 0.0069620609283447266\n",
- "Epoch 20,Loss: 4.148075580596924, sample batch: 0.005886554718017578, forward pass and backward grad: 0.006838560104370117\n",
- "Epoch 21,Loss: 4.1545729637146, sample batch: 0.006558895111083984, forward pass and backward grad: 0.007547140121459961\n",
- "Epoch 22,Loss: 4.1454854011535645, sample batch: 0.0068204402923583984, forward pass and backward grad: 0.007781505584716797\n",
- "Epoch 23,Loss: 4.146124839782715, sample batch: 0.006976604461669922, forward pass and backward grad: 0.007932424545288086\n",
- "Epoch 24,Loss: 4.152686595916748, sample batch: 0.0068511962890625, forward pass and backward grad: 0.007791042327880859\n",
- "Epoch 25,Loss: 4.148232460021973, sample batch: 0.012384414672851562, forward pass and backward grad: 0.013611793518066406\n",
- "Epoch 26,Loss: 4.143317222595215, sample batch: 0.00710606575012207, forward pass and backward grad: 0.00807332992553711\n",
- "Epoch 27,Loss: 4.135959625244141, sample batch: 0.0061304569244384766, forward pass and backward grad: 0.007124662399291992\n",
- "Epoch 28,Loss: 4.133705139160156, sample batch: 0.00642704963684082, forward pass and backward grad: 0.007401466369628906\n",
- "Epoch 29,Loss: 4.144227504730225, sample batch: 0.008374929428100586, forward pass and backward grad: 0.009456157684326172\n",
- "Epoch 30,Loss: 4.137173652648926, sample batch: 0.007355928421020508, forward pass and backward grad: 0.008334636688232422\n",
- "Epoch 31,Loss: 4.137599945068359, sample batch: 0.007583141326904297, forward pass and backward grad: 0.008576154708862305\n",
- "Epoch 32,Loss: 4.1330671310424805, sample batch: 0.006922483444213867, forward pass and backward grad: 0.007891654968261719\n",
- "Epoch 33,Loss: 4.140651702880859, sample batch: 0.00725555419921875, forward pass and backward grad: 0.008275270462036133\n",
- "Epoch 34,Loss: 4.132898807525635, sample batch: 0.0066339969635009766, forward pass and backward grad: 0.00758814811706543\n",
- "Epoch 35,Loss: 4.1298909187316895, sample batch: 0.00668644905090332, forward pass and backward grad: 0.007630586624145508\n",
- "Epoch 36,Loss: 4.1282429695129395, sample batch: 0.006491661071777344, forward pass and backward grad: 0.007435798645019531\n",
- "Epoch 37,Loss: 4.1404337882995605, sample batch: 0.006692171096801758, forward pass and backward grad: 0.0076482295989990234\n",
- "Epoch 38,Loss: 4.124664783477783, sample batch: 0.007221221923828125, forward pass and backward grad: 0.008726119995117188\n",
- "Epoch 39,Loss: 4.121993064880371, sample batch: 0.010175466537475586, forward pass and backward grad: 0.011160850524902344\n",
- "Epoch 40,Loss: 4.131916522979736, sample batch: 0.0070574283599853516, forward pass and backward grad: 0.008001327514648438\n",
- "Epoch 41,Loss: 4.121824741363525, sample batch: 0.006624698638916016, forward pass and backward grad: 0.007588624954223633\n",
- "Epoch 42,Loss: 4.124565124511719, sample batch: 0.007647991180419922, forward pass and backward grad: 0.008681535720825195\n",
- "Epoch 43,Loss: 4.122013568878174, sample batch: 0.006825923919677734, forward pass and backward grad: 0.0077838897705078125\n",
- "Epoch 44,Loss: 4.124299049377441, sample batch: 0.00981760025024414, forward pass and backward grad: 0.010839700698852539\n",
- "Epoch 45,Loss: 4.116213798522949, sample batch: 0.00638580322265625, forward pass and backward grad: 0.008318901062011719\n",
- "Epoch 46,Loss: 4.118075847625732, sample batch: 0.00923776626586914, forward pass and backward grad: 0.010228395462036133\n",
- "Epoch 47,Loss: 4.114686489105225, sample batch: 0.007014036178588867, forward pass and backward grad: 0.007985591888427734\n",
- "Epoch 48,Loss: 4.117861270904541, sample batch: 0.007818460464477539, forward pass and backward grad: 0.008879899978637695\n",
- "Epoch 49,Loss: 4.111810207366943, sample batch: 0.015448570251464844, forward pass and backward grad: 0.016459941864013672\n",
- "Epoch 50,Loss: 4.1100616455078125, sample batch: 0.00766754150390625, forward pass and backward grad: 0.00865626335144043\n",
- "Epoch 51,Loss: 4.106339931488037, sample batch: 0.0077228546142578125, forward pass and backward grad: 0.008730649948120117\n",
- "Epoch 52,Loss: 4.106321811676025, sample batch: 0.00809621810913086, forward pass and backward grad: 0.009130001068115234\n",
- "Epoch 53,Loss: 4.105854034423828, sample batch: 0.007456302642822266, forward pass and backward grad: 0.008474588394165039\n",
- "Epoch 54,Loss: 4.116119384765625, sample batch: 0.0074770450592041016, forward pass and backward grad: 0.008490324020385742\n",
- "Epoch 55,Loss: 4.107293605804443, sample batch: 0.008399009704589844, forward pass and backward grad: 0.009408950805664062\n",
- "Epoch 56,Loss: 4.112220764160156, sample batch: 0.007843494415283203, forward pass and backward grad: 0.008843421936035156\n",
- "Epoch 57,Loss: 4.106284141540527, sample batch: 0.007303953170776367, forward pass and backward grad: 0.008324861526489258\n",
- "Epoch 58,Loss: 4.101660251617432, sample batch: 0.009732961654663086, forward pass and backward grad: 0.010820388793945312\n",
- "Epoch 59,Loss: 4.100407123565674, sample batch: 0.0070569515228271484, forward pass and backward grad: 0.008041620254516602\n",
- "Epoch 60,Loss: 4.098540306091309, sample batch: 0.007241725921630859, forward pass and backward grad: 0.00823354721069336\n",
- "Epoch 61,Loss: 4.097070217132568, sample batch: 0.008341073989868164, forward pass and backward grad: 0.009309768676757812\n",
- "Epoch 62,Loss: 4.0942816734313965, sample batch: 0.0075092315673828125, forward pass and backward grad: 0.008522748947143555\n",
- "Epoch 63,Loss: 4.098483562469482, sample batch: 0.007592439651489258, forward pass and backward grad: 0.008698225021362305\n",
- "Epoch 64,Loss: 4.098681449890137, sample batch: 0.007315397262573242, forward pass and backward grad: 0.008310079574584961\n",
- "Epoch 65,Loss: 4.09865140914917, sample batch: 0.007418632507324219, forward pass and backward grad: 0.008446455001831055\n",
- "Epoch 66,Loss: 4.09343957901001, sample batch: 0.008169412612915039, forward pass and backward grad: 0.009173154830932617\n",
- "Epoch 67,Loss: 4.098697185516357, sample batch: 0.0086212158203125, forward pass and backward grad: 0.009929418563842773\n",
- "Epoch 68,Loss: 4.089414596557617, sample batch: 0.008454084396362305, forward pass and backward grad: 0.009839296340942383\n",
- "Epoch 69,Loss: 4.089560031890869, sample batch: 0.007783412933349609, forward pass and backward grad: 0.008769750595092773\n",
- "Epoch 70,Loss: 4.0868353843688965, sample batch: 0.007253885269165039, forward pass and backward grad: 0.008234262466430664\n",
- "Epoch 71,Loss: 4.0886335372924805, sample batch: 0.007701396942138672, forward pass and backward grad: 0.00868678092956543\n",
- "Epoch 72,Loss: 4.08442497253418, sample batch: 0.013611316680908203, forward pass and backward grad: 0.014733552932739258\n",
- "Epoch 73,Loss: 4.0852155685424805, sample batch: 0.007182121276855469, forward pass and backward grad: 0.008143424987792969\n",
- "Epoch 74,Loss: 4.077157974243164, sample batch: 0.006951570510864258, forward pass and backward grad: 0.007899045944213867\n",
- "Epoch 75,Loss: 4.083674430847168, sample batch: 0.006950855255126953, forward pass and backward grad: 0.00792551040649414\n",
- "Epoch 76,Loss: 4.080029487609863, sample batch: 0.007283687591552734, forward pass and backward grad: 0.008302688598632812\n",
- "Epoch 77,Loss: 4.075239658355713, sample batch: 0.006971836090087891, forward pass and backward grad: 0.008055925369262695\n",
- "Epoch 78,Loss: 4.077165603637695, sample batch: 0.008187294006347656, forward pass and backward grad: 0.00914311408996582\n",
- "Epoch 79,Loss: 4.082056522369385, sample batch: 0.007593870162963867, forward pass and backward grad: 0.009018421173095703\n",
- "Epoch 80,Loss: 4.08469820022583, sample batch: 0.009556293487548828, forward pass and backward grad: 0.010587453842163086\n",
- "Epoch 81,Loss: 4.074877738952637, sample batch: 0.007441520690917969, forward pass and backward grad: 0.008409261703491211\n",
- "Epoch 82,Loss: 4.065997123718262, sample batch: 0.008413314819335938, forward pass and backward grad: 0.009436368942260742\n",
- "Epoch 83,Loss: 4.072015762329102, sample batch: 0.009505033493041992, forward pass and backward grad: 0.010768413543701172\n",
- "Epoch 84,Loss: 4.065585136413574, sample batch: 0.008172035217285156, forward pass and backward grad: 0.009321451187133789\n",
- "Epoch 85,Loss: 4.070855617523193, sample batch: 0.007485151290893555, forward pass and backward grad: 0.008470535278320312\n",
- "Epoch 86,Loss: 4.06818151473999, sample batch: 0.007288455963134766, forward pass and backward grad: 0.008320093154907227\n",
- "Epoch 87,Loss: 4.072195529937744, sample batch: 0.007587909698486328, forward pass and backward grad: 0.008961677551269531\n",
- "Epoch 88,Loss: 4.065299987792969, sample batch: 0.0156404972076416, forward pass and backward grad: 0.017002105712890625\n",
- "Epoch 89,Loss: 4.059263229370117, sample batch: 0.008633852005004883, forward pass and backward grad: 0.009637594223022461\n",
- "Epoch 90,Loss: 4.064813613891602, sample batch: 0.006797313690185547, forward pass and backward grad: 0.007784843444824219\n",
- "Epoch 91,Loss: 4.0675811767578125, sample batch: 0.010977745056152344, forward pass and backward grad: 0.012456655502319336\n",
- "Epoch 92,Loss: 4.062394142150879, sample batch: 0.012706518173217773, forward pass and backward grad: 0.014220237731933594\n",
- "Epoch 93,Loss: 4.053857326507568, sample batch: 0.013547658920288086, forward pass and backward grad: 0.01579761505126953\n",
- "Epoch 94,Loss: 4.056527614593506, sample batch: 0.00801706314086914, forward pass and backward grad: 0.009006500244140625\n",
- "Epoch 95,Loss: 4.060898780822754, sample batch: 0.00725102424621582, forward pass and backward grad: 0.008263349533081055\n",
- "Epoch 96,Loss: 4.057039737701416, sample batch: 0.00706171989440918, forward pass and backward grad: 0.008035659790039062\n",
- "Epoch 97,Loss: 4.053262710571289, sample batch: 0.007007598876953125, forward pass and backward grad: 0.008001327514648438\n",
- "Epoch 98,Loss: 4.054243564605713, sample batch: 0.008085012435913086, forward pass and backward grad: 0.00907278060913086\n",
- "Epoch 99,Loss: 4.056581974029541, sample batch: 0.0069582462310791016, forward pass and backward grad: 0.007887125015258789\n"
- ]
- }
- ],
- "source": [
- "# Define the optimizer\n",
- "learning_rate = 1e-3 # Adjust as needed\n",
- "tx = adam(learning_rate)\n",
- "\n",
- "# Initialize model parameters and optimizer state\n",
- "prng = random.PRNGKey(1337)\n",
- "params = model.init(prng, jnp.ones((1, 1), jnp.int32))\n",
- "opt_state = tx.init(params)\n",
- "\n",
- "# Loss function (assuming you have a batch of data: xb, yb)\n",
- "def loss_fn(params, xb, yb):\n",
- " print(xb, yb)\n",
- " logits, loss = model.apply(params, xb, yb)\n",
- " return loss\n",
- "\n",
- "# Update function for a single training step\n",
- "@jax.jit\n",
- "def update_step(params, opt_state, xb, yb):\n",
- " loss, grads = jax.value_and_grad(loss_fn)(params, xb, yb)\n",
- " updates, opt_state = tx.update(grads, opt_state, params)\n",
- " new_params = optax.apply_updates(params, updates)\n",
- " return new_params, opt_state, loss\n",
- "\n",
- "# Training loop (example)\n",
- "batch_size = 32\n",
- "for steps in range(100):\n",
- " t1 = time.time()\n",
- " prng, subkey = random.split(prng)\n",
- " xb, yb = get_batch('train', subkey)\n",
- " t2 = time.time()\n",
- " params, opt_state, loss = update_step(params, opt_state, xb, yb)\n",
- " print(f\"Epoch {steps},Loss: {loss}, sample batch: {t2-t1}, forward pass and backward grad: {time.time()-t1}\")\n",
- "\n",
- "flax_apply_jitted = jax.jit(lambda params, xb, yb: model.apply(jax.lax.stop_gradient(params), xb, yb))"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "#### inference\n",
- "Slower than torch code\n",
- " 1. logits[:, -1, :]\n",
- " 2. random.categorical "
- ],
- "metadata": {
- "id": "DwB7mg9tBt0D"
- }
- },
- {
- "cell_type": "code",
- "execution_count": 37,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "EcVIDWAZEtjN",
- "outputId": "42718af7-305d-4216-99a3-d51504a2722b"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n",
- "yD.P.e'wn,CZsvq gP-f$f&W3aypokkuSEz?Paw:YCj?M;x\n",
- "pctpxMvdJMlTZrmCZhPRjYRJUfrgld,bqlwXxBlCHIWu'FYEBTwJrbX;b!HR'Fr;rI?&Nui3;woGFdW pAZYho3YO!hHPv:F3uMAHbG:slLyWXd;woxmBMTexUpY ZEP\n",
- "tTk?BlWOP&ZP.zNS YjFV,OxrO?!$wNDsXCd;iM:c!elaw'uOPGCJJDBsSf,E.XguCoK-rJP-kybvHsxxwu,:i3UJgZbBMO;s:coPALGSTE-hJWOStcI3$VaeVYfJsTPqaqT-ebJqAWy\n",
- "Ev:WFmCykXrvetkGbw-3-N!'oW\n",
- "nKqi:FgOyU3XdQwNr gVItNvRo,JbtDAvcfHSKDkh.caNKrf CMrJIGs?lbiNDbgJg'cHB:rRwAuGq&UDPhOdnmc:&jU,ZCuG?mF.An-r,EMDfCHfITHsvztXPL U3iSE-dAsTxeqf??i\n",
- "OUQfArTnZ.Hgv\n",
- "CPU times: user 1.03 s, sys: 9 ms, total: 1.04 s\n",
- "Wall time: 1.05 s\n"
- ]
- }
- ],
- "source": [
- "def generate(params, flax_apply_jitted, key, idx, max_new_tokens):\n",
- " for _ in range(max_new_tokens):\n",
- " logits, _ = flax_apply_jitted(params, idx, None)\n",
- " logits = logits[:, -1, :] # (B, C)\n",
- " key, subkey = random.split(key)\n",
- " idx_next = random.categorical(subkey, logits)[:, None] # (B, 1)\n",
- " idx = jnp.concatenate((idx, idx_next), axis=1) # (B, T+1)\n",
- " return idx\n",
- "\n",
- "%time print(decode(generate(params, flax_apply_jitted, key, jnp.zeros((1, 1), jnp.int32), 500)[0].tolist()))"
- ]
- },
- {
- "cell_type": "code",
- "source": [
- "# Try speed up with stop gradient\n",
- "t1=time.time()\n",
- "print(decode(generate(jax.lax.stop_gradient(params), flax_apply_jitted, key, jnp.zeros((1, 1), jnp.int32), 100)[0].tolist()))\n",
- "print('TIME total', time.time()-t1)"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "vfX1ymJPDoEF",
- "outputId": "7d908357-0706-446d-bee4-cb84791872f9"
- },
- "execution_count": 38,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n",
- "yD.P.e'wn,CZsvq gP-f$f&W3aypokkuSEz?Paw:YCj?M;x\n",
- "pctpxMvdJMlTZrmCZhPRjYRJUfrgld,bqlwXxBlCHIWu'FYEBTwJ\n",
- "TIME total 0.3211863040924072\n"
- ]
- }
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "yNc7jn8wy5Kx"
- },
- "source": [
- "#### put together"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 43,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "1sBa-IFGy90b",
- "outputId": "7b559e10-3b16-4ebd-e4f5-fbcdacaa5970"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "step 0: train loss 4.1747, val loss 4.1776\n",
- "Epoch 0,Loss: 4.176701545715332\n",
- "Epoch 1,Loss: 4.168148040771484\n",
- "Epoch 2,Loss: 4.176245212554932\n",
- "Epoch 3,Loss: 4.169358730316162\n",
- "Epoch 4,Loss: 4.175686836242676\n",
- "Epoch 5,Loss: 4.170393943786621\n",
- "Epoch 6,Loss: 4.169349193572998\n",
- "Epoch 7,Loss: 4.162716388702393\n",
- "Epoch 8,Loss: 4.157233238220215\n",
- "Epoch 9,Loss: 4.163629055023193\n",
- "Epoch 10,Loss: 4.16217565536499\n",
- "Epoch 11,Loss: 4.1645026206970215\n",
- "Epoch 12,Loss: 4.1598124504089355\n",
- "Epoch 13,Loss: 4.157349586486816\n",
- "Epoch 14,Loss: 4.163051128387451\n",
- "Epoch 15,Loss: 4.15546178817749\n",
- "Epoch 16,Loss: 4.154058456420898\n",
- "Epoch 17,Loss: 4.14748477935791\n",
- "Epoch 18,Loss: 4.153148174285889\n",
- "Epoch 19,Loss: 4.15305757522583\n",
- "Epoch 20,Loss: 4.148075580596924\n",
- "Epoch 21,Loss: 4.1545729637146\n",
- "Epoch 22,Loss: 4.1454854011535645\n",
- "Epoch 23,Loss: 4.146124839782715\n",
- "Epoch 24,Loss: 4.152686595916748\n",
- "Epoch 25,Loss: 4.148232460021973\n",
- "Epoch 26,Loss: 4.143317222595215\n",
- "Epoch 27,Loss: 4.135959625244141\n",
- "Epoch 28,Loss: 4.133705139160156\n",
- "Epoch 29,Loss: 4.144227504730225\n",
- "Epoch 30,Loss: 4.137173652648926\n",
- "Epoch 31,Loss: 4.137599945068359\n",
- "Epoch 32,Loss: 4.1330671310424805\n",
- "Epoch 33,Loss: 4.140651702880859\n",
- "Epoch 34,Loss: 4.132898807525635\n",
- "Epoch 35,Loss: 4.1298909187316895\n",
- "Epoch 36,Loss: 4.1282429695129395\n",
- "Epoch 37,Loss: 4.1404337882995605\n",
- "Epoch 38,Loss: 4.124664783477783\n",
- "Epoch 39,Loss: 4.121993064880371\n",
- "Epoch 40,Loss: 4.131916522979736\n",
- "Epoch 41,Loss: 4.121824741363525\n",
- "Epoch 42,Loss: 4.124565124511719\n",
- "Epoch 43,Loss: 4.122013568878174\n",
- "Epoch 44,Loss: 4.124299049377441\n",
- "Epoch 45,Loss: 4.116213798522949\n",
- "Epoch 46,Loss: 4.118075847625732\n",
- "Epoch 47,Loss: 4.114686489105225\n",
- "Epoch 48,Loss: 4.117861270904541\n",
- "Epoch 49,Loss: 4.111810207366943\n",
- "Epoch 50,Loss: 4.1100616455078125\n",
- "Epoch 51,Loss: 4.106339931488037\n",
- "Epoch 52,Loss: 4.106321811676025\n",
- "Epoch 53,Loss: 4.105854034423828\n",
- "Epoch 54,Loss: 4.116119384765625\n",
- "Epoch 55,Loss: 4.107293605804443\n",
- "Epoch 56,Loss: 4.112220764160156\n",
- "Epoch 57,Loss: 4.106284141540527\n",
- "Epoch 58,Loss: 4.101660251617432\n",
- "Epoch 59,Loss: 4.100407123565674\n",
- "Epoch 60,Loss: 4.098540306091309\n",
- "Epoch 61,Loss: 4.097070217132568\n",
- "Epoch 62,Loss: 4.0942816734313965\n",
- "Epoch 63,Loss: 4.098483562469482\n",
- "Epoch 64,Loss: 4.098681449890137\n",
- "Epoch 65,Loss: 4.09865140914917\n",
- "Epoch 66,Loss: 4.09343957901001\n",
- "Epoch 67,Loss: 4.098697185516357\n",
- "Epoch 68,Loss: 4.089414596557617\n",
- "Epoch 69,Loss: 4.089560031890869\n",
- "Epoch 70,Loss: 4.0868353843688965\n",
- "Epoch 71,Loss: 4.0886335372924805\n",
- "Epoch 72,Loss: 4.08442497253418\n",
- "Epoch 73,Loss: 4.0852155685424805\n",
- "Epoch 74,Loss: 4.077157974243164\n",
- "Epoch 75,Loss: 4.083674430847168\n",
- "Epoch 76,Loss: 4.080029487609863\n",
- "Epoch 77,Loss: 4.075239658355713\n",
- "Epoch 78,Loss: 4.077165603637695\n",
- "Epoch 79,Loss: 4.082056522369385\n",
- "Epoch 80,Loss: 4.08469820022583\n",
- "Epoch 81,Loss: 4.074877738952637\n",
- "Epoch 82,Loss: 4.065997123718262\n",
- "Epoch 83,Loss: 4.072015762329102\n",
- "Epoch 84,Loss: 4.065585136413574\n",
- "Epoch 85,Loss: 4.070855617523193\n",
- "Epoch 86,Loss: 4.06818151473999\n",
- "Epoch 87,Loss: 4.072195529937744\n",
- "Epoch 88,Loss: 4.065299987792969\n",
- "Epoch 89,Loss: 4.059263229370117\n",
- "Epoch 90,Loss: 4.064813613891602\n",
- "Epoch 91,Loss: 4.0675811767578125\n",
- "Epoch 92,Loss: 4.062394142150879\n",
- "Epoch 93,Loss: 4.053857326507568\n",
- "Epoch 94,Loss: 4.056527614593506\n",
- "Epoch 95,Loss: 4.060898780822754\n",
- "Epoch 96,Loss: 4.057039737701416\n",
- "Epoch 97,Loss: 4.053262710571289\n",
- "Epoch 98,Loss: 4.054243564605713\n",
- "Epoch 99,Loss: 4.056581974029541\n",
- "step 100: train loss 4.0528, val loss 4.0548\n",
- "Epoch 100,Loss: 4.046238899230957\n",
- "Epoch 101,Loss: 4.049732685089111\n",
- "Epoch 102,Loss: 4.0535478591918945\n",
- "Epoch 103,Loss: 4.050480365753174\n",
- "Epoch 104,Loss: 4.053426265716553\n",
- "Epoch 105,Loss: 4.042150974273682\n",
- "Epoch 106,Loss: 4.0471320152282715\n",
- "Epoch 107,Loss: 4.044474124908447\n",
- "Epoch 108,Loss: 4.040208339691162\n",
- "Epoch 109,Loss: 4.046355247497559\n",
- "Epoch 110,Loss: 4.039368629455566\n",
- "Epoch 111,Loss: 4.04256010055542\n",
- "Epoch 112,Loss: 4.035763740539551\n",
- "Epoch 113,Loss: 4.035247325897217\n",
- "Epoch 114,Loss: 4.035799026489258\n",
- "Epoch 115,Loss: 4.029808521270752\n",
- "Epoch 116,Loss: 4.034291744232178\n",
- "Epoch 117,Loss: 4.03066349029541\n",
- "Epoch 118,Loss: 4.036454200744629\n",
- "Epoch 119,Loss: 4.026322364807129\n",
- "Epoch 120,Loss: 4.024256229400635\n",
- "Epoch 121,Loss: 4.022698402404785\n",
- "Epoch 122,Loss: 4.022899150848389\n",
- "Epoch 123,Loss: 4.025989532470703\n",
- "Epoch 124,Loss: 4.022078990936279\n",
- "Epoch 125,Loss: 4.024803638458252\n",
- "Epoch 126,Loss: 4.015080451965332\n",
- "Epoch 127,Loss: 4.0170817375183105\n",
- "Epoch 128,Loss: 4.021418571472168\n",
- "Epoch 129,Loss: 4.012678146362305\n",
- "Epoch 130,Loss: 4.012955188751221\n",
- "Epoch 131,Loss: 4.0082855224609375\n",
- "Epoch 132,Loss: 4.011625289916992\n",
- "Epoch 133,Loss: 4.0081305503845215\n",
- "Epoch 134,Loss: 4.015928745269775\n",
- "Epoch 135,Loss: 4.0120649337768555\n",
- "Epoch 136,Loss: 4.00533390045166\n",
- "Epoch 137,Loss: 4.013724327087402\n",
- "Epoch 138,Loss: 4.004214763641357\n",
- "Epoch 139,Loss: 4.007634162902832\n",
- "Epoch 140,Loss: 3.9990668296813965\n",
- "Epoch 141,Loss: 4.004575252532959\n",
- "Epoch 142,Loss: 3.9970436096191406\n",
- "Epoch 143,Loss: 4.000253677368164\n",
- "Epoch 144,Loss: 3.993255615234375\n",
- "Epoch 145,Loss: 4.004382610321045\n",
- "Epoch 146,Loss: 3.994224786758423\n",
- "Epoch 147,Loss: 3.99639892578125\n",
- "Epoch 148,Loss: 3.9898176193237305\n",
- "Epoch 149,Loss: 3.9965922832489014\n",
- "Epoch 150,Loss: 3.9880197048187256\n",
- "Epoch 151,Loss: 3.990727663040161\n",
- "Epoch 152,Loss: 3.984743118286133\n",
- "Epoch 153,Loss: 3.9868345260620117\n",
- "Epoch 154,Loss: 3.9808623790740967\n",
- "Epoch 155,Loss: 3.979079246520996\n",
- "Epoch 156,Loss: 3.9874355792999268\n",
- "Epoch 157,Loss: 3.9819841384887695\n",
- "Epoch 158,Loss: 3.9719669818878174\n",
- "Epoch 159,Loss: 3.9836819171905518\n",
- "Epoch 160,Loss: 3.9836788177490234\n",
- "Epoch 161,Loss: 3.9711480140686035\n",
- "Epoch 162,Loss: 3.9807724952697754\n",
- "Epoch 163,Loss: 3.974381446838379\n",
- "Epoch 164,Loss: 3.975137948989868\n",
- "Epoch 165,Loss: 3.970792293548584\n",
- "Epoch 166,Loss: 3.9664454460144043\n",
- "Epoch 167,Loss: 3.972529888153076\n",
- "Epoch 168,Loss: 3.970104217529297\n",
- "Epoch 169,Loss: 3.9673550128936768\n",
- "Epoch 170,Loss: 3.974461793899536\n",
- "Epoch 171,Loss: 3.966738700866699\n",
- "Epoch 172,Loss: 3.968355178833008\n",
- "Epoch 173,Loss: 3.961606740951538\n",
- "Epoch 174,Loss: 3.9639313220977783\n",
- "Epoch 175,Loss: 3.9576148986816406\n",
- "Epoch 176,Loss: 3.966170072555542\n",
- "Epoch 177,Loss: 3.9580230712890625\n",
- "Epoch 178,Loss: 3.9642348289489746\n",
- "Epoch 179,Loss: 3.9687607288360596\n",
- "Epoch 180,Loss: 3.9539127349853516\n",
- "Epoch 181,Loss: 3.9508092403411865\n",
- "Epoch 182,Loss: 3.964644193649292\n",
- "Epoch 183,Loss: 3.9514272212982178\n",
- "Epoch 184,Loss: 3.9511210918426514\n",
- "Epoch 185,Loss: 3.953274965286255\n",
- "Epoch 186,Loss: 3.9493982791900635\n",
- "Epoch 187,Loss: 3.9452455043792725\n",
- "Epoch 188,Loss: 3.9493496417999268\n",
- "Epoch 189,Loss: 3.9497110843658447\n",
- "Epoch 190,Loss: 3.934187650680542\n",
- "Epoch 191,Loss: 3.942565679550171\n",
- "Epoch 192,Loss: 3.939657211303711\n",
- "Epoch 193,Loss: 3.950552225112915\n",
- "Epoch 194,Loss: 3.9368786811828613\n",
- "Epoch 195,Loss: 3.938438892364502\n",
- "Epoch 196,Loss: 3.9386792182922363\n",
- "Epoch 197,Loss: 3.937180519104004\n",
- "Epoch 198,Loss: 3.9317572116851807\n",
- "Epoch 199,Loss: 3.938286781311035\n",
- "step 200: train loss 3.9314, val loss 3.9342\n",
- "Epoch 200,Loss: 3.931473731994629\n",
- "Epoch 201,Loss: 3.9273831844329834\n",
- "Epoch 202,Loss: 3.9355266094207764\n",
- "Epoch 203,Loss: 3.9317986965179443\n",
- "Epoch 204,Loss: 3.9241442680358887\n",
- "Epoch 205,Loss: 3.927114963531494\n",
- "Epoch 206,Loss: 3.927417516708374\n",
- "Epoch 207,Loss: 3.918044090270996\n",
- "Epoch 208,Loss: 3.9226672649383545\n",
- "Epoch 209,Loss: 3.9170031547546387\n",
- "Epoch 210,Loss: 3.9196341037750244\n",
- "Epoch 211,Loss: 3.9205477237701416\n",
- "Epoch 212,Loss: 3.9187772274017334\n",
- "Epoch 213,Loss: 3.914111852645874\n",
- "Epoch 214,Loss: 3.915701389312744\n",
- "Epoch 215,Loss: 3.912682056427002\n",
- "Epoch 216,Loss: 3.9132320880889893\n",
- "Epoch 217,Loss: 3.9093964099884033\n",
- "Epoch 218,Loss: 3.910304546356201\n",
- "Epoch 219,Loss: 3.902721643447876\n",
- "Epoch 220,Loss: 3.9013726711273193\n",
- "Epoch 221,Loss: 3.9074864387512207\n",
- "Epoch 222,Loss: 3.9011595249176025\n",
- "Epoch 223,Loss: 3.9110357761383057\n",
- "Epoch 224,Loss: 3.919271469116211\n",
- "Epoch 225,Loss: 3.8986053466796875\n",
- "Epoch 226,Loss: 3.913900375366211\n",
- "Epoch 227,Loss: 3.9005937576293945\n",
- "Epoch 228,Loss: 3.894559383392334\n",
- "Epoch 229,Loss: 3.9020140171051025\n",
- "Epoch 230,Loss: 3.9012680053710938\n",
- "Epoch 231,Loss: 3.890580177307129\n",
- "Epoch 232,Loss: 3.901439666748047\n",
- "Epoch 233,Loss: 3.8912415504455566\n",
- "Epoch 234,Loss: 3.8939270973205566\n",
- "Epoch 235,Loss: 3.8888893127441406\n",
- "Epoch 236,Loss: 3.8912696838378906\n",
- "Epoch 237,Loss: 3.8963351249694824\n",
- "Epoch 238,Loss: 3.8835411071777344\n",
- "Epoch 239,Loss: 3.8910605907440186\n",
- "Epoch 240,Loss: 3.891777992248535\n",
- "Epoch 241,Loss: 3.8892035484313965\n",
- "Epoch 242,Loss: 3.8773059844970703\n",
- "Epoch 243,Loss: 3.883334159851074\n",
- "Epoch 244,Loss: 3.8824222087860107\n",
- "Epoch 245,Loss: 3.8792808055877686\n",
- "Epoch 246,Loss: 3.8650636672973633\n",
- "Epoch 247,Loss: 3.882506847381592\n",
- "Epoch 248,Loss: 3.8686697483062744\n",
- "Epoch 249,Loss: 3.871563196182251\n",
- "Epoch 250,Loss: 3.8801050186157227\n",
- "Epoch 251,Loss: 3.876453161239624\n",
- "Epoch 252,Loss: 3.878538131713867\n",
- "Epoch 253,Loss: 3.8688294887542725\n",
- "Epoch 254,Loss: 3.869239568710327\n",
- "Epoch 255,Loss: 3.8656933307647705\n",
- "Epoch 256,Loss: 3.8676295280456543\n",
- "Epoch 257,Loss: 3.868959903717041\n",
- "Epoch 258,Loss: 3.860564947128296\n",
- "Epoch 259,Loss: 3.8659508228302\n",
- "Epoch 260,Loss: 3.870162010192871\n",
- "Epoch 261,Loss: 3.862337112426758\n",
- "Epoch 262,Loss: 3.8631129264831543\n",
- "Epoch 263,Loss: 3.8556807041168213\n",
- "Epoch 264,Loss: 3.859565496444702\n",
- "Epoch 265,Loss: 3.8612101078033447\n",
- "Epoch 266,Loss: 3.860574245452881\n",
- "Epoch 267,Loss: 3.863940954208374\n",
- "Epoch 268,Loss: 3.846076726913452\n",
- "Epoch 269,Loss: 3.858536958694458\n",
- "Epoch 270,Loss: 3.8503029346466064\n",
- "Epoch 271,Loss: 3.857239007949829\n",
- "Epoch 272,Loss: 3.853991746902466\n",
- "Epoch 273,Loss: 3.8525960445404053\n",
- "Epoch 274,Loss: 3.8553481101989746\n",
- "Epoch 275,Loss: 3.845716953277588\n",
- "Epoch 276,Loss: 3.846571207046509\n",
- "Epoch 277,Loss: 3.8519811630249023\n",
- "Epoch 278,Loss: 3.8430190086364746\n",
- "Epoch 279,Loss: 3.8463683128356934\n",
- "Epoch 280,Loss: 3.839120864868164\n",
- "Epoch 281,Loss: 3.8422346115112305\n",
- "Epoch 282,Loss: 3.8394320011138916\n",
- "Epoch 283,Loss: 3.8440675735473633\n",
- "Epoch 284,Loss: 3.8503305912017822\n",
- "Epoch 285,Loss: 3.838545322418213\n",
- "Epoch 286,Loss: 3.829591751098633\n",
- "Epoch 287,Loss: 3.8370203971862793\n",
- "Epoch 288,Loss: 3.8436241149902344\n",
- "Epoch 289,Loss: 3.8240110874176025\n",
- "Epoch 290,Loss: 3.83495831489563\n",
- "Epoch 291,Loss: 3.8333072662353516\n",
- "Epoch 292,Loss: 3.833137035369873\n",
- "Epoch 293,Loss: 3.8384203910827637\n",
- "Epoch 294,Loss: 3.840672731399536\n",
- "Epoch 295,Loss: 3.83957576751709\n",
- "Epoch 296,Loss: 3.8232216835021973\n",
- "Epoch 297,Loss: 3.8206679821014404\n",
- "Epoch 298,Loss: 3.8219101428985596\n",
- "Epoch 299,Loss: 3.834547758102417\n",
- "step 300: train loss 3.8211, val loss 3.8261\n",
- "Epoch 300,Loss: 3.8192684650421143\n",
- "Epoch 301,Loss: 3.8214690685272217\n",
- "Epoch 302,Loss: 3.8147058486938477\n",
- "Epoch 303,Loss: 3.8182125091552734\n",
- "Epoch 304,Loss: 3.822451114654541\n",
- "Epoch 305,Loss: 3.8188161849975586\n",
- "Epoch 306,Loss: 3.8196117877960205\n",
- "Epoch 307,Loss: 3.814863681793213\n",
- "Epoch 308,Loss: 3.812788724899292\n",
- "Epoch 309,Loss: 3.799957036972046\n",
- "Epoch 310,Loss: 3.8292646408081055\n",
- "Epoch 311,Loss: 3.8186419010162354\n",
- "Epoch 312,Loss: 3.8149664402008057\n",
- "Epoch 313,Loss: 3.805006742477417\n",
- "Epoch 314,Loss: 3.801513195037842\n",
- "Epoch 315,Loss: 3.79679536819458\n",
- "Epoch 316,Loss: 3.79837703704834\n",
- "Epoch 317,Loss: 3.7946202754974365\n",
- "Epoch 318,Loss: 3.812180995941162\n",
- "Epoch 319,Loss: 3.8108162879943848\n",
- "Epoch 320,Loss: 3.790243625640869\n",
- "Epoch 321,Loss: 3.8042056560516357\n",
- "Epoch 322,Loss: 3.7979207038879395\n",
- "Epoch 323,Loss: 3.790651321411133\n",
- "Epoch 324,Loss: 3.7874715328216553\n",
- "Epoch 325,Loss: 3.79123854637146\n",
- "Epoch 326,Loss: 3.794536828994751\n",
- "Epoch 327,Loss: 3.784797191619873\n",
- "Epoch 328,Loss: 3.790224552154541\n",
- "Epoch 329,Loss: 3.782738208770752\n",
- "Epoch 330,Loss: 3.794175148010254\n",
- "Epoch 331,Loss: 3.786827564239502\n",
- "Epoch 332,Loss: 3.7888123989105225\n",
- "Epoch 333,Loss: 3.783346176147461\n",
- "Epoch 334,Loss: 3.793034076690674\n",
- "Epoch 335,Loss: 3.7874460220336914\n",
- "Epoch 336,Loss: 3.7902700901031494\n",
- "Epoch 337,Loss: 3.778681755065918\n",
- "Epoch 338,Loss: 3.7811100482940674\n",
- "Epoch 339,Loss: 3.76914644241333\n",
- "Epoch 340,Loss: 3.789398193359375\n",
- "Epoch 341,Loss: 3.7791271209716797\n",
- "Epoch 342,Loss: 3.770552635192871\n",
- "Epoch 343,Loss: 3.766422986984253\n",
- "Epoch 344,Loss: 3.7771592140197754\n",
- "Epoch 345,Loss: 3.7636663913726807\n",
- "Epoch 346,Loss: 3.787656545639038\n",
- "Epoch 347,Loss: 3.764472723007202\n",
- "Epoch 348,Loss: 3.7605934143066406\n",
- "Epoch 349,Loss: 3.7653210163116455\n",
- "Epoch 350,Loss: 3.768303155899048\n",
- "Epoch 351,Loss: 3.7691171169281006\n",
- "Epoch 352,Loss: 3.7605278491973877\n",
- "Epoch 353,Loss: 3.771134614944458\n",
- "Epoch 354,Loss: 3.766357183456421\n",
- "Epoch 355,Loss: 3.7620556354522705\n",
- "Epoch 356,Loss: 3.763490915298462\n",
- "Epoch 357,Loss: 3.7664549350738525\n",
- "Epoch 358,Loss: 3.7606000900268555\n",
- "Epoch 359,Loss: 3.7632651329040527\n",
- "Epoch 360,Loss: 3.745481252670288\n",
- "Epoch 361,Loss: 3.7506914138793945\n",
- "Epoch 362,Loss: 3.7509055137634277\n",
- "Epoch 363,Loss: 3.758342981338501\n",
- "Epoch 364,Loss: 3.7418177127838135\n",
- "Epoch 365,Loss: 3.7498388290405273\n",
- "Epoch 366,Loss: 3.75475811958313\n",
- "Epoch 367,Loss: 3.7491402626037598\n",
- "Epoch 368,Loss: 3.749390125274658\n",
- "Epoch 369,Loss: 3.74155330657959\n",
- "Epoch 370,Loss: 3.7474842071533203\n",
- "Epoch 371,Loss: 3.75659441947937\n",
- "Epoch 372,Loss: 3.7347867488861084\n",
- "Epoch 373,Loss: 3.73516845703125\n",
- "Epoch 374,Loss: 3.736833333969116\n",
- "Epoch 375,Loss: 3.7528369426727295\n",
- "Epoch 376,Loss: 3.7481980323791504\n",
- "Epoch 377,Loss: 3.7436063289642334\n",
- "Epoch 378,Loss: 3.7425522804260254\n",
- "Epoch 379,Loss: 3.743187189102173\n",
- "Epoch 380,Loss: 3.7335641384124756\n",
- "Epoch 381,Loss: 3.7418203353881836\n",
- "Epoch 382,Loss: 3.7498104572296143\n",
- "Epoch 383,Loss: 3.730151891708374\n",
- "Epoch 384,Loss: 3.735663414001465\n",
- "Epoch 385,Loss: 3.725689649581909\n",
- "Epoch 386,Loss: 3.7350897789001465\n",
- "Epoch 387,Loss: 3.7252840995788574\n",
- "Epoch 388,Loss: 3.7173573970794678\n",
- "Epoch 389,Loss: 3.72982120513916\n",
- "Epoch 390,Loss: 3.728365898132324\n",
- "Epoch 391,Loss: 3.7212765216827393\n",
- "Epoch 392,Loss: 3.7305972576141357\n",
- "Epoch 393,Loss: 3.720611572265625\n",
- "Epoch 394,Loss: 3.7320303916931152\n",
- "Epoch 395,Loss: 3.726435661315918\n",
- "Epoch 396,Loss: 3.7120232582092285\n",
- "Epoch 397,Loss: 3.721881151199341\n",
- "Epoch 398,Loss: 3.7123491764068604\n",
- "Epoch 399,Loss: 3.7059667110443115\n",
- "step 400: train loss 3.7182, val loss 3.7246\n",
- "Epoch 400,Loss: 3.715552806854248\n",
- "Epoch 401,Loss: 3.70719575881958\n",
- "Epoch 402,Loss: 3.717360258102417\n",
- "Epoch 403,Loss: 3.7229161262512207\n",
- "Epoch 404,Loss: 3.7143988609313965\n",
- "Epoch 405,Loss: 3.7218456268310547\n",
- "Epoch 406,Loss: 3.707012414932251\n",
- "Epoch 407,Loss: 3.7073845863342285\n",
- "Epoch 408,Loss: 3.7069833278656006\n",
- "Epoch 409,Loss: 3.7150471210479736\n",
- "Epoch 410,Loss: 3.709181547164917\n",
- "Epoch 411,Loss: 3.7069830894470215\n",
- "Epoch 412,Loss: 3.7000956535339355\n",
- "Epoch 413,Loss: 3.6996009349823\n",
- "Epoch 414,Loss: 3.709993362426758\n",
- "Epoch 415,Loss: 3.6976137161254883\n",
- "Epoch 416,Loss: 3.69256329536438\n",
- "Epoch 417,Loss: 3.6903862953186035\n",
- "Epoch 418,Loss: 3.6960389614105225\n",
- "Epoch 419,Loss: 3.6840765476226807\n",
- "Epoch 420,Loss: 3.6853392124176025\n",
- "Epoch 421,Loss: 3.693702459335327\n",
- "Epoch 422,Loss: 3.692162036895752\n",
- "Epoch 423,Loss: 3.697707176208496\n",
- "Epoch 424,Loss: 3.683466911315918\n",
- "Epoch 425,Loss: 3.7038679122924805\n",
- "Epoch 426,Loss: 3.700803279876709\n",
- "Epoch 427,Loss: 3.678903818130493\n",
- "Epoch 428,Loss: 3.6846554279327393\n",
- "Epoch 429,Loss: 3.696505308151245\n",
- "Epoch 430,Loss: 3.6857666969299316\n",
- "Epoch 431,Loss: 3.700955867767334\n",
- "Epoch 432,Loss: 3.6856634616851807\n",
- "Epoch 433,Loss: 3.6878204345703125\n",
- "Epoch 434,Loss: 3.6669888496398926\n",
- "Epoch 435,Loss: 3.681831121444702\n",
- "Epoch 436,Loss: 3.6712563037872314\n",
- "Epoch 437,Loss: 3.674145460128784\n",
- "Epoch 438,Loss: 3.6812634468078613\n",
- "Epoch 439,Loss: 3.6820902824401855\n",
- "Epoch 440,Loss: 3.682633638381958\n",
- "Epoch 441,Loss: 3.674262762069702\n",
- "Epoch 442,Loss: 3.6762478351593018\n",
- "Epoch 443,Loss: 3.6800148487091064\n",
- "Epoch 444,Loss: 3.6821205615997314\n",
- "Epoch 445,Loss: 3.679013729095459\n",
- "Epoch 446,Loss: 3.663222551345825\n",
- "Epoch 447,Loss: 3.674647808074951\n",
- "Epoch 448,Loss: 3.674830675125122\n",
- "Epoch 449,Loss: 3.6704530715942383\n",
- "Epoch 450,Loss: 3.6534314155578613\n",
- "Epoch 451,Loss: 3.67617130279541\n",
- "Epoch 452,Loss: 3.6654410362243652\n",
- "Epoch 453,Loss: 3.6626992225646973\n",
- "Epoch 454,Loss: 3.6643929481506348\n",
- "Epoch 455,Loss: 3.663473129272461\n",
- "Epoch 456,Loss: 3.652484893798828\n",
- "Epoch 457,Loss: 3.657602071762085\n",
- "Epoch 458,Loss: 3.659010171890259\n",
- "Epoch 459,Loss: 3.6692850589752197\n",
- "Epoch 460,Loss: 3.6528565883636475\n",
- "Epoch 461,Loss: 3.661115884780884\n",
- "Epoch 462,Loss: 3.6511964797973633\n",
- "Epoch 463,Loss: 3.6498470306396484\n",
- "Epoch 464,Loss: 3.6548194885253906\n",
- "Epoch 465,Loss: 3.6427958011627197\n",
- "Epoch 466,Loss: 3.640429973602295\n",
- "Epoch 467,Loss: 3.6433045864105225\n",
- "Epoch 468,Loss: 3.6328577995300293\n",
- "Epoch 469,Loss: 3.6472063064575195\n",
- "Epoch 470,Loss: 3.6393685340881348\n",
- "Epoch 471,Loss: 3.643303155899048\n",
- "Epoch 472,Loss: 3.635701894760132\n",
- "Epoch 473,Loss: 3.6435654163360596\n",
- "Epoch 474,Loss: 3.6412038803100586\n",
- "Epoch 475,Loss: 3.6328794956207275\n",
- "Epoch 476,Loss: 3.6483938694000244\n",
- "Epoch 477,Loss: 3.6373813152313232\n",
- "Epoch 478,Loss: 3.645163059234619\n",
- "Epoch 479,Loss: 3.650409460067749\n",
- "Epoch 480,Loss: 3.6378026008605957\n",
- "Epoch 481,Loss: 3.651862382888794\n",
- "Epoch 482,Loss: 3.618602752685547\n",
- "Epoch 483,Loss: 3.653233528137207\n",
- "Epoch 484,Loss: 3.6276497840881348\n",
- "Epoch 485,Loss: 3.6483981609344482\n",
- "Epoch 486,Loss: 3.6308512687683105\n",
- "Epoch 487,Loss: 3.6225945949554443\n",
- "Epoch 488,Loss: 3.6241683959960938\n",
- "Epoch 489,Loss: 3.625032901763916\n",
- "Epoch 490,Loss: 3.614368438720703\n",
- "Epoch 491,Loss: 3.6298270225524902\n",
- "Epoch 492,Loss: 3.638739585876465\n",
- "Epoch 493,Loss: 3.6167731285095215\n",
- "Epoch 494,Loss: 3.6361453533172607\n",
- "Epoch 495,Loss: 3.626582622528076\n",
- "Epoch 496,Loss: 3.6187684535980225\n",
- "Epoch 497,Loss: 3.625209331512451\n",
- "Epoch 498,Loss: 3.621763229370117\n",
- "step 499: train loss 3.6193, val loss 3.6217\n",
- "Epoch 499,Loss: 3.614203929901123\n",
- "TIME jax 5.985821008682251\n"
- ]
- }
- ],
- "source": [
- "prng = jax.random.PRNGKey(1337)\n",
- "\n",
- "# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
- "with open('input.txt', 'r', encoding='utf-8') as f:\n",
- " text = f.read()\n",
- "\n",
- "# hyperparameters\n",
- "batch_size = 16 # how many independent sequences will we process in parallel?\n",
- "block_size = 32 # what is the maximum context length for predictions?\n",
- "max_iters = 500\n",
- "eval_interval = 100\n",
- "learning_rate = 1e-3\n",
- "eval_iters = 10\n",
- "n_embd = 64\n",
- "n_head = 4\n",
- "n_layer = 4\n",
- "dropout = 0.0\n",
- "\n",
- "\n",
- "# here are all the unique characters that occur in this text\n",
- "chars = sorted(list(set(text)))\n",
- "vocab_size = len(chars)\n",
- "# create a mapping from characters to integers\n",
- "stoi = { ch:i for i,ch in enumerate(chars) }\n",
- "itos = { i:ch for i,ch in enumerate(chars) }\n",
- "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
- "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
- "\n",
- "# Train and test splits\n",
- "data = jnp.array(encode(text), dtype=jnp.int32)\n",
- "data = device_put(data)\n",
- "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
- "train_data = data[:n]\n",
- "val_data = data[n:]\n",
- "\n",
- "def get_batch(split, subkey):\n",
- " # generate a small batch of data of inputs x and targets y\n",
- " data = train_data if split == 'train' else val_data\n",
- " t1 = time.time()\n",
- " ix = random.randint(subkey, (batch_size,), 0, len(data) - block_size)\n",
- " t2 = time.time()\n",
- " # x = jnp.stack([data[i:i+block_size] for i in ix])\n",
- " # y = jnp.stack([data[i+1:i+block_size+1] for i in ix])\n",
- " def slice_data(i):\n",
- " return jax.lax.dynamic_slice(data, (i,), (block_size,))\n",
- "\n",
- " x = jax.vmap(slice_data)(ix)\n",
- " y = jax.vmap(slice_data)(ix+1)\n",
- " x, y = device_put(x), device_put(y)\n",
- " # print('TIME rand idx', t2-t1)\n",
- " # print('TIME rand idx fetch', time.time()-t2)\n",
- " return x, y\n",
- "\n",
- "def estimate_loss(params, prng):\n",
- " out = {}\n",
- " for split in ['train', 'val']:\n",
- " losses = jnp.zeros(eval_iters)\n",
- " for k in range(eval_iters):\n",
- " prng, subkey = random.split(prng)\n",
- " X, Y = get_batch(split, subkey)\n",
- " logits, loss = model.apply(params, X, Y)\n",
- " losses = losses.at[k].set(loss)\n",
- " out[split] = losses.mean()\n",
- " return out\n",
- "\n",
- "class BigramLanguageModel(nn.Module):\n",
- " vocab_size: int\n",
- "\n",
- " @nn.compact\n",
- " def __call__(self, idx, targets=None):\n",
- " # Token embedding table\n",
- " embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)\n",
- " logits = embedding_table(idx) # (B,T,C)\n",
- "\n",
- " if targets is None:\n",
- " loss = None\n",
- " else:\n",
- " B, T, C = logits.shape\n",
- " logits = logits.reshape(B*T, C)\n",
- " targets = targets.reshape(B*T)\n",
- " loss = -jnp.sum(jax.nn.one_hot(targets, C) * jax.nn.log_softmax(logits), axis=1).mean()\n",
- "\n",
- " return logits, loss\n",
- "\n",
- " def generate(self, params, key, idx, max_new_tokens):\n",
- " for _ in range(max_new_tokens):\n",
- " logits, _ = self.apply(params, idx)\n",
- " logits = logits[:, -1, :] # (B, C)\n",
- " key, subkey = random.split(key)\n",
- " idx_next = random.categorical(subkey, logits)[:, None] # (B, 1)\n",
- " idx = jnp.concatenate((idx, idx_next), axis=1) # (B, T+1)\n",
- " return idx\n",
- "\n",
- "# Define the optimizer\n",
- "tx = adam(learning_rate)\n",
- "\n",
- "# Initialize model parameters and optimizer state\n",
- "model = BigramLanguageModel(vocab_size)\n",
- "\n",
- "# Initialize model parameters and optimizer\n",
- "params = model.init(prng, jnp.ones((1, 1), jnp.int32))\n",
- "opt_state = tx.init(params)\n",
- "\n",
- "# Loss function (assuming you have a batch of data: xb, yb)\n",
- "def loss_fn(params, xb, yb):\n",
- " logits, loss = model.apply(params, xb, yb)\n",
- " return loss\n",
- "\n",
- "# Update function for a single training step\n",
- "@jax.jit\n",
- "def update_step(params, opt_state, xb, yb):\n",
- " loss, grads = jax.value_and_grad(loss_fn)(params, xb, yb)\n",
- " updates, opt_state = tx.update(grads, opt_state, params)\n",
- " new_params = optax.apply_updates(params, updates)\n",
- " return new_params, opt_state, loss\n",
- "\n",
- "# Training loop (example)\n",
- "batch_size = 32\n",
- "t = time.time()\n",
- "for steps in range(max_iters):\n",
- " # every once in a while evaluate the loss on train and val sets\n",
- " if steps == max_iters - 1 or steps % eval_interval == 0:\n",
- " losses = estimate_loss(params, prng)\n",
- " print(f\"step {steps}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
- " prng, subkey = random.split(prng)\n",
- " xb, yb = get_batch('train', subkey)\n",
- " t2 = time.time()\n",
- " params, opt_state, loss = update_step(params, opt_state, xb, yb)\n",
- " print(f\"Epoch {steps},Loss: {loss}\")\n",
- "print('TIME jax', time.time()-t)"
- ]
- },
- {
- "cell_type": "code",
- "source": [
- "flax_apply_jitted = jax.jit(lambda params, xb, yb: model.apply(params, xb, yb))"
- ],
- "metadata": {
- "id": "IFG4kdaRF7b1"
- },
- "execution_count": 49,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "def generate(params, flax_apply_jitted, key, idx, max_new_tokens):\n",
- " for _ in range(max_new_tokens):\n",
- " logits, _ = flax_apply_jitted(params, idx, None)\n",
- " logits = logits[:, -1, :] # (B, C)\n",
- " key, subkey = random.split(key)\n",
- " idx_next = random.categorical(subkey, logits)[:, None] # (B, 1)\n",
- " idx = jnp.concatenate((idx, idx_next), axis=1) # (B, T+1)\n",
- " return idx\n",
- "\n",
- "%time print(decode(generate(params, flax_apply_jitted, key, jnp.zeros((1, 1), jnp.int32), 500)[0].tolist()))"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "Ly4f5Bm-F5t8",
- "outputId": "37bb46de-bd9a-4fd8-97d2-b6488d3cccbd"
- },
- "execution_count": 54,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n",
- "yD.P.e'wn,CZsvq gPrf$f&W3aypokkuSEz?Paw:YCj?M;x\n",
- "pctexMadJMlTZr,CyhaRoYRJUfrsld,bqlwXxclCHIWu'FYEBldJrby;b!HR'Frcr,?&Nui3;woGFdW psZYhosYO!hHPv:F3uMAHbGoslLIWXd;woxmBMTe UpY ZEP\n",
- "tTk?BlWOPrZP.zNS pjFR,OxrO?!$wNDsXCd;il:c!'lal'uOPGCJeDusSf,E.XgunoK-rJP-ky oHsxxwu,:i3UJgZbBMO;s:\n",
- "oPALGSTE-heWO,tcI3$VaeVY JsTPqaqT-ebedAWhoEv:WFiCykXrvetkGbw'3-N!'oW\n",
- "n\n",
- "qi:FgOyU3Xd wrr gVItNvRo,JbtDAvcfHSKDWh.caNKrf CMr IGs?lbiNDerJg'cHB:rRwAuGq&UDUhOdnmc:&jUSZCuG?mF.An--,EMDfCHfITHs ztXPL U3iSE--AsTxeqf??imOUQfArTnZ.Hgv\n",
- "CPU times: user 1.02 s, sys: 7.95 ms, total: 1.03 s\n",
- "Wall time: 1.03 s\n"
- ]
- }
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "gpuType": "T4",
- "provenance": [],
- "toc_visible": true,
- "include_colab_link": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- },
- "language_info": {
- "name": "python"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}