diff --git a/Code/1_data_prepare/1_5_compute_mean.py b/Code/1_data_prepare/1_5_compute_mean.py index 28cb90e9..6648e3ee 100644 --- a/Code/1_data_prepare/1_5_compute_mean.py +++ b/Code/1_data_prepare/1_5_compute_mean.py @@ -32,6 +32,7 @@ imgs = np.concatenate((imgs, img), axis=3) print(i) +imgs = np.delete(imgs, 0, axis=3) imgs = imgs.astype(np.float32)/255. diff --git a/Code/4_viewer/2_1_visual_weights.py b/Code/4_viewer/2_1_visual_weights.py new file mode 100644 index 00000000..60c7dde1 --- /dev/null +++ b/Code/4_viewer/2_1_visual_weights.py @@ -0,0 +1,51 @@ +# coding: utf-8 +import torch +import torchvision.utils as vutils +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool1 = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.pool2 = nn.MaxPool2d(2, 2) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool1(F.relu(self.conv1(x))) + x = self.pool2(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + # 定义权值初始化 + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + torch.nn.init.xavier_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + torch.nn.init.normal_(m.weight.data, 0, 0.01) + m.bias.data.zero_() + + +net = Net() # 创建一个网络 + +kernels = net.conv1.weight.detach().clone() +kernels = kernels.view(-1, 1, 5, 5) +img = vutils.make_grid(kernels, normalize=True, scale_each=True, nrow=3) +plt.imshow(img.permute(1, 2, 0)) +plt.axis('off') +plt.show() diff --git a/Code/4_viewer/3_1_visual_featuremaps.py b/Code/4_viewer/3_1_visual_featuremaps.py new file mode 100644 index 00000000..615354ea --- /dev/null +++ b/Code/4_viewer/3_1_visual_featuremaps.py @@ -0,0 +1,56 @@ +# coding: utf-8 +import torch +import torchvision.utils as vutils +import numpy as np +from tensorboardX import SummaryWriter +import torch.nn.functional as F +import torchvision.transforms as transforms +import sys +sys.path.append("..") +from utils.utils import MyDataset, Net, normalize_invert +from torch.utils.data import DataLoader +import matplotlib.pyplot as plt + +vis_layer = 'conv1' +log_dir = '../../Result/visual_featuremaps' +txt_path = '../../Data/visual.txt' +pretrained_path = '../../Data/net_params_72p.pkl' + +net = Net() +pretrained_dict = torch.load(pretrained_path) +net.load_state_dict(pretrained_dict) + +# 数据预处理 +normMean = [0.49139968, 0.48215827, 0.44653124] +normStd = [0.24703233, 0.24348505, 0.26158768] +normTransform = transforms.Normalize(normMean, normStd) +testTransform = transforms.Compose([ + transforms.Resize((32, 32)), + transforms.ToTensor(), + normTransform +]) +# 载入数据 +test_data = MyDataset(txt_path=txt_path, transform=testTransform) +test_loader = DataLoader(dataset=test_data, batch_size=1) +img, label = iter(test_loader).next() + +x = img + + +# Visualize feature maps +features_dict = {} +def get_features(name): + def hook(model, input, output): + features_dict[name] = output.detach() + return hook + + +net.conv1.register_forward_hook(get_features('ext_conv1')) +output = net(x) + +features = features_dict['ext_conv1'].view(-1, 1, 28, 28) + +img = vutils.make_grid(features, normalize=True, scale_each=True, nrow=3) +plt.imshow(img.permute(1, 2, 0)) +plt.axis('off') +plt.show()