Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Sampled Muzero #216

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions games/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def __init__(self):
self.pb_c_base = 19652
self.pb_c_init = 1.25


### Sampling
self.action_shape = [1, 2]
self.sample_size = 3
self.policy_distribution = torch.distributions.Categorical

### Network
self.network = "fullyconnected" # "resnet" / "fullyconnected"
Expand All @@ -66,10 +69,10 @@ def __init__(self):
# Fully Connected Network
self.encoding_size = 8
self.fc_representation_layers = [] # Define the hidden layers in the representation network
self.fc_dynamics_layers = [16] # Define the hidden layers in the dynamics network
self.fc_reward_layers = [16] # Define the hidden layers in the reward network
self.fc_value_layers = [16] # Define the hidden layers in the value network
self.fc_policy_layers = [16] # Define the hidden layers in the policy network
self.fc_dynamics_layers = [32] # Define the hidden layers in the dynamics network
self.fc_reward_layers = [32] # Define the hidden layers in the reward network
self.fc_value_layers = [32] # Define the hidden layers in the value network
self.fc_policy_layers = [128] # Define the hidden layers in the policy network



Expand Down Expand Up @@ -148,7 +151,7 @@ def step(self, action):
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done, _ = self.env.step(action)
observation, reward, done, _ = self.env.step(action[0])
return numpy.array([[observation]]), reward, done

def legal_actions(self):
Expand Down
42 changes: 26 additions & 16 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from abc import ABC, abstractmethod

import numpy
import torch


Expand All @@ -10,7 +11,9 @@ def __new__(cls, config):
return MuZeroFullyConnectedNetwork(
config.observation_shape,
config.stacked_observations,
len(config.action_space),
config.action_shape,
config.sample_size,
config.policy_distribution,
config.encoding_size,
config.fc_reward_layers,
config.fc_value_layers,
Expand Down Expand Up @@ -82,7 +85,9 @@ def __init__(
self,
observation_shape,
stacked_observations,
action_space_size,
action_shape,
sample_size,
policy_distribution,
encoding_size,
fc_reward_layers,
fc_value_layers,
Expand All @@ -92,7 +97,10 @@ def __init__(
support_size,
):
super().__init__()
self.action_space_size = action_space_size
self.action_shape = action_shape
self.action_bins = numpy.prod(self.action_shape)
self.sample_size = sample_size
self.policy_distribution = policy_distribution
self.full_support_size = 2 * support_size + 1

self.representation_network = torch.nn.DataParallel(
Expand All @@ -106,10 +114,9 @@ def __init__(
encoding_size,
)
)

self.dynamics_encoded_state_network = torch.nn.DataParallel(
mlp(
encoding_size + self.action_space_size,
encoding_size + self.action_bins,
fc_dynamics_layers,
encoding_size,
)
Expand All @@ -119,14 +126,15 @@ def __init__(
)

self.prediction_policy_network = torch.nn.DataParallel(
mlp(encoding_size, fc_policy_layers, self.action_space_size)
mlp(encoding_size, fc_policy_layers, self.action_bins)
)
self.prediction_value_network = torch.nn.DataParallel(
mlp(encoding_size, fc_value_layers, self.full_support_size)
)

def prediction(self, encoded_state):
policy_logits = self.prediction_policy_network(encoded_state)
policy_logits = torch.reshape(policy_logits, [encoded_state.shape[0]] + self.action_shape)
value = self.prediction_value_network(encoded_state)
return policy_logits, value

Expand All @@ -146,18 +154,12 @@ def representation(self, observation):

def dynamics(self, encoded_state, action):
# Stack encoded_state with a game specific one hot encoded action (See paper appendix Network Architecture)
action_one_hot = (
torch.zeros((action.shape[0], self.action_space_size))
.to(action.device)
.float()
)
action_one_hot.scatter_(1, action.long(), 1.0)
action_one_hot = torch.zeros((action.shape[0], *self.action_shape)).to(action.device).float()
action_one_hot = torch.scatter(action_one_hot, 2, action.long().unsqueeze(-1), 1.0)
action_one_hot = torch.reshape(action_one_hot, (encoded_state.shape[0], self.action_bins))
x = torch.cat((encoded_state, action_one_hot), dim=1)

next_encoded_state = self.dynamics_encoded_state_network(x)

reward = self.dynamics_reward_network(next_encoded_state)

# Scale encoded state between [0, 1] (See paper appendix Training)
min_next_encoded_state = next_encoded_state.min(1, keepdim=True)[0]
max_next_encoded_state = next_encoded_state.max(1, keepdim=True)[0]
Expand All @@ -166,9 +168,17 @@ def dynamics(self, encoded_state, action):
next_encoded_state_normalized = (
next_encoded_state - min_next_encoded_state
) / scale_next_encoded_state

return next_encoded_state_normalized, reward

def sample_actions(self, policy_parameters):
return self.policy_distribution(logits=policy_parameters).sample((self.sample_size,)).squeeze(-1)

def log_prob(self, policy_parameters_batch, sampled_actions_batch):
batch_size, *action_size = tuple(policy_parameters_batch.shape)
policy_parameters_batch = policy_parameters_batch.unsqueeze(-2).expand(
torch.Size((batch_size, self.sample_size, *action_size)))
return self.policy_distribution(logits=policy_parameters_batch).log_prob(sampled_actions_batch).squeeze(-1)

def initial_inference(self, observation):
encoded_state = self.representation(observation)
policy_logits, value = self.prediction(encoded_state)
Expand Down
33 changes: 24 additions & 9 deletions replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,17 @@ def get_batch(self):
reward_batch,
value_batch,
policy_batch,
sampled_action_batch,
gradient_scale_batch,
) = ([], [], [], [], [], [], [])
) = ([], [], [], [], [], [], [], [])
weight_batch = [] if self.config.PER else None

for game_id, game_history, game_prob in self.sample_n_games(
self.config.batch_size
):
game_pos, pos_prob = self.sample_position(game_history)

values, rewards, policies, actions = self.make_target(
values, rewards, policies, sampled_actions, actions = self.make_target(
game_history, game_pos
)

Expand All @@ -93,13 +94,14 @@ def get_batch(self):
game_history.get_stacked_observations(
game_pos,
self.config.stacked_observations,
len(self.config.action_space),
self.config.sample_size
)
)
action_batch.append(actions)
value_batch.append(values)
reward_batch.append(rewards)
policy_batch.append(policies)
sampled_action_batch.append(sampled_actions)
gradient_scale_batch.append(
[
min(
Expand Down Expand Up @@ -132,6 +134,7 @@ def get_batch(self):
value_batch,
reward_batch,
policy_batch,
sampled_action_batch,
weight_batch,
gradient_scale_batch,
),
Expand Down Expand Up @@ -261,20 +264,28 @@ def compute_target_value(self, game_history, index):

return value

def pad_sampled_actions(self, sampled_actions):
return list(map(lambda inner_sampled_actions: (inner_sampled_actions + self.config.sample_size * [
inner_sampled_actions[0]])[:self.config.sample_size], sampled_actions))

def pad_target_policies(self, target_policies):
return list(map(lambda target_policy: (target_policy + self.config.sample_size * [0])[:self.config.sample_size],
target_policies))

def make_target(self, game_history, state_index):
"""
Generate targets for every unroll steps.
"""
target_values, target_rewards, target_policies, actions = [], [], [], []
target_values, target_rewards, target_policies, sampled_actions, actions = [], [], [], [], []
for current_index in range(
state_index, state_index + self.config.num_unroll_steps + 1
):
value = self.compute_target_value(game_history, current_index)

if current_index < len(game_history.root_values):
target_values.append(value)
target_rewards.append(game_history.reward_history[current_index])
target_policies.append(game_history.child_visits[current_index])
sampled_actions.append(game_history.sampled_actions[current_index])
actions.append(game_history.action_history[current_index])
elif current_index == len(game_history.root_values):
target_values.append(0)
Expand All @@ -286,6 +297,7 @@ def make_target(self, game_history, state_index):
for _ in range(len(game_history.child_visits[0]))
]
)
sampled_actions.append(game_history.sampled_actions[0])
actions.append(game_history.action_history[current_index])
else:
# States past the end of games are treated as absorbing states
Expand All @@ -298,9 +310,12 @@ def make_target(self, game_history, state_index):
for _ in range(len(game_history.child_visits[0]))
]
)
actions.append(numpy.random.choice(self.config.action_space))

return target_values, target_rewards, target_policies, actions
sampled_actions.append(game_history.sampled_actions[0])
actions.append(
game_history.sampled_actions[0][numpy.random.choice(len(game_history.sampled_actions[0]))])
target_policies = self.pad_target_policies(target_policies)
sampled_actions = self.pad_sampled_actions(sampled_actions)
return target_values, target_rewards, target_policies, sampled_actions, actions


@ray.remote
Expand Down Expand Up @@ -347,7 +362,7 @@ def reanalyse(self, replay_buffer, shared_storage):
game_history.get_stacked_observations(
i,
self.config.stacked_observations,
len(self.config.action_space),
self.config.sample_size
)
for i in range(len(game_history.root_values))
]
Expand Down
Binary file removed results/cartpole/model.checkpoint
Binary file not shown.
Loading