diff --git a/games/cartpole.py b/games/cartpole.py index fa1e8bbf..0fb7059f 100644 --- a/games/cartpole.py +++ b/games/cartpole.py @@ -46,7 +46,10 @@ def __init__(self): self.pb_c_base = 19652 self.pb_c_init = 1.25 - + ### Sampling + self.action_shape = [1, 2] + self.sample_size = 3 + self.policy_distribution = torch.distributions.Categorical ### Network self.network = "fullyconnected" # "resnet" / "fullyconnected" @@ -66,10 +69,10 @@ def __init__(self): # Fully Connected Network self.encoding_size = 8 self.fc_representation_layers = [] # Define the hidden layers in the representation network - self.fc_dynamics_layers = [16] # Define the hidden layers in the dynamics network - self.fc_reward_layers = [16] # Define the hidden layers in the reward network - self.fc_value_layers = [16] # Define the hidden layers in the value network - self.fc_policy_layers = [16] # Define the hidden layers in the policy network + self.fc_dynamics_layers = [32] # Define the hidden layers in the dynamics network + self.fc_reward_layers = [32] # Define the hidden layers in the reward network + self.fc_value_layers = [32] # Define the hidden layers in the value network + self.fc_policy_layers = [128] # Define the hidden layers in the policy network @@ -148,7 +151,7 @@ def step(self, action): Returns: The new observation, the reward and a boolean if the game has ended. """ - observation, reward, done, _ = self.env.step(action) + observation, reward, done, _ = self.env.step(action[0]) return numpy.array([[observation]]), reward, done def legal_actions(self): diff --git a/models.py b/models.py index be847fef..79ccf478 100644 --- a/models.py +++ b/models.py @@ -1,6 +1,7 @@ import math from abc import ABC, abstractmethod +import numpy import torch @@ -10,7 +11,9 @@ def __new__(cls, config): return MuZeroFullyConnectedNetwork( config.observation_shape, config.stacked_observations, - len(config.action_space), + config.action_shape, + config.sample_size, + config.policy_distribution, config.encoding_size, config.fc_reward_layers, config.fc_value_layers, @@ -82,7 +85,9 @@ def __init__( self, observation_shape, stacked_observations, - action_space_size, + action_shape, + sample_size, + policy_distribution, encoding_size, fc_reward_layers, fc_value_layers, @@ -92,7 +97,10 @@ def __init__( support_size, ): super().__init__() - self.action_space_size = action_space_size + self.action_shape = action_shape + self.action_bins = numpy.prod(self.action_shape) + self.sample_size = sample_size + self.policy_distribution = policy_distribution self.full_support_size = 2 * support_size + 1 self.representation_network = torch.nn.DataParallel( @@ -106,10 +114,9 @@ def __init__( encoding_size, ) ) - self.dynamics_encoded_state_network = torch.nn.DataParallel( mlp( - encoding_size + self.action_space_size, + encoding_size + self.action_bins, fc_dynamics_layers, encoding_size, ) @@ -119,7 +126,7 @@ def __init__( ) self.prediction_policy_network = torch.nn.DataParallel( - mlp(encoding_size, fc_policy_layers, self.action_space_size) + mlp(encoding_size, fc_policy_layers, self.action_bins) ) self.prediction_value_network = torch.nn.DataParallel( mlp(encoding_size, fc_value_layers, self.full_support_size) @@ -127,6 +134,7 @@ def __init__( def prediction(self, encoded_state): policy_logits = self.prediction_policy_network(encoded_state) + policy_logits = torch.reshape(policy_logits, [encoded_state.shape[0]] + self.action_shape) value = self.prediction_value_network(encoded_state) return policy_logits, value @@ -146,18 +154,12 @@ def representation(self, observation): def dynamics(self, encoded_state, action): # Stack encoded_state with a game specific one hot encoded action (See paper appendix Network Architecture) - action_one_hot = ( - torch.zeros((action.shape[0], self.action_space_size)) - .to(action.device) - .float() - ) - action_one_hot.scatter_(1, action.long(), 1.0) + action_one_hot = torch.zeros((action.shape[0], *self.action_shape)).to(action.device).float() + action_one_hot = torch.scatter(action_one_hot, 2, action.long().unsqueeze(-1), 1.0) + action_one_hot = torch.reshape(action_one_hot, (encoded_state.shape[0], self.action_bins)) x = torch.cat((encoded_state, action_one_hot), dim=1) - next_encoded_state = self.dynamics_encoded_state_network(x) - reward = self.dynamics_reward_network(next_encoded_state) - # Scale encoded state between [0, 1] (See paper appendix Training) min_next_encoded_state = next_encoded_state.min(1, keepdim=True)[0] max_next_encoded_state = next_encoded_state.max(1, keepdim=True)[0] @@ -166,9 +168,17 @@ def dynamics(self, encoded_state, action): next_encoded_state_normalized = ( next_encoded_state - min_next_encoded_state ) / scale_next_encoded_state - return next_encoded_state_normalized, reward + def sample_actions(self, policy_parameters): + return self.policy_distribution(logits=policy_parameters).sample((self.sample_size,)).squeeze(-1) + + def log_prob(self, policy_parameters_batch, sampled_actions_batch): + batch_size, *action_size = tuple(policy_parameters_batch.shape) + policy_parameters_batch = policy_parameters_batch.unsqueeze(-2).expand( + torch.Size((batch_size, self.sample_size, *action_size))) + return self.policy_distribution(logits=policy_parameters_batch).log_prob(sampled_actions_batch).squeeze(-1) + def initial_inference(self, observation): encoded_state = self.representation(observation) policy_logits, value = self.prediction(encoded_state) diff --git a/replay_buffer.py b/replay_buffer.py index 81bc813e..7c802f27 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -75,8 +75,9 @@ def get_batch(self): reward_batch, value_batch, policy_batch, + sampled_action_batch, gradient_scale_batch, - ) = ([], [], [], [], [], [], []) + ) = ([], [], [], [], [], [], [], []) weight_batch = [] if self.config.PER else None for game_id, game_history, game_prob in self.sample_n_games( @@ -84,7 +85,7 @@ def get_batch(self): ): game_pos, pos_prob = self.sample_position(game_history) - values, rewards, policies, actions = self.make_target( + values, rewards, policies, sampled_actions, actions = self.make_target( game_history, game_pos ) @@ -93,13 +94,14 @@ def get_batch(self): game_history.get_stacked_observations( game_pos, self.config.stacked_observations, - len(self.config.action_space), + self.config.sample_size ) ) action_batch.append(actions) value_batch.append(values) reward_batch.append(rewards) policy_batch.append(policies) + sampled_action_batch.append(sampled_actions) gradient_scale_batch.append( [ min( @@ -132,6 +134,7 @@ def get_batch(self): value_batch, reward_batch, policy_batch, + sampled_action_batch, weight_batch, gradient_scale_batch, ), @@ -261,20 +264,28 @@ def compute_target_value(self, game_history, index): return value + def pad_sampled_actions(self, sampled_actions): + return list(map(lambda inner_sampled_actions: (inner_sampled_actions + self.config.sample_size * [ + inner_sampled_actions[0]])[:self.config.sample_size], sampled_actions)) + + def pad_target_policies(self, target_policies): + return list(map(lambda target_policy: (target_policy + self.config.sample_size * [0])[:self.config.sample_size], + target_policies)) + def make_target(self, game_history, state_index): """ Generate targets for every unroll steps. """ - target_values, target_rewards, target_policies, actions = [], [], [], [] + target_values, target_rewards, target_policies, sampled_actions, actions = [], [], [], [], [] for current_index in range( state_index, state_index + self.config.num_unroll_steps + 1 ): value = self.compute_target_value(game_history, current_index) - if current_index < len(game_history.root_values): target_values.append(value) target_rewards.append(game_history.reward_history[current_index]) target_policies.append(game_history.child_visits[current_index]) + sampled_actions.append(game_history.sampled_actions[current_index]) actions.append(game_history.action_history[current_index]) elif current_index == len(game_history.root_values): target_values.append(0) @@ -286,6 +297,7 @@ def make_target(self, game_history, state_index): for _ in range(len(game_history.child_visits[0])) ] ) + sampled_actions.append(game_history.sampled_actions[0]) actions.append(game_history.action_history[current_index]) else: # States past the end of games are treated as absorbing states @@ -298,9 +310,12 @@ def make_target(self, game_history, state_index): for _ in range(len(game_history.child_visits[0])) ] ) - actions.append(numpy.random.choice(self.config.action_space)) - - return target_values, target_rewards, target_policies, actions + sampled_actions.append(game_history.sampled_actions[0]) + actions.append( + game_history.sampled_actions[0][numpy.random.choice(len(game_history.sampled_actions[0]))]) + target_policies = self.pad_target_policies(target_policies) + sampled_actions = self.pad_sampled_actions(sampled_actions) + return target_values, target_rewards, target_policies, sampled_actions, actions @ray.remote @@ -347,7 +362,7 @@ def reanalyse(self, replay_buffer, shared_storage): game_history.get_stacked_observations( i, self.config.stacked_observations, - len(self.config.action_space), + self.config.sample_size ) for i in range(len(game_history.root_values)) ] diff --git a/results/cartpole/model.checkpoint b/results/cartpole/model.checkpoint deleted file mode 100644 index 35f4b41c..00000000 Binary files a/results/cartpole/model.checkpoint and /dev/null differ diff --git a/self_play.py b/self_play.py index d90fe5db..aab38528 100644 --- a/self_play.py +++ b/self_play.py @@ -115,7 +115,7 @@ def play_game( """ game_history = GameHistory() observation = self.game.reset() - game_history.action_history.append(0) + game_history.action_history.append([0] * self.config.action_shape[0]) game_history.observation_history.append(observation) game_history.reward_history.append(0) game_history.to_play_history.append(self.game.to_play()) @@ -136,7 +136,7 @@ def play_game( numpy.array(observation).shape == self.config.observation_shape ), f"Observation should match the observation_shape defined in MuZeroConfig. Expected {self.config.observation_shape} but got {numpy.array(observation).shape}." stacked_observations = game_history.get_stacked_observations( - -1, self.config.stacked_observations, len(self.config.action_space) + -1, self.config.stacked_observations, self.config.sample_size ) # Choose the action @@ -172,7 +172,7 @@ def play_game( print(f"Played action: {self.game.action_to_string(action)}") self.game.render() - game_history.store_search_statistics(root, self.config.action_space) + game_history.store_search_statistics(root) # Next batch game_history.action_history.append(action) @@ -229,19 +229,18 @@ def select_action(node, temperature): visit_counts = numpy.array( [child.visit_count for child in node.children.values()], dtype="int32" ) - actions = [action for action in node.children.keys()] + actions = [list(action) for action in node.children.keys()] if temperature == 0: action = actions[numpy.argmax(visit_counts)] elif temperature == float("inf"): - action = numpy.random.choice(actions) + action = actions[numpy.random.choice(len(actions))] else: # See paper appendix Data Generation visit_count_distribution = visit_counts ** (1 / temperature) visit_count_distribution = visit_count_distribution / sum( visit_count_distribution ) - action = numpy.random.choice(actions, p=visit_count_distribution) - + action = actions[numpy.random.choice(len(actions), p=visit_count_distribution)] return action @@ -289,6 +288,7 @@ def run( policy_logits, hidden_state, ) = model.initial_inference(observation) + sampled_actions = model.sample_actions(policy_logits) root_predicted_value = models.support_to_scalar( root_predicted_value, self.config.support_size ).item() @@ -296,14 +296,10 @@ def run( assert ( legal_actions ), f"Legal actions should not be an empty array. Got {legal_actions}." - assert set(legal_actions).issubset( - set(self.config.action_space) - ), "Legal actions should be a subset of the action space." root.expand( - legal_actions, + sampled_actions, to_play, reward, - policy_logits, hidden_state, ) @@ -338,15 +334,15 @@ def run( parent = search_path[-2] value, reward, policy_logits, hidden_state = model.recurrent_inference( parent.hidden_state, - torch.tensor([[action]]).to(parent.hidden_state.device), + torch.tensor([action]).to(parent.hidden_state.device), ) + sampled_actions = model.sample_actions(policy_logits) value = models.support_to_scalar(value, self.config.support_size).item() reward = models.support_to_scalar(reward, self.config.support_size).item() node.expand( - self.config.action_space, + sampled_actions, virtual_to_play, reward, - policy_logits, hidden_state, ) @@ -368,13 +364,9 @@ def select_child(self, node, min_max_stats): self.ucb_score(node, child, min_max_stats) for action, child in node.children.items() ) - action = numpy.random.choice( - [ - action - for action, child in node.children.items() - if self.ucb_score(node, child, min_max_stats) == max_ucb - ] - ) + actions = [action for action, child in node.children.items() if + self.ucb_score(node, child, min_max_stats) == max_ucb] + action = actions[numpy.random.choice(len(actions))] return action, node.children[action] def ucb_score(self, parent, child, min_max_stats): @@ -437,6 +429,7 @@ def __init__(self, prior): self.prior = prior self.value_sum = 0 self.children = {} + self.sampled_actions = [] self.hidden_state = None self.reward = 0 @@ -448,7 +441,7 @@ def value(self): return 0 return self.value_sum / self.visit_count - def expand(self, actions, to_play, reward, policy_logits, hidden_state): + def expand(self, sampled_actions, to_play, reward, hidden_state): """ We expand a node using the value, reward and policy prediction obtained from the neural network. @@ -456,13 +449,11 @@ def expand(self, actions, to_play, reward, policy_logits, hidden_state): self.to_play = to_play self.reward = reward self.hidden_state = hidden_state - - policy_values = torch.softmax( - torch.tensor([policy_logits[0][a] for a in actions]), dim=0 - ).tolist() - policy = {a: policy_values[i] for i, a in enumerate(actions)} - for action, p in policy.items(): - self.children[action] = Node(p) + uniques, counts = torch.unique(sampled_actions, dim=0, return_counts=True, sorted=True) + self.sampled_actions = uniques.tolist() + empirical_probabilities = counts / counts.sum() + for action, p in zip(self.sampled_actions, empirical_probabilities): + self.children[tuple(action)] = Node(p.item()) def add_exploration_noise(self, dirichlet_alpha, exploration_fraction): """ @@ -487,31 +478,25 @@ def __init__(self): self.reward_history = [] self.to_play_history = [] self.child_visits = [] + self.sampled_actions = [] self.root_values = [] self.reanalysed_predicted_root_values = None # For PER self.priorities = None self.game_priority = None - def store_search_statistics(self, root, action_space): + def store_search_statistics(self, root): # Turn visit count from root into a policy if root is not None: sum_visits = sum(child.visit_count for child in root.children.values()) - self.child_visits.append( - [ - root.children[a].visit_count / sum_visits - if a in root.children - else 0 - for a in action_space - ] - ) - + self.child_visits.append([root.children[a].visit_count / sum_visits for a in root.children]) + self.sampled_actions.append([list(a) for a in root.children]) self.root_values.append(root.value()) else: self.root_values.append(None) def get_stacked_observations( - self, index, num_stacked_observations, action_space_size + self, index, num_stacked_observations, sample_size ): """ Generate a new observation with the observation at the index position @@ -531,7 +516,7 @@ def get_stacked_observations( [ numpy.ones_like(stacked_observations[0]) * self.action_history[past_observation_index + 1] - / action_space_size + / sample_size ], ) ) diff --git a/trainer.py b/trainer.py index faa5f941..b7a962d6 100644 --- a/trainer.py +++ b/trainer.py @@ -1,4 +1,5 @@ import copy +import math import time import numpy @@ -132,6 +133,7 @@ def update_weights(self, batch): target_value, target_reward, target_policy, + sampled_actions, weight_batch, gradient_scale_batch, ) = batch @@ -146,10 +148,11 @@ def update_weights(self, batch): observation_batch = ( torch.tensor(numpy.array(observation_batch)).float().to(device) ) - action_batch = torch.tensor(action_batch).long().to(device).unsqueeze(-1) + action_batch = torch.tensor(action_batch).long().to(device) target_value = torch.tensor(target_value).float().to(device) target_reward = torch.tensor(target_reward).float().to(device) target_policy = torch.tensor(target_policy).float().to(device) + sampled_actions = torch.tensor(sampled_actions).long().to(device) gradient_scale_batch = torch.tensor(gradient_scale_batch).float().to(device) # observation_batch: batch, channels, height, width # action_batch: batch, num_unroll_steps+1, 1 (unsqueeze) @@ -169,24 +172,26 @@ def update_weights(self, batch): value, reward, policy_logits, hidden_state = self.model.initial_inference( observation_batch ) - predictions = [(value, reward, policy_logits)] + input_policy = self.model.log_prob(policy_logits, sampled_actions[:, 0]) + predictions = [(value, reward, input_policy)] for i in range(1, action_batch.shape[1]): value, reward, policy_logits, hidden_state = self.model.recurrent_inference( hidden_state, action_batch[:, i] ) + input_policy = self.model.log_prob(policy_logits, sampled_actions[:, i]) # Scale the gradient at the start of the dynamics function (See paper appendix Training) hidden_state.register_hook(lambda grad: grad * 0.5) - predictions.append((value, reward, policy_logits)) + predictions.append((value, reward, input_policy)) # predictions: num_unroll_steps+1, 3, batch, 2*support_size+1 | 2*support_size+1 | 9 (according to the 2nd dim) ## Compute losses value_loss, reward_loss, policy_loss = (0, 0, 0) - value, reward, policy_logits = predictions[0] + value, reward, input_policy = predictions[0] # Ignore reward loss for the first batch step current_value_loss, _, current_policy_loss = self.loss_function( value.squeeze(-1), reward.squeeze(-1), - policy_logits, + input_policy, target_value[:, 0], target_reward[:, 0], target_policy[:, 0], @@ -207,7 +212,7 @@ def update_weights(self, batch): ) for i in range(1, len(predictions)): - value, reward, policy_logits = predictions[i] + value, reward, input_policy = predictions[i] ( current_value_loss, current_reward_loss, @@ -215,7 +220,7 @@ def update_weights(self, batch): ) = self.loss_function( value.squeeze(-1), reward.squeeze(-1), - policy_logits, + input_policy, target_value[:, i], target_reward[:, i], target_policy[:, i], @@ -276,9 +281,7 @@ def update_lr(self): """ Update learning rate """ - lr = self.config.lr_init * self.config.lr_decay_rate ** ( - self.training_step / self.config.lr_decay_steps - ) + lr = self.config.lr_init * 0.5 * (1 + math.cos(math.pi * (self.training_step / self.config.lr_decay_steps))) for param_group in self.optimizer.param_groups: param_group["lr"] = lr @@ -286,7 +289,7 @@ def update_lr(self): def loss_function( value, reward, - policy_logits, + input_policy, target_value, target_reward, target_policy, @@ -294,7 +297,5 @@ def loss_function( # Cross-entropy seems to have a better convergence than MSE value_loss = (-target_value * torch.nn.LogSoftmax(dim=1)(value)).sum(1) reward_loss = (-target_reward * torch.nn.LogSoftmax(dim=1)(reward)).sum(1) - policy_loss = (-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits)).sum( - 1 - ) + policy_loss = (-target_policy * input_policy).sum(1) return value_loss, reward_loss, policy_loss