-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
26 lines (22 loc) · 757 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import datetime
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import config
import image_train
import main
import utils.csv_record as csv_record
def train(helper, start_epoch, local_model, target_model, is_poison, agent_name_keys):
epochs_submit_update_dict = {}
num_samples_dict = {}
if (
helper.params["type"] == config.TYPE_CIFAR
or helper.params["type"] == config.TYPE_MNIST
or helper.params["type"] == config.TYPE_FASHION_MNIST
):
epochs_submit_update_dict, num_samples_dict = image_train.ImageTrain(
helper, start_epoch, local_model, target_model, is_poison, agent_name_keys
)
return epochs_submit_update_dict, num_samples_dict