Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
LiJunnan1992 committed Jul 15, 2021
0 parents commit 91de3c1
Show file tree
Hide file tree
Showing 198 changed files with 19,767 additions and 0 deletions.
343 changes: 343 additions & 0 deletions .ipynb_checkpoints/visualization-checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "lovely-budapest",
"metadata": {},
"source": [
"# This is a notebook that shows how to produce Grad-CAM visualizations for ALBEF"
]
},
{
"cell_type": "markdown",
"id": "czech-surprise",
"metadata": {},
"source": [
"# 1. Set the paths for model checkpoint and configuration"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "institutional-sarah",
"metadata": {},
"outputs": [],
"source": [
"model_path = '../VL/Example/refcoco.pth'\n",
"bert_config_path = 'configs/config_bert.json'\n",
"use_cuda = False"
]
},
{
"cell_type": "markdown",
"id": "lovely-passage",
"metadata": {},
"source": [
"# 2. Model defination"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "documented-symbol",
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"from models.vit import VisionTransformer\n",
"from models.xbert import BertConfig, BertModel\n",
"from models.tokenization_bert import BertTokenizer\n",
"\n",
"import torch\n",
"from torch import nn\n",
"from torchvision import transforms\n",
"\n",
"import json\n",
"\n",
"class VL_Transformer_ITM(nn.Module):\n",
" def __init__(self, \n",
" text_encoder = None,\n",
" config_bert = ''\n",
" ):\n",
" super().__init__()\n",
" \n",
" bert_config = BertConfig.from_json_file(config_bert)\n",
"\n",
" self.visual_encoder = VisionTransformer(\n",
" img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, \n",
" mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) \n",
"\n",
" self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) \n",
" \n",
" self.itm_head = nn.Linear(768, 2) \n",
"\n",
" \n",
" def forward(self, image, text):\n",
" image_embeds = self.visual_encoder(image) \n",
"\n",
" image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)\n",
"\n",
" output = self.text_encoder(text.input_ids, \n",
" attention_mask = text.attention_mask,\n",
" encoder_hidden_states = image_embeds,\n",
" encoder_attention_mask = image_atts, \n",
" return_dict = True,\n",
" ) \n",
" \n",
" vl_embeddings = output.last_hidden_state[:,0,:]\n",
" vl_output = self.itm_head(vl_embeddings) \n",
" return vl_output"
]
},
{
"cell_type": "markdown",
"id": "renewable-eight",
"metadata": {},
"source": [
"# 3. Text Preprocessing"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "optional-brooklyn",
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"\n",
"def pre_caption(caption,max_words=30):\n",
" caption = re.sub(\n",
" r\"([,.'!?\\\"()*#:;~])\",\n",
" '',\n",
" caption.lower(),\n",
" ).replace('-', ' ').replace('/', ' ')\n",
"\n",
" caption = re.sub(\n",
" r\"\\s{2,}\",\n",
" ' ',\n",
" caption,\n",
" )\n",
" caption = caption.rstrip('\\n') \n",
" caption = caption.strip(' ')\n",
"\n",
" #truncate caption\n",
" caption_words = caption.split(' ')\n",
" if len(caption_words)>max_words:\n",
" caption = ' '.join(caption_words[:max_words]) \n",
" return caption"
]
},
{
"cell_type": "markdown",
"id": "based-roads",
"metadata": {},
"source": [
"# 4. Image Preprocessing and Postpressing"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "subsequent-flesh",
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"\n",
"import cv2\n",
"import numpy as np\n",
"\n",
"from skimage import transform as skimage_transform\n",
"from scipy.ndimage import filters\n",
"from matplotlib import pyplot as plt\n",
"\n",
"def getAttMap(img, attMap, blur = True, overlap = True):\n",
" attMap -= attMap.min()\n",
" if attMap.max() > 0:\n",
" attMap /= attMap.max()\n",
" attMap = skimage_transform.resize(attMap, (img.shape[:2]), order = 3, mode = 'constant')\n",
" if blur:\n",
" attMap = filters.gaussian_filter(attMap, 0.02*max(img.shape[:2]))\n",
" attMap -= attMap.min()\n",
" attMap /= attMap.max()\n",
" cmap = plt.get_cmap('jet')\n",
" attMapV = cmap(attMap)\n",
" attMapV = np.delete(attMapV, 3, 2)\n",
" if overlap:\n",
" attMap = 1*(1-attMap**0.7).reshape(attMap.shape + (1,))*img + (attMap**0.7).reshape(attMap.shape+(1,)) * attMapV\n",
" return attMap\n",
"\n",
"\n",
"normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
"\n",
"transform = transforms.Compose([\n",
" transforms.Resize((384,384),interpolation=Image.BICUBIC),\n",
" transforms.ToTensor(),\n",
" normalize,\n",
"]) "
]
},
{
"cell_type": "markdown",
"id": "occasional-trace",
"metadata": {},
"source": [
"# 5. Load model and tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "qualified-sleep",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['bert.pooler.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'bert.pooler.dense.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of BertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.10.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.10.crossattention.output.dense.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.dense.bias', 'bert.encoder.layer.6.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.self.value.bias', 'bert.encoder.layer.9.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.self.query.weight', 'bert.encoder.layer.8.crossattention.self.value.weight', 'bert.encoder.layer.9.crossattention.self.query.weight', 'bert.encoder.layer.9.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.8.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.9.crossattention.self.key.weight', 'bert.encoder.layer.9.crossattention.self.value.weight', 'bert.encoder.layer.8.crossattention.self.query.bias', 'bert.encoder.layer.11.crossattention.self.key.bias', 'bert.encoder.layer.10.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.8.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.self.value.weight', 'bert.encoder.layer.11.crossattention.self.query.bias', 'bert.encoder.layer.9.crossattention.self.query.bias', 'bert.encoder.layer.6.crossattention.output.dense.weight', 'bert.encoder.layer.10.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.6.crossattention.self.value.bias', 'bert.encoder.layer.8.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.11.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.self.key.bias', 'bert.encoder.layer.6.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.10.crossattention.self.value.weight', 'bert.encoder.layer.10.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.7.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.self.key.weight', 'bert.encoder.layer.10.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.output.dense.weight', 'bert.encoder.layer.7.crossattention.self.value.bias', 'bert.encoder.layer.11.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.self.value.bias', 'bert.encoder.layer.10.crossattention.self.query.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
"\n",
"model = VL_Transformer_ITM(text_encoder='bert-base-uncased', config_bert=bert_config_path)\n",
"\n",
"checkpoint = torch.load(model_path, map_location='cpu') \n",
"msg = model.load_state_dict(checkpoint,strict=False)\n",
"model.eval()\n",
"\n",
"block_num = 8\n",
"\n",
"model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.save_attention = True\n",
"\n",
"if use_cuda:\n",
" model.cuda() "
]
},
{
"cell_type": "markdown",
"id": "apparent-captain",
"metadata": {},
"source": [
"# 6. Load Image and Text"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "finite-angle",
"metadata": {},
"outputs": [],
"source": [
"image_path = 'examples/image0.jpg'\n",
"image_pil = Image.open(image_path).convert('RGB') \n",
"image = transform(image_pil).unsqueeze(0) \n",
"\n",
"caption = 'the woman is working on her computer at the desk'\n",
"text = pre_caption(caption)\n",
"text_input = tokenizer(text, return_tensors=\"pt\")\n",
"\n",
"if use_cuda:\n",
" image = image.cuda()\n",
" text_input = text_input.to(image.device)"
]
},
{
"cell_type": "markdown",
"id": "gorgeous-matrix",
"metadata": {},
"source": [
"# 7. Compute GradCAM"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "driven-termination",
"metadata": {},
"outputs": [],
"source": [
"output = model(image, text_input)\n",
"loss = output[:,1].sum()\n",
"\n",
"model.zero_grad()\n",
"loss.backward() \n",
"\n",
"with torch.no_grad():\n",
" mask = text_input.attention_mask.view(text_input.attention_mask.size(0),1,-1,1,1)\n",
"\n",
" grads=model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.get_attn_gradients()\n",
" cams=model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.get_attention_map()\n",
"\n",
" cams = cams[:, :, :, 1:].reshape(image.size(0), 12, -1, 24, 24) * mask\n",
" grads = grads[:, :, :, 1:].clamp(0).reshape(image.size(0), 12, -1, 24, 24) * mask\n",
"\n",
" gradcam = cams * grads\n",
" gradcam = gradcam[0].mean(0).cpu().detach()"
]
},
{
"cell_type": "markdown",
"id": "abroad-northern",
"metadata": {},
"source": [
"# 8. Visualize GradCam for each word"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fourth-cache",
"metadata": {},
"outputs": [],
"source": [
"num_image = len(text_input.input_ids[0]) \n",
"fig, ax = plt.subplots(num_image, 1, figsize=(15,5*num_image))\n",
"\n",
"rgb_image = cv2.imread(image_path)[:, :, ::-1]\n",
"rgb_image = np.float32(rgb_image) / 255\n",
"\n",
"ax[0].imshow(rgb_image)\n",
"ax[0].set_yticks([])\n",
"ax[0].set_xticks([])\n",
"ax[0].set_xlabel(\"Image\")\n",
" \n",
"for i,token_id in enumerate(text_input.input_ids[0][1:]):\n",
" word = tokenizer.decode([token_id])\n",
" gradcam_image = getAttMap(rgb_image, gradcam[i+1])\n",
" ax[i+1].imshow(gradcam_image)\n",
" ax[i+1].set_yticks([])\n",
" ax[i+1].set_xticks([])\n",
" ax[i+1].set_xlabel(word)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit 91de3c1

Please sign in to comment.