forked from zawagner22/cross-entropy-for-combinatorics
-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
31 lines (25 loc) · 1008 Bytes
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import trange, tqdm
def train_network(model, optimizer, train_loader,
num_epochs=1, pbar_update_interval=200, print_logs=False):
'''
Updates the model parameters (in place) using the given optimizer object.
Returns `None`.
'''
criterion = nn.BCELoss()
pbar = trange(num_epochs) if print_logs else range(num_epochs)
for i in pbar:
for k, batch_data in enumerate(train_loader):
batch_x = batch_data[:, :-1]
batch_y = batch_data[:, -1]
model.zero_grad()
y_pred = model(batch_x)
loss = criterion(y_pred, batch_y.unsqueeze(1))
loss.backward()
optimizer.step()
if print_logs and k % pbar_update_interval == 0:
acc = (y_pred.round() == batch_y).sum().float()/(len(batch_y))
pbar.set_postfix(loss=loss.item(), acc=acc.item())