diff --git a/.gitignore b/.gitignore index 6320cd2..ce99277 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -data \ No newline at end of file +data +__pycache__ \ No newline at end of file diff --git a/README.md b/README.md index 0cae1c0..44ed436 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Maintenance - License: CS454 20F Team GodYou? + License: CS454 20F Team GodYou?

diff --git a/analyze.py b/analyze.py new file mode 100644 index 0000000..4f77af4 --- /dev/null +++ b/analyze.py @@ -0,0 +1,45 @@ +import os +import torch +import numpy as np +import matplotlib.pyplot as plt + +from utils.dataloader import val_loader +from models.medium import mediumNet +from models.small import smallNet + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" +PATH = './trained_model/medium_74_128px.pth' + +# Our Dataset Classes +classes = ('airplane', 'cat', 'dog', 'motorbike', 'person') + +model = mediumNet() +trained_weight = torch.load(PATH, map_location='cpu') +model.load_state_dict(trained_weight) + +def custom_imshow(imgList, predicted): + + fig = plt.figure() + + rows = 2 + cols = 2 + + for i in range(4): + img = imgList[i] + temp = fig.add_subplot(rows, cols, i+1) + temp.set_title(classes[predicted[i]]) + temp.imshow(np.transpose(img, (1, 2, 0))) + temp.axis('off') + + plt.show() + +for idx, data in enumerate(val_loader): + + inputs, labels = data + outputs = model(inputs) + _, predicted = torch.max(outputs.data, 1) + + custom_imshow(inputs, predicted) + + + diff --git a/main.py b/main.py index 467b05e..df6ce41 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,5 @@ import torch -from dataloader import train_loader, val_loader, custom_imshow +from utils.dataloader import train_loader, val_loader, custom_imshow import matplotlib.pyplot as plt import numpy as np diff --git a/medium.pth b/medium.pth deleted file mode 100644 index 7dbd243..0000000 Binary files a/medium.pth and /dev/null differ diff --git a/models/__pycache__/medium.cpython-37.pyc b/models/__pycache__/medium.cpython-37.pyc index b2fe728..206a2aa 100644 Binary files a/models/__pycache__/medium.cpython-37.pyc and b/models/__pycache__/medium.cpython-37.pyc differ diff --git a/models/__pycache__/small.cpython-37.pyc b/models/__pycache__/small.cpython-37.pyc index c4c2a57..0c85755 100644 Binary files a/models/__pycache__/small.cpython-37.pyc and b/models/__pycache__/small.cpython-37.pyc differ diff --git a/models/medium.py b/models/medium.py index 49825ce..4651770 100644 --- a/models/medium.py +++ b/models/medium.py @@ -15,13 +15,11 @@ def __init__(self) : self.batchConv3 = nn.BatchNorm2d(16) self.conv4 = nn.Conv2d(16,32,3) self.batchConv4 = nn.BatchNorm2d(32) - self.conv5 = nn.Conv2d(32,48,3) - self.batchConv5 = nn.BatchNorm2d(48) self.pool = nn.MaxPool2d(2,2) - self.fc1 = nn.Linear(768,256) - self.fc2 = nn.Linear(256,128) + self.fc1 = nn.Linear(1152,512) + self.fc2 = nn.Linear(512,128) self.fc3 = nn.Linear(128,5) self.dropout = nn.Dropout(p=0.5) @@ -42,7 +40,6 @@ def forward(self, x) : x = self.pool(F.relu(self.conv2(x))) x = self.pool(F.relu(self.conv3(x))) x = self.pool(F.relu(self.conv4(x))) - x = F.relu(self.conv5(x)) x = x.view(batchSize, -1) x = self.dropout(x) diff --git a/models/small.py b/models/small.py index 6cf5712..b498afd 100644 --- a/models/small.py +++ b/models/small.py @@ -14,7 +14,8 @@ def __init__(self) : self.pool = nn.MaxPool2d(2,2) self.fc1 = nn.Linear(2304, 1024) - self.fc2 = nn.Linear(1024,5) + self.fc2 = nn.Linear(1024,512) + self.fc3 = nn.Linear(512,5) nn.init.kaiming_normal_(self.fc1.weight) nn.init.kaiming_normal_(self.fc2.weight) @@ -29,6 +30,7 @@ def forward(self, x) : x = x.view(batchSize, -1) x = F.relu(self.fc1(x)) - x = self.fc2(x) + x = F.relu(self.fc2(x)) + x = self.fc3(x) return x \ No newline at end of file diff --git a/trained_model/small_65_128px.pth b/trained_model/small_65_128px.pth new file mode 100644 index 0000000..3fc2195 Binary files /dev/null and b/trained_model/small_65_128px.pth differ diff --git a/dataloader.py b/utils/dataloader.py similarity index 80% rename from dataloader.py rename to utils/dataloader.py index 0def30c..8c13d5a 100644 --- a/dataloader.py +++ b/utils/dataloader.py @@ -29,26 +29,10 @@ ])) val_loader = torch.utils.data.DataLoader(dataset=val_dataset, - batch_size=10, - shuffle=False + batch_size=4, + shuffle=True ) -def custom_imshow(imgList, labelList): - - fig = plt.figure() - - rows = 2 - cols = 2 - - for i in range(4): - img = imgList[i] - temp = fig.add_subplot(rows, cols, i+1) - temp.imshow(np.transpose(img, (1, 2, 0))) - temp.axis('off') - - - plt.show() - if __name__ == "__main__": for batch_idx, data in enumerate(train_loader) : inputs, labels = data