diff --git a/src/cnn.py b/src/cnn.py new file mode 100644 index 0000000..a4eac83 --- /dev/null +++ b/src/cnn.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +import torch.optim as optim + +class CNN(nn.Module): + """ + Convolutional Neural Network (CNN) class. + """ + def __init__(self): + super(CNN, self).__init__() + self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.fc1 = nn.Linear(7 * 7 * 64, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.pool1(self.relu1(self.conv1(x))) + x = self.pool2(self.relu2(self.conv2(x))) + x = x.view(-1, 7 * 7 * 64) + x = self.fc1(x) + x = self.fc2(x) + return x + +def train_cnn(model, dataloader, epochs): + """ + Function to train the CNN model. + """ + optimizer = optim.SGD(model.parameters(), lr=0.01) + criterion = nn.CrossEntropyLoss() + + for epoch in range(epochs): + for images, labels in dataloader: + optimizer.zero_grad() + output = model(images) + loss = criterion(output, labels) + loss.backward() + optimizer.step()