Skip to content

Commit

Permalink
Load & Anlayze Model
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoonseo Kim committed Nov 27, 2020
1 parent a30ac57 commit 4987141
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 28 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
data
data
__pycache__
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
<img alt="Maintenance" src="https://img.shields.io/badge/Maintained%3F-yes-green.svg" />
</a>
<a href="#" target="_blank">
<img alt="License: CS454 20F Team GodYou?" src="https://img.shields.io/badge/License-GodYou?-red.svg" />
<img alt="License: CS454 20F Team GodYou?" src="https://img.shields.io/badge/License-GodYou-red.svg" />
</a>
</p>

Expand Down
45 changes: 45 additions & 0 deletions analyze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from utils.dataloader import val_loader
from models.medium import mediumNet
from models.small import smallNet

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
PATH = './trained_model/medium_74_128px.pth'

# Our Dataset Classes
classes = ('airplane', 'cat', 'dog', 'motorbike', 'person')

model = mediumNet()
trained_weight = torch.load(PATH, map_location='cpu')
model.load_state_dict(trained_weight)

def custom_imshow(imgList, predicted):

fig = plt.figure()

rows = 2
cols = 2

for i in range(4):
img = imgList[i]
temp = fig.add_subplot(rows, cols, i+1)
temp.set_title(classes[predicted[i]])
temp.imshow(np.transpose(img, (1, 2, 0)))
temp.axis('off')

plt.show()

for idx, data in enumerate(val_loader):

inputs, labels = data
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)

custom_imshow(inputs, predicted)



2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from dataloader import train_loader, val_loader, custom_imshow
from utils.dataloader import train_loader, val_loader, custom_imshow

import matplotlib.pyplot as plt
import numpy as np
Expand Down
Binary file removed medium.pth
Binary file not shown.
Binary file modified models/__pycache__/medium.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/small.cpython-37.pyc
Binary file not shown.
7 changes: 2 additions & 5 deletions models/medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@ def __init__(self) :
self.batchConv3 = nn.BatchNorm2d(16)
self.conv4 = nn.Conv2d(16,32,3)
self.batchConv4 = nn.BatchNorm2d(32)
self.conv5 = nn.Conv2d(32,48,3)
self.batchConv5 = nn.BatchNorm2d(48)

self.pool = nn.MaxPool2d(2,2)

self.fc1 = nn.Linear(768,256)
self.fc2 = nn.Linear(256,128)
self.fc1 = nn.Linear(1152,512)
self.fc2 = nn.Linear(512,128)
self.fc3 = nn.Linear(128,5)

self.dropout = nn.Dropout(p=0.5)
Expand All @@ -42,7 +40,6 @@ def forward(self, x) :
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = self.pool(F.relu(self.conv4(x)))
x = F.relu(self.conv5(x))

x = x.view(batchSize, -1)
x = self.dropout(x)
Expand Down
6 changes: 4 additions & 2 deletions models/small.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def __init__(self) :
self.pool = nn.MaxPool2d(2,2)

self.fc1 = nn.Linear(2304, 1024)
self.fc2 = nn.Linear(1024,5)
self.fc2 = nn.Linear(1024,512)
self.fc3 = nn.Linear(512,5)

nn.init.kaiming_normal_(self.fc1.weight)
nn.init.kaiming_normal_(self.fc2.weight)
Expand All @@ -29,6 +30,7 @@ def forward(self, x) :

x = x.view(batchSize, -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
x = F.relu(self.fc2(x))
x = self.fc3(x)

return x
Binary file added trained_model/small_65_128px.pth
Binary file not shown.
20 changes: 2 additions & 18 deletions dataloader.py → utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,10 @@
]))

val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
batch_size=10,
shuffle=False
batch_size=4,
shuffle=True
)

def custom_imshow(imgList, labelList):

fig = plt.figure()

rows = 2
cols = 2

for i in range(4):
img = imgList[i]
temp = fig.add_subplot(rows, cols, i+1)
temp.imshow(np.transpose(img, (1, 2, 0)))
temp.axis('off')


plt.show()

if __name__ == "__main__":
for batch_idx, data in enumerate(train_loader) :
inputs, labels = data
Expand Down

0 comments on commit 4987141

Please sign in to comment.