forked from salesforce/ALBEF
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 91de3c1
Showing
198 changed files
with
19,767 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,343 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "lovely-budapest", | ||
"metadata": {}, | ||
"source": [ | ||
"# This is a notebook that shows how to produce Grad-CAM visualizations for ALBEF" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "czech-surprise", | ||
"metadata": {}, | ||
"source": [ | ||
"# 1. Set the paths for model checkpoint and configuration" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 37, | ||
"id": "institutional-sarah", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model_path = '../VL/Example/refcoco.pth'\n", | ||
"bert_config_path = 'configs/config_bert.json'\n", | ||
"use_cuda = False" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "lovely-passage", | ||
"metadata": {}, | ||
"source": [ | ||
"# 2. Model defination" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 38, | ||
"id": "documented-symbol", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from functools import partial\n", | ||
"from models.vit import VisionTransformer\n", | ||
"from models.xbert import BertConfig, BertModel\n", | ||
"from models.tokenization_bert import BertTokenizer\n", | ||
"\n", | ||
"import torch\n", | ||
"from torch import nn\n", | ||
"from torchvision import transforms\n", | ||
"\n", | ||
"import json\n", | ||
"\n", | ||
"class VL_Transformer_ITM(nn.Module):\n", | ||
" def __init__(self, \n", | ||
" text_encoder = None,\n", | ||
" config_bert = ''\n", | ||
" ):\n", | ||
" super().__init__()\n", | ||
" \n", | ||
" bert_config = BertConfig.from_json_file(config_bert)\n", | ||
"\n", | ||
" self.visual_encoder = VisionTransformer(\n", | ||
" img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, \n", | ||
" mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) \n", | ||
"\n", | ||
" self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) \n", | ||
" \n", | ||
" self.itm_head = nn.Linear(768, 2) \n", | ||
"\n", | ||
" \n", | ||
" def forward(self, image, text):\n", | ||
" image_embeds = self.visual_encoder(image) \n", | ||
"\n", | ||
" image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)\n", | ||
"\n", | ||
" output = self.text_encoder(text.input_ids, \n", | ||
" attention_mask = text.attention_mask,\n", | ||
" encoder_hidden_states = image_embeds,\n", | ||
" encoder_attention_mask = image_atts, \n", | ||
" return_dict = True,\n", | ||
" ) \n", | ||
" \n", | ||
" vl_embeddings = output.last_hidden_state[:,0,:]\n", | ||
" vl_output = self.itm_head(vl_embeddings) \n", | ||
" return vl_output" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "renewable-eight", | ||
"metadata": {}, | ||
"source": [ | ||
"# 3. Text Preprocessing" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 39, | ||
"id": "optional-brooklyn", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import re\n", | ||
"\n", | ||
"def pre_caption(caption,max_words=30):\n", | ||
" caption = re.sub(\n", | ||
" r\"([,.'!?\\\"()*#:;~])\",\n", | ||
" '',\n", | ||
" caption.lower(),\n", | ||
" ).replace('-', ' ').replace('/', ' ')\n", | ||
"\n", | ||
" caption = re.sub(\n", | ||
" r\"\\s{2,}\",\n", | ||
" ' ',\n", | ||
" caption,\n", | ||
" )\n", | ||
" caption = caption.rstrip('\\n') \n", | ||
" caption = caption.strip(' ')\n", | ||
"\n", | ||
" #truncate caption\n", | ||
" caption_words = caption.split(' ')\n", | ||
" if len(caption_words)>max_words:\n", | ||
" caption = ' '.join(caption_words[:max_words]) \n", | ||
" return caption" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "based-roads", | ||
"metadata": {}, | ||
"source": [ | ||
"# 4. Image Preprocessing and Postpressing" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 40, | ||
"id": "subsequent-flesh", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from PIL import Image\n", | ||
"\n", | ||
"import cv2\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"from skimage import transform as skimage_transform\n", | ||
"from scipy.ndimage import filters\n", | ||
"from matplotlib import pyplot as plt\n", | ||
"\n", | ||
"def getAttMap(img, attMap, blur = True, overlap = True):\n", | ||
" attMap -= attMap.min()\n", | ||
" if attMap.max() > 0:\n", | ||
" attMap /= attMap.max()\n", | ||
" attMap = skimage_transform.resize(attMap, (img.shape[:2]), order = 3, mode = 'constant')\n", | ||
" if blur:\n", | ||
" attMap = filters.gaussian_filter(attMap, 0.02*max(img.shape[:2]))\n", | ||
" attMap -= attMap.min()\n", | ||
" attMap /= attMap.max()\n", | ||
" cmap = plt.get_cmap('jet')\n", | ||
" attMapV = cmap(attMap)\n", | ||
" attMapV = np.delete(attMapV, 3, 2)\n", | ||
" if overlap:\n", | ||
" attMap = 1*(1-attMap**0.7).reshape(attMap.shape + (1,))*img + (attMap**0.7).reshape(attMap.shape+(1,)) * attMapV\n", | ||
" return attMap\n", | ||
"\n", | ||
"\n", | ||
"normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n", | ||
"\n", | ||
"transform = transforms.Compose([\n", | ||
" transforms.Resize((384,384),interpolation=Image.BICUBIC),\n", | ||
" transforms.ToTensor(),\n", | ||
" normalize,\n", | ||
"]) " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "occasional-trace", | ||
"metadata": {}, | ||
"source": [ | ||
"# 5. Load model and tokenizer" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 41, | ||
"id": "qualified-sleep", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['bert.pooler.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'bert.pooler.dense.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']\n", | ||
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | ||
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | ||
"Some weights of BertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.10.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.10.crossattention.output.dense.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.dense.bias', 'bert.encoder.layer.6.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.self.value.bias', 'bert.encoder.layer.9.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.self.query.weight', 'bert.encoder.layer.8.crossattention.self.value.weight', 'bert.encoder.layer.9.crossattention.self.query.weight', 'bert.encoder.layer.9.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.8.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.9.crossattention.self.key.weight', 'bert.encoder.layer.9.crossattention.self.value.weight', 'bert.encoder.layer.8.crossattention.self.query.bias', 'bert.encoder.layer.11.crossattention.self.key.bias', 'bert.encoder.layer.10.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.8.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.self.value.weight', 'bert.encoder.layer.11.crossattention.self.query.bias', 'bert.encoder.layer.9.crossattention.self.query.bias', 'bert.encoder.layer.6.crossattention.output.dense.weight', 'bert.encoder.layer.10.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.6.crossattention.self.value.bias', 'bert.encoder.layer.8.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.11.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.self.key.bias', 'bert.encoder.layer.6.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.10.crossattention.self.value.weight', 'bert.encoder.layer.10.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.7.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.self.key.weight', 'bert.encoder.layer.10.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.output.dense.weight', 'bert.encoder.layer.7.crossattention.self.value.bias', 'bert.encoder.layer.11.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.self.value.bias', 'bert.encoder.layer.10.crossattention.self.query.bias']\n", | ||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", | ||
"\n", | ||
"model = VL_Transformer_ITM(text_encoder='bert-base-uncased', config_bert=bert_config_path)\n", | ||
"\n", | ||
"checkpoint = torch.load(model_path, map_location='cpu') \n", | ||
"msg = model.load_state_dict(checkpoint,strict=False)\n", | ||
"model.eval()\n", | ||
"\n", | ||
"block_num = 8\n", | ||
"\n", | ||
"model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.save_attention = True\n", | ||
"\n", | ||
"if use_cuda:\n", | ||
" model.cuda() " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "apparent-captain", | ||
"metadata": {}, | ||
"source": [ | ||
"# 6. Load Image and Text" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 42, | ||
"id": "finite-angle", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"image_path = 'examples/image0.jpg'\n", | ||
"image_pil = Image.open(image_path).convert('RGB') \n", | ||
"image = transform(image_pil).unsqueeze(0) \n", | ||
"\n", | ||
"caption = 'the woman is working on her computer at the desk'\n", | ||
"text = pre_caption(caption)\n", | ||
"text_input = tokenizer(text, return_tensors=\"pt\")\n", | ||
"\n", | ||
"if use_cuda:\n", | ||
" image = image.cuda()\n", | ||
" text_input = text_input.to(image.device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "gorgeous-matrix", | ||
"metadata": {}, | ||
"source": [ | ||
"# 7. Compute GradCAM" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 43, | ||
"id": "driven-termination", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"output = model(image, text_input)\n", | ||
"loss = output[:,1].sum()\n", | ||
"\n", | ||
"model.zero_grad()\n", | ||
"loss.backward() \n", | ||
"\n", | ||
"with torch.no_grad():\n", | ||
" mask = text_input.attention_mask.view(text_input.attention_mask.size(0),1,-1,1,1)\n", | ||
"\n", | ||
" grads=model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.get_attn_gradients()\n", | ||
" cams=model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.get_attention_map()\n", | ||
"\n", | ||
" cams = cams[:, :, :, 1:].reshape(image.size(0), 12, -1, 24, 24) * mask\n", | ||
" grads = grads[:, :, :, 1:].clamp(0).reshape(image.size(0), 12, -1, 24, 24) * mask\n", | ||
"\n", | ||
" gradcam = cams * grads\n", | ||
" gradcam = gradcam[0].mean(0).cpu().detach()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "abroad-northern", | ||
"metadata": {}, | ||
"source": [ | ||
"# 8. Visualize GradCam for each word" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "fourth-cache", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"num_image = len(text_input.input_ids[0]) \n", | ||
"fig, ax = plt.subplots(num_image, 1, figsize=(15,5*num_image))\n", | ||
"\n", | ||
"rgb_image = cv2.imread(image_path)[:, :, ::-1]\n", | ||
"rgb_image = np.float32(rgb_image) / 255\n", | ||
"\n", | ||
"ax[0].imshow(rgb_image)\n", | ||
"ax[0].set_yticks([])\n", | ||
"ax[0].set_xticks([])\n", | ||
"ax[0].set_xlabel(\"Image\")\n", | ||
" \n", | ||
"for i,token_id in enumerate(text_input.input_ids[0][1:]):\n", | ||
" word = tokenizer.decode([token_id])\n", | ||
" gradcam_image = getAttMap(rgb_image, gradcam[i+1])\n", | ||
" ax[i+1].imshow(gradcam_image)\n", | ||
" ax[i+1].set_yticks([])\n", | ||
" ax[i+1].set_xticks([])\n", | ||
" ax[i+1].set_xlabel(word)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.