-
Notifications
You must be signed in to change notification settings - Fork 87
/
Copy pathmain.py
153 lines (133 loc) · 4.45 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Run the MLP training and evaluation pipeline.
"""
from model_factory import create_model
# MNIST:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose
# PyTorch:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
# Other:
from typing import Tuple
from tqdm import tqdm
# The transform list is a set of operations that we apply to the data
# before we use it. In this case, we convert the data to a tensor and
# flatten it. (Thought-exercise: Why do we need to flatten the data?)
_transform_list = [
ToTensor(),
lambda x: x.view(-1),
]
def get_mnist_data() -> Tuple[DataLoader, DataLoader]:
"""
Get the MNIST data from torchvision.
Arguments:
None
Returns:
train_loader (DataLoader): The training data loader.
test_loader (DataLoader): The test data loader.
"""
# Get the training data:
train_data = MNIST(
root="data", train=True, download=True, transform=Compose(_transform_list)
)
# Create a data loader for the training data:
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
# Get the test data:
test_data = MNIST(
root="data", train=False, download=True, transform=Compose(_transform_list)
)
# Create a data loader for the test data:
test_loader = DataLoader(test_data, batch_size=64, shuffle=True)
# Return the data loaders:
return train_loader, test_loader
def train(
model: torch.nn.Module,
train_loader: DataLoader,
test_loader: DataLoader,
num_epochs: int,
learning_rate: float,
device: torch.device,
) -> None:
"""
Train a model on the MNIST data.
Arguments:
model (torch.nn.Module): The model to train.
train_loader (DataLoader): The training data loader.
test_loader (DataLoader): The test data loader.
num_epochs (int): The number of epochs to train for.
learning_rate (float): The learning rate to use.
device (torch.device): The device to use for training.
Returns:
None
"""
# Create an optimizer:
optimizer = Adam(model.parameters(), lr=learning_rate)
# Create a loss function:
criterion = CrossEntropyLoss()
# Move the model to the device:
model.to(device)
# Create a progress bar:
progress_bar = tqdm(range(num_epochs))
# Train the model:
for epoch in progress_bar:
# Set the model to training mode:
model.train()
# Iterate over the training data:
for batch in train_loader:
# Get the data and labels:
data, labels = batch
# Move the data and labels to the device:
data = data.to(device)
labels = labels.to(device)
# Zero the gradients:
optimizer.zero_grad()
# Forward pass:
outputs = model(data)
# Calculate the loss:
loss = criterion(outputs, labels)
# Backward pass:
loss.backward()
# Update the parameters:
optimizer.step()
# Set the model to evaluation mode:
model.eval()
# Calculate the accuracy on the test data:
correct = 0
total = 0
with torch.no_grad():
for batch in test_loader:
# Get the data and labels:
data, labels = batch
# Move the data and labels to the device:
data = data.to(device)
labels = labels.to(device)
# Forward pass:
outputs = model(data)
# Get the predictions:
_, predictions = torch.max(outputs.data, 1)
# Update the total and correct counts:
total += labels.size(0)
correct += (predictions == labels).sum().item()
# Calculate the accuracy:
accuracy = correct / total
# Update the progress bar:
progress_bar.set_description(f"Epoch: {epoch}, Accuracy: {accuracy:.4f}")
def main():
# Get the data:
train_loader, test_loader = get_mnist_data()
# Create the model:
model = create_model(784, 10)
# Train the model:
train(
model=model,
train_loader=train_loader,
test_loader=test_loader,
num_epochs=10,
learning_rate=0.001,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
if __name__ == "__main__":
main()