-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathEmbryo-Classification-Code-with-Resnet18.py
148 lines (114 loc) · 4.75 KB
/
Embryo-Classification-Code-with-Resnet18.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# -*- coding: utf-8 -*-
"""
Created on Wed Oct 11 00:29:35 2023
@author: btouati
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader
import os
from PIL import Image
from PIL import ImageFile
import matplotlib.pyplot as plt
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Chemin vers le répertoire racine de votre jeu de données
data_dir = './data-set-embr/'
# Transformation pour redimensionner et normaliser les images
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# Chargement des données d'entraînement, de validation et de test avec des transformations
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(data_dir, 'valid'), transform=transform)
test_dataset = datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=transform)
# Création d'un dictionnaire pour mapper les classes et les catégories de qualité
# Création des DataLoader pour itérer sur les ensembles de données
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
# Utilisation d'un modèle pré-entraîné ResNet18 (ou tout autre modèle souhaité)
model = models.resnet18(pretrained=True)
# Remplacement de la dernière couche pour la classification en fonction du nombre total de catégories de qualité
num_ftrs = model.fc.in_features
num_quality_categories = 5
model.fc = nn.Linear(num_ftrs, num_quality_categories)
print(model)
criterion = nn.CrossEntropyLoss() # Utilisation de la perte de l'entropie croisée pour la classification
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Entraînement du modèle
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
num_epochs = 15
best_val_acc = 0.0
# Initialisation des listes pour les courbes
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
corrects = 0
total = 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
corrects += (predicted == labels).sum().item()
train_loss = train_loss / len(train_loader.dataset)
train_acc = corrects / total
# Évaluation du modèle sur l'ensemble de validation
model.eval()
val_corrects = 0
val_loss = 0.0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item() * inputs.size(0)
_, preds = torch.max(outputs, 1)
val_corrects += torch.sum(preds == labels.data)
val_loss = val_loss / len(val_loader.dataset)
val_acc = val_corrects.double() / len(val_loader.dataset)
# Ajout des données pour les courbes
train_losses.append(train_loss)
val_losses.append(val_loss)
train_accuracies.append(train_acc)
val_accuracies.append(val_acc)
print(f'Epoch [{epoch}/{num_epochs - 1}] '
f'Train Loss: {train_loss:.4f} '
f'Validation Loss: {val_loss:.4f} '
f'Train Acc: {train_acc:.4f} '
f'Validation Acc: {val_acc:.4f}')
# Sauvegarde du modèle si la précision de validation est la meilleure jusqu'à présent
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), 'best_model_Resnet18.pth')
# Affichage des courbes de perte et d'exactitude
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(range(num_epochs), train_losses, label='Train')
plt.plot(range(num_epochs), val_losses, label='Validation')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(range(num_epochs), train_accuracies, label='Train Accuracy', color='blue')
plt.plot(range(num_epochs), val_accuracies, label='Validation Accuracy', color='green')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()