Skip to content

Commit

Permalink
lab 5
Browse files Browse the repository at this point in the history
  • Loading branch information
sergeyk committed Mar 2, 2021
1 parent c3c8424 commit 16dea21
Show file tree
Hide file tree
Showing 34 changed files with 4,677 additions and 0 deletions.
356 changes: 356 additions & 0 deletions lab5/notebooks/01-look-at-emnist.ipynb

Large diffs are not rendered by default.

620 changes: 620 additions & 0 deletions lab5/notebooks/02-look-at-emnist-lines.ipynb

Large diffs are not rendered by default.

621 changes: 621 additions & 0 deletions lab5/notebooks/02b-look-at-emnist-lines2.ipynb

Large diffs are not rendered by default.

389 changes: 389 additions & 0 deletions lab5/notebooks/03-look-at-iam-lines.ipynb

Large diffs are not rendered by default.

105 changes: 105 additions & 0 deletions lab5/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Lab 5: Experiment Management

## Goals of this lab

- Introduce IAMLines handwriting dataset
- Make EMNISTLines look more like IAMLines with additional data augmentations
- Introduce Weights & Biases
- Run some `LineCNNTransformer` experiments on EMNISTLines, writing notes in W&B
- Start a hyper-parameter sweep

## Follow along

```
git pull
cd lab5
```

## IAMLines Dataset

- Look at `notebooks/03-look-at-iam-lines.ipynb`.
- Code is in `text_recognizer/data/iam_lines.py`, which depends on `text_recognizer/data/iam.py`

## Make EMNISTLines more like IAMLines

To make our synthetic data look more like the real data we want to get to, we need to introduce data augmentations.

- Look at `notebooks/02b-look-at-emnist-lines2.ipynb`
- Code is in `text_recognizer/data/emnist_lines2.py`

## Weights & Biases

Weights & Biases is an experiment tracking tool that ensures you never lose track of your progress.

- Keep track of all experiments in one place
- Easily compare runs
- Create reports to document your progress
- Look at results from the whole team

### Set up W&B

```
wandb init
```

You should see something like:

```
? Which team should we use? (Use arrow keys)
> your_username
Manual Entry
```

Select your username.

```
Which project should we use?
> Create New
```

Select `fsdl-text-recognizer-project`.

How to implement W&B in training code?

Look at `training/run_experiment.py`.

### Your first W&B experiment

```
wandb login
python training/run_experiment.py --max_epochs=100 --gpus='0,' --num_workers=20 --model_class=LineCNNTransformer --data_class=EMNISTLines --window_stride=8 --loss=transformer --wandb
```

You should see a W&B experiment link in the console output.
Click the link to see the progress of your training run.

## Configuring sweeps

Sweeps enable automated trials of hyper-parameters.
W&B provides built in support for running [sweeps](https://docs.wandb.com/library/sweeps).

We've setup an initial configuration file for sweeps in `training/emnist_lines_line_cnn_transformer_sweep.yaml`.
It performs a basic random search across 3 parameters.

There are lots of different [configuration options](https://docs.wandb.com/library/sweeps/configuration) for defining more complex sweeps.
Anytime you modify this configuration you'll need to create a sweep in wandb by running:

```bash
wandb sweep training/emnist_lines_line_cnn_transformer_sweep.yaml
```

```bash
wandb agent $SWEEP_ID
```

### Stopping a sweep

If you choose the **random** sweep strategy, the agent will run forever. Our **grid** search strategy will stop once all options have been tried. You can stop a sweep from the W&B UI, or directly from the terminal. Hitting CTRL-C once will prevent the agent from running a new experiment but allow the current experiment to finish. Hitting CTRL-C again will kill the current running experiment.

## Things to try

- Try to find a settings of hyperparameters for `LineCNNTransformer` (don't forget -- it includes `LineCNN` hyperparams) that trains fastest while reaching low CER
- Perhaps do that by running a sweep!
- Try some experiments with `LineCNNLSTM` if you want
- You can also experiment with
Empty file.
17 changes: 17 additions & 0 deletions lab5/text_recognizer/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .util import BaseDataset
from .base_data_module import BaseDataModule
from .mnist import MNIST

# Hide lines below until Lab 2
from .emnist import EMNIST
from .emnist_lines import EMNISTLines

# Hide lines above until Lab 2

# Hide lines below until Lab 5
from .emnist_lines2 import EMNISTLines2
from .iam_lines import IAMLines

# Hide lines above until Lab 5


100 changes: 100 additions & 0 deletions lab5/text_recognizer/data/base_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Base DataModule class."""
from pathlib import Path
from typing import Dict
import argparse
import os

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import transforms

from text_recognizer import util


def load_and_print_info(data_module_class: type) -> None:
"""Load EMNISTLines and print info."""
parser = argparse.ArgumentParser()
data_module_class.add_to_argparse(parser)
args = parser.parse_args()
dataset = data_module_class(args)
dataset.prepare_data()
dataset.setup()
print(dataset)


def _download_raw_dataset(metadata: Dict, dl_dirname: Path) -> Path:
dl_dirname.mkdir(parents=True, exist_ok=True)
filename = dl_dirname / metadata["filename"]
if filename.exists():
return
print(f"Downloading raw dataset from {metadata['url']} to {filename}...")
util.download_url(metadata["url"], filename)
print("Computing SHA-256...")
sha256 = util.compute_sha256(filename)
if sha256 != metadata["sha256"]:
raise ValueError("Downloaded data file SHA-256 does not match that listed in metadata document.")
return filename


BATCH_SIZE = 128
NUM_WORKERS = 0


class BaseDataModule(pl.LightningDataModule):
"""
Base DataModule.
Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html
"""

def __init__(self, args: argparse.Namespace = None) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.batch_size = self.args.get("batch_size", BATCH_SIZE)
self.num_workers = self.args.get("num_workers", NUM_WORKERS)

# Make sure to set the variables below in subclasses
self.dims = None
self.output_dims = None
self.mapping = None

@classmethod
def data_dirname(cls):
return Path(__file__).resolve().parents[3] / "data"

@staticmethod
def add_to_argparse(parser):
parser.add_argument(
"--batch_size", type=int, default=BATCH_SIZE, help="Number of examples to operate on per forward step."
)
parser.add_argument(
"--num_workers", type=int, default=NUM_WORKERS, help="Number of additional processes to load data."
)
return parser

def config(self):
"""Return important settings of the dataset, which will be passed to instantiate models."""
return {"input_dims": self.dims, "output_dims": self.output_dims, "mapping": self.mapping}

def prepare_data(self):
"""
Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`).
"""
pass

def setup(self, stage=None):
"""
Split into train, val, test, and set dims.
Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.
"""
self.data_train = None
self.data_val = None
self.data_test = None

def train_dataloader(self):
return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)

def val_dataloader(self):
return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)

def test_dataloader(self):
return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
Loading

0 comments on commit 16dea21

Please sign in to comment.