Skip to content

Commit

Permalink
feat: Added CNN class and training function in cnn
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 24, 2023
1 parent 4d62685 commit d4ed50f
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions src/cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import torch.nn as nn
import torch.optim as optim

class CNN(nn.Module):
"""
Convolutional Neural Network (CNN) class.
"""
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(7 * 7 * 64, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = x.view(-1, 7 * 7 * 64)
x = self.fc1(x)
x = self.fc2(x)
return x

def train_cnn(model, dataloader, epochs):
"""
Function to train the CNN model.
"""
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

for epoch in range(epochs):
for images, labels in dataloader:
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()

0 comments on commit d4ed50f

Please sign in to comment.