Skip to content

Commit

Permalink
Adding integration tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
madlag committed Sep 9, 2020
1 parent 807a864 commit 0985083
Show file tree
Hide file tree
Showing 8 changed files with 52,851 additions and 171 deletions.
162 changes: 16 additions & 146 deletions doc/notebooks/01_how_to_train_sparse/01_how_to_train_sparse.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
Expand All @@ -332,7 +332,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
Expand All @@ -349,7 +349,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -359,25 +359,14 @@
"id": "E3Ye27nchfzq",
"outputId": "b9812ed2-1ecd-4e1b-d9bd-7de581955e70"
},
"outputs": [
{
"data": {
"text/plain": [
"Encoding(num_tokens=7, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"tokenizer.encode(\"Mi estas Julien.\")"
]
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -387,18 +376,7 @@
"id": "X8ya5_7rhjKS",
"outputId": "e9e08ded-1081-4823-dd81-9d6be1255385"
},
"outputs": [
{
"data": {
"text/plain": [
"['<s>', 'Mi', 'Ġestas', 'ĠJuli', 'en', '.', '</s>']"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"tokenizer.encode(\"Mi estas Julien.\").tokens"
]
Expand All @@ -421,7 +399,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -431,35 +409,7 @@
"id": "kD140sFjh0LQ",
"outputId": "0bab1f9e-bf7a-4f13-82d3-07fe5866ce78"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wed Sep 2 09:55:50 2020 \r\n",
"+-----------------------------------------------------------------------------+\r\n",
"| NVIDIA-SMI 440.100 Driver Version: 440.100 CUDA Version: 10.2 |\r\n",
"|-------------------------------+----------------------+----------------------+\r\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\r\n",
"|===============================+======================+======================|\r\n",
"| 0 GeForce RTX 208... Off | 00000000:01:00.0 Off | N/A |\r\n",
"| 27% 29C P8 21W / 250W | 98MiB / 11019MiB | 0% Default |\r\n",
"+-------------------------------+----------------------+----------------------+\r\n",
"| 1 GeForce RTX 208... Off | 00000000:03:00.0 Off | N/A |\r\n",
"| 27% 27C P8 19W / 250W | 1MiB / 11019MiB | 0% Default |\r\n",
"+-------------------------------+----------------------+----------------------+\r\n",
" \r\n",
"+-----------------------------------------------------------------------------+\r\n",
"| Processes: GPU Memory |\r\n",
"| GPU PID Type Process name Usage |\r\n",
"|=============================================================================|\r\n",
"| 0 1486 G /usr/lib/xorg/Xorg 39MiB |\r\n",
"| 0 1573 G /usr/bin/gnome-shell 57MiB |\r\n",
"+-----------------------------------------------------------------------------+\r\n"
]
}
],
"outputs": [],
"source": [
"# Check that we have a GPU (you should have only one, as there is currently an issue with Pytorch and Long tensors for multi-gpu)\n",
"# If you have several GPUs, make sure to launch jupyter using CUDA_VISIBLE_DEVICES=0 jupyter notebook .\n",
Expand All @@ -468,7 +418,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -478,18 +428,7 @@
"id": "VNZZs-r6iKAV",
"outputId": "c8404d6c-7662-4240-c8da-ee89edfaf51b"
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Check that PyTorch sees it\n",
"import torch\n",
Expand All @@ -508,7 +447,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
Expand Down Expand Up @@ -539,7 +478,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
Expand Down Expand Up @@ -568,42 +507,13 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "BzMqR-dzF4Ro"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Patching 'roberta.encoder.layer.0.intermediate.dense' with density=0.25, in=768, out=3072,bias=True \n",
"Patching 'roberta.encoder.layer.0.output.dense' with density=0.25, in=3072, out=768,bias=True \n",
"Patching 'roberta.encoder.layer.1.intermediate.dense' with density=0.25, in=768, out=3072,bias=True \n",
"Patching 'roberta.encoder.layer.1.output.dense' with density=0.25, in=3072, out=768,bias=True \n",
"Patching 'roberta.encoder.layer.2.intermediate.dense' with density=0.25, in=768, out=3072,bias=True \n",
"Patching 'roberta.encoder.layer.2.output.dense' with density=0.25, in=3072, out=768,bias=True \n",
"Patching 'roberta.encoder.layer.3.intermediate.dense' with density=0.25, in=768, out=3072,bias=True \n",
"Patching 'roberta.encoder.layer.3.output.dense' with density=0.25, in=3072, out=768,bias=True \n",
"Patching 'roberta.encoder.layer.4.intermediate.dense' with density=0.25, in=768, out=3072,bias=True \n",
"Patching 'roberta.encoder.layer.4.output.dense' with density=0.25, in=3072, out=768,bias=True \n",
"Patching 'roberta.encoder.layer.5.intermediate.dense' with density=0.25, in=768, out=3072,bias=True \n",
"Patching 'roberta.encoder.layer.5.output.dense' with density=0.25, in=3072, out=768,bias=True \n"
]
},
{
"data": {
"text/plain": [
"62861344"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"from pytorch_block_sparse import BlockSparseModelPatcher\n",
"\n",
Expand Down Expand Up @@ -636,7 +546,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -646,52 +556,12 @@
"id": "GlvP_A-THEEl",
"outputId": "e0510a33-7937-4a04-fa1c-d4e20b758bb2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3min 48s, sys: 7.51 s, total: 3min 56s\n",
"Wall time: 45.3 s\n"
]
}
],
"outputs": [],
"source": [
"%%time\n",
"\n",
"from transformers import LineByLineTextDataset\n",
"\n",
"from torch.utils.data.dataset import Dataset\n",
"from transformers.tokenization_utils import PreTrainedTokenizer\n",
"import os\n",
"\n",
"class LineByLineTextDataset2(Dataset):\n",
" \"\"\"\n",
" This will be superseded by a framework-agnostic approach\n",
" soon.\n",
" \"\"\"\n",
"\n",
" def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):\n",
" assert os.path.isfile(file_path), f\"Input file path {file_path} not found\"\n",
" # Here, we do not cache the features, operating under the assumption\n",
" # that we will soon use fast multithreaded tokenizers from the\n",
" # `tokenizers` repo everywhere =)\n",
"# logger.info(\"Creating features from dataset file at %s\", file_path)\n",
"\n",
" with open(file_path, encoding=\"utf-8\") as f:\n",
" lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]\n",
"\n",
" batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size)\n",
" self.examples = batch_encoding[\"input_ids\"]\n",
"\n",
" def __len__(self):\n",
" return len(self.examples)\n",
"\n",
" def __getitem__(self, i) -> torch.Tensor:\n",
" ret = torch.tensor(self.examples[i], dtype=torch.long)\n",
"# print(ret.shape)\n",
" return ret\n",
"\n",
"dataset = LineByLineTextDataset(\n",
" tokenizer=tokenizer,\n",
" file_path=\"./oscar.eo.txt\",\n",
Expand Down
18 changes: 17 additions & 1 deletion pytorch_block_sparse/sparse_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import torch.optim as optim
from pytorch_block_sparse import BlockSparseMatrix

try:
import transformers.optimization as transformers_optim
except:
transformers_optim = None

class SparseOptimizerStrategy:
def run(self, block_sparse_matrix):
raise NotImplementedError()
Expand Down Expand Up @@ -108,6 +113,15 @@ def update_state(self, state_keep_mask):
return found

class AdamOptimizerStateUpdater(OptimizerStateUpdater):
@staticmethod
def is_compatible(optimizer):
if isinstance(optimizer, optim.Adam):
return True

if transformers_optim is not None:
if isinstance(optimizer, transformers_optim.AdamW):
return True

def update_state_data(self, param, state_keep_mask):
opt = self.optimizer

Expand Down Expand Up @@ -218,9 +232,11 @@ def clean(self, p, method, clean_ratio, new_coefficients_scale, new_coefficients
if len(self.attached_optimizers) != 0:
found = False
for optimizer in self.attached_optimizers:
if isinstance(optimizer, optim.Adam):
if AdamOptimizerStateUpdater.is_compatible(optimizer):
updater = AdamOptimizerStateUpdater(optimizer, p)
found = found or updater.update_state(state_keep_mask)
else:
raise Exception(f"unsupported optimizer {optimizer.__class__}")

if not found:
raise Exception(f"Could not find sparse object {p} in optimizers {self.attached_optimizers}")
Expand Down
Loading

0 comments on commit 0985083

Please sign in to comment.