Skip to content

Commit

Permalink
First working cartpole
Browse files Browse the repository at this point in the history
  • Loading branch information
JosephDenman committed Apr 1, 2023
1 parent 0825bd5 commit 2f52d72
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 87 deletions.
15 changes: 9 additions & 6 deletions games/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def __init__(self):
self.pb_c_base = 19652
self.pb_c_init = 1.25


### Sampling
self.action_shape = [1, 2]
self.sample_size = 3
self.policy_distribution = torch.distributions.Categorical

### Network
self.network = "fullyconnected" # "resnet" / "fullyconnected"
Expand All @@ -66,10 +69,10 @@ def __init__(self):
# Fully Connected Network
self.encoding_size = 8
self.fc_representation_layers = [] # Define the hidden layers in the representation network
self.fc_dynamics_layers = [16] # Define the hidden layers in the dynamics network
self.fc_reward_layers = [16] # Define the hidden layers in the reward network
self.fc_value_layers = [16] # Define the hidden layers in the value network
self.fc_policy_layers = [16] # Define the hidden layers in the policy network
self.fc_dynamics_layers = [32] # Define the hidden layers in the dynamics network
self.fc_reward_layers = [32] # Define the hidden layers in the reward network
self.fc_value_layers = [32] # Define the hidden layers in the value network
self.fc_policy_layers = [128] # Define the hidden layers in the policy network



Expand Down Expand Up @@ -148,7 +151,7 @@ def step(self, action):
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done, _ = self.env.step(action)
observation, reward, done, _ = self.env.step(action[0])
return numpy.array([[observation]]), reward, done

def legal_actions(self):
Expand Down
42 changes: 26 additions & 16 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from abc import ABC, abstractmethod

import numpy
import torch


Expand All @@ -10,7 +11,9 @@ def __new__(cls, config):
return MuZeroFullyConnectedNetwork(
config.observation_shape,
config.stacked_observations,
len(config.action_space),
config.action_shape,
config.sample_size,
config.policy_distribution,
config.encoding_size,
config.fc_reward_layers,
config.fc_value_layers,
Expand Down Expand Up @@ -82,7 +85,9 @@ def __init__(
self,
observation_shape,
stacked_observations,
action_space_size,
action_shape,
sample_size,
policy_distribution,
encoding_size,
fc_reward_layers,
fc_value_layers,
Expand All @@ -92,7 +97,10 @@ def __init__(
support_size,
):
super().__init__()
self.action_space_size = action_space_size
self.action_shape = action_shape
self.action_bins = numpy.prod(self.action_shape)
self.sample_size = sample_size
self.policy_distribution = policy_distribution
self.full_support_size = 2 * support_size + 1

self.representation_network = torch.nn.DataParallel(
Expand All @@ -106,10 +114,9 @@ def __init__(
encoding_size,
)
)

self.dynamics_encoded_state_network = torch.nn.DataParallel(
mlp(
encoding_size + self.action_space_size,
encoding_size + self.action_bins,
fc_dynamics_layers,
encoding_size,
)
Expand All @@ -119,14 +126,15 @@ def __init__(
)

self.prediction_policy_network = torch.nn.DataParallel(
mlp(encoding_size, fc_policy_layers, self.action_space_size)
mlp(encoding_size, fc_policy_layers, self.action_bins)
)
self.prediction_value_network = torch.nn.DataParallel(
mlp(encoding_size, fc_value_layers, self.full_support_size)
)

def prediction(self, encoded_state):
policy_logits = self.prediction_policy_network(encoded_state)
policy_logits = torch.reshape(policy_logits, [encoded_state.shape[0]] + self.action_shape)
value = self.prediction_value_network(encoded_state)
return policy_logits, value

Expand All @@ -146,18 +154,12 @@ def representation(self, observation):

def dynamics(self, encoded_state, action):
# Stack encoded_state with a game specific one hot encoded action (See paper appendix Network Architecture)
action_one_hot = (
torch.zeros((action.shape[0], self.action_space_size))
.to(action.device)
.float()
)
action_one_hot.scatter_(1, action.long(), 1.0)
action_one_hot = torch.zeros((action.shape[0], *self.action_shape)).to(action.device).float()
action_one_hot = torch.scatter(action_one_hot, 2, action.long().unsqueeze(-1), 1.0)
action_one_hot = torch.reshape(action_one_hot, (encoded_state.shape[0], self.action_bins))
x = torch.cat((encoded_state, action_one_hot), dim=1)

next_encoded_state = self.dynamics_encoded_state_network(x)

reward = self.dynamics_reward_network(next_encoded_state)

# Scale encoded state between [0, 1] (See paper appendix Training)
min_next_encoded_state = next_encoded_state.min(1, keepdim=True)[0]
max_next_encoded_state = next_encoded_state.max(1, keepdim=True)[0]
Expand All @@ -166,9 +168,17 @@ def dynamics(self, encoded_state, action):
next_encoded_state_normalized = (
next_encoded_state - min_next_encoded_state
) / scale_next_encoded_state

return next_encoded_state_normalized, reward

def sample_actions(self, policy_parameters):
return self.policy_distribution(logits=policy_parameters).sample((self.sample_size,)).squeeze(-1)

def log_prob(self, policy_parameters_batch, sampled_actions_batch):
batch_size, *action_size = tuple(policy_parameters_batch.shape)
policy_parameters_batch = policy_parameters_batch.unsqueeze(-2).expand(
torch.Size((batch_size, self.sample_size, *action_size)))
return self.policy_distribution(logits=policy_parameters_batch).log_prob(sampled_actions_batch).squeeze(-1)

def initial_inference(self, observation):
encoded_state = self.representation(observation)
policy_logits, value = self.prediction(encoded_state)
Expand Down
33 changes: 24 additions & 9 deletions replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,17 @@ def get_batch(self):
reward_batch,
value_batch,
policy_batch,
sampled_action_batch,
gradient_scale_batch,
) = ([], [], [], [], [], [], [])
) = ([], [], [], [], [], [], [], [])
weight_batch = [] if self.config.PER else None

for game_id, game_history, game_prob in self.sample_n_games(
self.config.batch_size
):
game_pos, pos_prob = self.sample_position(game_history)

values, rewards, policies, actions = self.make_target(
values, rewards, policies, sampled_actions, actions = self.make_target(
game_history, game_pos
)

Expand All @@ -93,13 +94,14 @@ def get_batch(self):
game_history.get_stacked_observations(
game_pos,
self.config.stacked_observations,
len(self.config.action_space),
self.config.sample_size
)
)
action_batch.append(actions)
value_batch.append(values)
reward_batch.append(rewards)
policy_batch.append(policies)
sampled_action_batch.append(sampled_actions)
gradient_scale_batch.append(
[
min(
Expand Down Expand Up @@ -132,6 +134,7 @@ def get_batch(self):
value_batch,
reward_batch,
policy_batch,
sampled_action_batch,
weight_batch,
gradient_scale_batch,
),
Expand Down Expand Up @@ -261,20 +264,28 @@ def compute_target_value(self, game_history, index):

return value

def pad_sampled_actions(self, sampled_actions):
return list(map(lambda inner_sampled_actions: (inner_sampled_actions + self.config.sample_size * [
inner_sampled_actions[0]])[:self.config.sample_size], sampled_actions))

def pad_target_policies(self, target_policies):
return list(map(lambda target_policy: (target_policy + self.config.sample_size * [0])[:self.config.sample_size],
target_policies))

def make_target(self, game_history, state_index):
"""
Generate targets for every unroll steps.
"""
target_values, target_rewards, target_policies, actions = [], [], [], []
target_values, target_rewards, target_policies, sampled_actions, actions = [], [], [], [], []
for current_index in range(
state_index, state_index + self.config.num_unroll_steps + 1
):
value = self.compute_target_value(game_history, current_index)

if current_index < len(game_history.root_values):
target_values.append(value)
target_rewards.append(game_history.reward_history[current_index])
target_policies.append(game_history.child_visits[current_index])
sampled_actions.append(game_history.sampled_actions[current_index])
actions.append(game_history.action_history[current_index])
elif current_index == len(game_history.root_values):
target_values.append(0)
Expand All @@ -286,6 +297,7 @@ def make_target(self, game_history, state_index):
for _ in range(len(game_history.child_visits[0]))
]
)
sampled_actions.append(game_history.sampled_actions[0])
actions.append(game_history.action_history[current_index])
else:
# States past the end of games are treated as absorbing states
Expand All @@ -298,9 +310,12 @@ def make_target(self, game_history, state_index):
for _ in range(len(game_history.child_visits[0]))
]
)
actions.append(numpy.random.choice(self.config.action_space))

return target_values, target_rewards, target_policies, actions
sampled_actions.append(game_history.sampled_actions[0])
actions.append(
game_history.sampled_actions[0][numpy.random.choice(len(game_history.sampled_actions[0]))])
target_policies = self.pad_target_policies(target_policies)
sampled_actions = self.pad_sampled_actions(sampled_actions)
return target_values, target_rewards, target_policies, sampled_actions, actions


@ray.remote
Expand Down Expand Up @@ -347,7 +362,7 @@ def reanalyse(self, replay_buffer, shared_storage):
game_history.get_stacked_observations(
i,
self.config.stacked_observations,
len(self.config.action_space),
self.config.sample_size
)
for i in range(len(game_history.root_values))
]
Expand Down
Binary file removed results/cartpole/model.checkpoint
Binary file not shown.
Loading

0 comments on commit 2f52d72

Please sign in to comment.