forked from tensorflow/tfjs-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.js
158 lines (141 loc) · 5.56 KB
/
agent.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs';
import {createDeepQNetwork} from './dqn';
import {getRandomAction, SnakeGame, NUM_ACTIONS, ALL_ACTIONS, getStateTensor} from './snake_game';
import {ReplayMemory} from './replay_memory';
import { assertPositiveInteger } from './utils';
export class SnakeGameAgent {
/**
* Constructor of SnakeGameAgent.
*
* @param {SnakeGame} game A game object.
* @param {object} config The configuration object with the following keys:
* - `replayBufferSize` {number} Size of the replay memory. Must be a
* positive integer.
* - `epsilonInit` {number} Initial value of epsilon (for the epsilon-
* greedy algorithm). Must be >= 0 and <= 1.
* - `epsilonFinal` {number} The final value of epsilon. Must be >= 0 and
* <= 1.
* - `epsilonDecayFrames` {number} The # of frames over which the value of
* `epsilon` decreases from `episloInit` to `epsilonFinal`, via a linear
* schedule.
*/
constructor(game, config) {
assertPositiveInteger(config.epsilonDecayFrames);
this.game = game;
this.epsilonInit = config.epsilonInit;
this.epsilonFinal = config.epsilonFinal;
this.epsilonDecayFrames = config.epsilonDecayFrames;
this.epsilonIncrement_ = (this.epsilonFinal - this.epsilonInit) /
this.epsilonDecayFrames;
this.onlineNetwork =
createDeepQNetwork(game.height, game.width, NUM_ACTIONS);
this.targetNetwork =
createDeepQNetwork(game.height, game.width, NUM_ACTIONS);
// Freeze taget network: it's weights are updated only through copying from
// the online network.
this.targetNetwork.trainable = false;
this.optimizer = tf.train.adam(config.learningRate);
this.replayBufferSize = config.replayBufferSize;
this.replayMemory = new ReplayMemory(config.replayBufferSize);
this.frameCount = 0;
this.reset();
}
reset() {
this.cumulativeReward_ = 0;
this.fruitsEaten_ = 0;
this.game.reset();
}
/**
* Play one step of the game.
*
* @returns {number | null} If this step leads to the end of the game,
* the total reward from the game as a plain number. Else, `null`.
*/
playStep() {
this.epsilon = this.frameCount >= this.epsilonDecayFrames ?
this.epsilonFinal :
this.epsilonInit + this.epsilonIncrement_ * this.frameCount;
this.frameCount++;
// The epsilon-greedy algorithm.
let action;
const state = this.game.getState();
if (Math.random() < this.epsilon) {
// Pick an action at random.
action = getRandomAction();
} else {
// Greedily pick an action based on online DQN output.
tf.tidy(() => {
const stateTensor =
getStateTensor(state, this.game.height, this.game.width)
action = ALL_ACTIONS[
this.onlineNetwork.predict(stateTensor).argMax(-1).dataSync()[0]];
});
}
const {state: nextState, reward, done, fruitEaten} = this.game.step(action);
this.replayMemory.append([state, action, reward, done, nextState]);
this.cumulativeReward_ += reward;
if (fruitEaten) {
this.fruitsEaten_++;
}
const output = {
action,
cumulativeReward: this.cumulativeReward_,
done,
fruitsEaten: this.fruitsEaten_
};
if (done) {
this.reset();
}
return output;
}
/**
* Perform training on a randomly sampled batch from the replay buffer.
*
* @param {number} batchSize Batch size.
* @param {numebr} gamma Reward discount rate. Must be >= 0 and <= 1.
* @param {tf.train.Optimizer} optimizer The optimizer object used to update
* the weights of the online network.
*/
trainOnReplayBatch(batchSize, gamma, optimizer) {
// Get a batch of examples from the replay buffer.
const batch = this.replayMemory.sample(batchSize);
const lossFunction = () => tf.tidy(() => {
const stateTensor = getStateTensor(
batch.map(example => example[0]), this.game.height, this.game.width);
const actionTensor = tf.tensor1d(
batch.map(example => example[1]), 'int32');
const qs = this.onlineNetwork.apply(stateTensor, {training: true})
.mul(tf.oneHot(actionTensor, NUM_ACTIONS)).sum(-1);
const rewardTensor = tf.tensor1d(batch.map(example => example[2]));
const nextStateTensor = getStateTensor(
batch.map(example => example[4]), this.game.height, this.game.width);
const nextMaxQTensor =
this.targetNetwork.predict(nextStateTensor).max(-1);
const doneMask = tf.scalar(1).sub(
tf.tensor1d(batch.map(example => example[3])).asType('float32'));
const targetQs =
rewardTensor.add(nextMaxQTensor.mul(doneMask).mul(gamma));
return tf.losses.meanSquaredError(targetQs, qs);
});
const grads = tf.variableGrads(lossFunction);
optimizer.applyGradients(grads.grads);
tf.dispose(grads);
// TODO(cais): Return the loss value here?
}
}