-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
119 lines (94 loc) · 3.73 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from typing import Callable, List
import torch
import torch.utils.data as data
import numpy as np
import os
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import cv2
import matplotlib.image as mpimg
import torch.nn.functional as F
class BaselineTrainer:
def __init__(self, model: torch.nn.Module,
loss: Callable,
optimizer: torch.optim.Optimizer,
validate_every=3,
use_cuda=True,
):
self.loss = loss
self.validate_every = validate_every
self.train_losses = []
self.val_losses = []
self.use_cuda = use_cuda
self.optimizer = optimizer
self.model = model
if self.use_cuda:
self.model = model.to(device="cuda:0")
def fit(self, train_data_loader: data.DataLoader,
epoch: int,
val_data_loader: data.DataLoader,
):
avg_loss = 0.
self.model.training = True
for e in range(epoch):
print(f"Start epoch {e + 1}/{epoch}")
n_batch = 0
epoch_loss = 0.0 # Store loss for the current epoch
for i, (x, y) in enumerate(train_data_loader):
# Reset previous gradients
self.optimizer.zero_grad()
# Move data to cuda is necessary:
if self.use_cuda:
x = x.to(device="cuda:0")
y = y.to(device="cuda:0")
out = self.model(x)
loss = self.loss(out, y.squeeze(1))
loss.backward()
# Adjust learning weights
self.optimizer.step()
avg_loss += loss.item()
epoch_loss += loss.item()
n_batch += 1
print(f"\r{i + 1}/{len(train_data_loader)}: loss = {loss.item()}", end='')
# Store the average loss for the current epoch
epoch_loss /= len(train_data_loader)
self.train_losses.append(epoch_loss)
# Test the model on validation dataset every validate_every epochs
if (e + 1) % self.validate_every == 0:
val_loss = self.evaluate(val_data_loader)
print(f"\nValidation Loss after epoch {e + 1}: {val_loss}")
# Assuming you have a list to store validation losses, you can append it here
self.val_losses.append(val_loss)
if e % 5 == 0:
model_filename = 'DeepLabV3Plus_model.pth'
torch.save(self.model.state_dict(), model_filename)
# Plot the training loss
# Save the model state dictionary to a file
self.plot_training_loss(val_losses=self.val_losses)
avg_loss /= len(train_data_loader)
return avg_loss
def plot_training_loss(self, val_losses=None):
plt.plot(self.train_losses, label='Training Loss')
if val_losses:
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.savefig('training_loss_plot.png')
plt.show()
def evaluate(self, val_data_loader: data.DataLoader):
avg_loss = 0.
self.model.eval()
save_dir = "validation_results"
os.makedirs(save_dir, exist_ok=True)
with torch.no_grad():
for i, (x, y) in enumerate(val_data_loader):
if self.use_cuda:
x = x.to(device="cuda:0")
y = y.to(device="cuda:0")
out = self.model(x)
loss = self.loss(out, y)
avg_loss += loss.item()
avg_loss /= len(val_data_loader)
return avg_loss