-
Notifications
You must be signed in to change notification settings - Fork 6
/
train.py
95 lines (75 loc) · 2.82 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
Functions you need to edit in this script -
- train_naive()
- train_dagger()
- main()
Feel free to define more functions if required.
Usage: train.py [-h] [--mode {naive,dagger}]
optional arguments:
-h, --help show this help message and exit
--mode {naive,dagger}, -m {naive,dagger}
Sets the training mode. Default : naive
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--mode", '-m', choices = ['naive', 'dagger'], default='naive', help = "Sets the training mode. Default : naive")
args = parser.parse_args()
MODE = args.mode
import numpy as np
import os
import pickle
import gzip
import matplotlib.pyplot as plt
# Load Dependencies #
...
#####################
from model import MODEL_NAME # Change this to whatever name you've given your model class
from model import preprocessing
def load_data(directory = "./data", val_split = 0.1):
"""
Loads the data saved after expert runs.
Input : directory where data.pkl.gzip is located, val_split
Output : X_train, Y_train, X_val, Y_val (training and validations sets with split determined by `val_split`)
"""
data_file = os.path.join(directory, 'data.pkl.gzip')
file = gzip.open(data_file, 'rb')
data = pickle.load(file)
X = np.array(data["state"]).astype('float32')
y = np.array(data["action"]).astype('float32')
# split data into training and validation sets
num_samples = len(data["state"])
val_len = int(val_split*num_samples)
X_train, y_train = X[:-val_len], y[:-val_len]
X_val, y_val = X[-val_len:], y[-val_len:]
return X_train, y_train, X_val, y_val
def train_naive(): # add arguments as needed
"""
Define your training pipeline for naive behavioural cloning. Delete the pass statement once you're done.
Save your trained model under "agents/naive".
This function should return the history of your training and validation metrics.
"""
pass
def train_dagger(): # add arguments as needed
"""
Define your training pipeline for naive behavioural cloning. Delete the pass statement once you're done.
Save your trained model under "agents/dagger".
This function should return the history of your training and validation metrics.
"""
pass
def main():
# Loading the data
X_train, y_train, X_val, y_val = load_data(directory = "./data", val_split = 0.1) # Feel free to experiment with val_split
# Apply preprocessing to observations (if any, delete the next two lines otherwise)
X_train = ...
X_val = ...
# Initialise your model
agent = ...
# Training
if MODE == 'naive':
# Call train_naive(), save its results in results/naive.
pass
else:
# Call train_dagger, save its results in results/dagger.
pass
if __name__ == "__main__":
main()