From 3c69ef9e81b824181885a25cb2f607b88eca1a13 Mon Sep 17 00:00:00 2001 From: NewLandTV <86148752+NewLandTV@users.noreply.github.com> Date: Mon, 19 Feb 2024 14:46:17 +0900 Subject: [PATCH] Add Model --- .gitignore | 4 ++++ Main.py | 26 ++++++++++++++++++++++++++ Model.py | 28 ++++++++++++++++++++++++++++ Test.py | 41 +++++++++++++++++++++++++++++++++++++++++ Train.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 147 insertions(+) create mode 100644 Main.py create mode 100644 Model.py create mode 100644 Test.py create mode 100644 Train.py diff --git a/.gitignore b/.gitignore index 68bc17f..4d2492e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# Dataset +Dataset/ +*.jpeg + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/Main.py b/Main.py new file mode 100644 index 0000000..fab51f1 --- /dev/null +++ b/Main.py @@ -0,0 +1,26 @@ +from PIL import Image +import torch +import torchvision.transforms as transforms +from Model import CNN + +transform = transforms.Compose([ + transforms.ToTensor() +]) + +classes = ("Circle", "Square", "Triangle") + +# Load image +image = Image.open("./Input.jpeg") +image = transform(image) + +# Load model +model = CNN() + +model.load_state_dict(torch.load("Model.pth")) +model.eval() + +with torch.no_grad(): + outputs = model(image) + _, predicted = torch.max(outputs.data, 1) + +print(classes[predicted[0]]) diff --git a/Model.py b/Model.py new file mode 100644 index 0000000..b17c438 --- /dev/null +++ b/Model.py @@ -0,0 +1,28 @@ +import torch.nn as nn +import torch.nn.functional as F + +class CNN(nn.Module): + def __init__(self): + super(CNN, self).__init__() + + self.layer1 = nn.Sequential( + nn.Conv2d(3, 32, 3, 1, 1), + nn.ReLU(True), + nn.MaxPool2d(2, 2) + ) + self.layer2 = nn.Sequential( + nn.Conv2d(32, 64, 3, 1, 1), + nn.ReLU(True), + nn.MaxPool2d(2, 2) + ) + self.func1 = nn.Linear(64 * 7 * 7, 128) + self.func2 = nn.Linear(128, 3) + + def forward(self, x): + out = self.layer1(x) + out = self.layer2(out) + out = out.view(-1, 64 * 7 * 7) + out = F.relu(self.func1(out)) + out = self.func2(out) + + return out diff --git a/Test.py b/Test.py new file mode 100644 index 0000000..576b981 --- /dev/null +++ b/Test.py @@ -0,0 +1,41 @@ +import torch +import torchvision +import torchvision.transforms as transforms +from Model import CNN + +batchSize = 4 +transform = transforms.Compose([ + transforms.ToTensor() +]) + +# Load test dataset +testSet = torchvision.datasets.ImageFolder( + root = "./Dataset/Test/", + transform = transform +) +testLoader = torch.utils.data.DataLoader( + testSet, + batch_size = batchSize, + shuffle = True, + num_workers = 0 +) + +# Load model +model = CNN() + +model.load_state_dict(torch.load("Model.pth")) +model.eval() + +correct = 0 +total = 0 + +with torch.no_grad(): + for inputs, labels in testLoader: + outputs = model(inputs) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + +accuracy = correct / total * 100 + +print(f"Accuracy: {accuracy:.2f} %") diff --git a/Train.py b/Train.py new file mode 100644 index 0000000..66526d3 --- /dev/null +++ b/Train.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +from Model import CNN + +batchSize = 4 +transform = transforms.Compose([ + transforms.ToTensor() +]) + +# Load train dataset +trainSet = torchvision.datasets.ImageFolder( + root = "./Dataset/Train/", + transform = transform +) +trainLoader = torch.utils.data.DataLoader( + trainSet, + batch_size = batchSize, + shuffle = True, + num_workers = 0 +) + +# Model training +epochs = 20 +learningRate = 1e-4 +model = CNN() +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam(model.parameters(), lr = learningRate) + +model.train() + +for epoch in range(epochs): + for images, labels in trainLoader: + optimizer.zero_grad() + + outputs = model(images) + loss = criterion(outputs, labels) + + loss.backward() + optimizer.step() + + print(f"[Epoch: {epoch + 1:5d}/{epochs}] Loss: {loss.item()}") + +torch.save(model.state_dict(), "Model.pth") + +print("Finished Training!")