Skip to content

Commit

Permalink
add agent training with curriculum
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed May 15, 2024
1 parent e69f29b commit 5f1a2f9
Show file tree
Hide file tree
Showing 8 changed files with 552 additions and 74 deletions.
471 changes: 471 additions & 0 deletions curriculum/neurips_curriculum.py

Large diffs are not rendered by default.

Binary file added curriculum/neurips_curriculum_with_embedding.pkl
Binary file not shown.
35 changes: 32 additions & 3 deletions curriculum/task_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,29 @@ def extract_module_fn(module: ModuleType):
return fn_dict


def pca(X, num_components):
# 1. Standardize the data
# X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)

# 2. Compute the covariance matrix
cov_matrix = np.cov(X, rowvar=False)

# 3. Compute the eigenvalues and eigenvectors of the covariance matrix
eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)

# 4. Sort the eigenvalues and corresponding eigenvectors
sorted_indices = np.argsort(eigenvalues)[::-1]
sorted_eigenvectors = eigenvectors[:, sorted_indices]

# 5. Select the first k eigenvectors
selected_components = sorted_eigenvectors[:, :num_components]

# 6. Transform the original n-dimensional data points into k dimensions
transformed_data = np.dot(X, selected_components)

return transformed_data


class TaskEncoder:
"""A class for encoding tasks into embeddings using a pretrained model."""

Expand All @@ -30,6 +53,7 @@ def __init__(
context: ModuleType,
batch_size=2,
tmp_file_path="tmp_task_encoder.pkl",
reduce_dim=None,
):
"""
Initialize the TaskEncoder.
Expand Down Expand Up @@ -58,6 +82,7 @@ def __init__(

blank_embedding = self._get_embedding(["# just to get the embedding size"])
self.embed_dim = len(blank_embedding[0])
self.reduce_dim = reduce_dim

def update_context(self, context: ModuleType):
"""Update the module context, extracting function dictionary."""
Expand Down Expand Up @@ -154,6 +179,10 @@ def get_task_embedding(self, task_spec_list: List[ts.TaskSpec], save_to_file: st
]
embeddings = self._get_embedding(prompts)

if isinstance(self.reduce_dim, int) and self.reduce_dim < self.embed_dim:
embeddings = pca(np.stack(embeddings), self.reduce_dim)
embeddings = (embeddings / np.std(embeddings, axis=0)).astype(np.float16)

for single_spec, embedding in zip(task_spec_list, embeddings):
single_spec.embedding = embedding

Expand All @@ -179,10 +208,10 @@ def __exit__(self, exc_type, exc_value, traceback):


if __name__ == "__main__":
import curriculum.manual_curriculum as curriculum
import curriculum.neurips_curriculum as curriculum

LLM_CHECKPOINT = "deepseek-ai/deepseek-coder-1.3b-instruct"
CURRICULUM_FILE_PATH = "curriculum_generation/curriculum_with_embedding.pkl"
CURRICULUM_FILE_PATH = "curriculum/neurips_curriculum_with_embedding.pkl"

with TaskEncoder(LLM_CHECKPOINT, curriculum, batch_size=6) as task_encoder:
with TaskEncoder(LLM_CHECKPOINT, curriculum, batch_size=6, reduce_dim=16) as task_encoder:
task_encoder.get_task_embedding(curriculum.curriculum, save_to_file=CURRICULUM_FILE_PATH)
61 changes: 0 additions & 61 deletions curriculum/team_curriculum.py

This file was deleted.

Binary file removed curriculum/team_curriculum_with_embedding.pkl
Binary file not shown.
43 changes: 41 additions & 2 deletions reinforcement_learning/environment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from argparse import Namespace

import dill
import numpy as np

import pufferlib
import pufferlib.emulation
from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper
Expand Down Expand Up @@ -57,6 +60,41 @@ def _set_config(self):
self.config.set_for_episode("NPC_N", npc_num)


class AgentTraining(ng.AgentTraining):
def _get_candidate_tasks(self, eval_mode=False):
with open(self.config.CURRICULUM_FILE_PATH, "rb") as f:
# curriculum file may have been changed, so read the file when sampling
curriculum = dill.load(f) # a list of TaskSpec

cand_specs = [spec for spec in curriculum if spec.reward_to == "agent"]
if eval_mode:
cand_specs = [spec for spec in cand_specs if "eval" in spec.tags]
else:
cand_specs = [spec for spec in cand_specs if "eval" not in spec.tags]

assert len(cand_specs) > 0, "There is no agent task to be sampled"
return cand_specs

def _make_agent_tasks(self, cand_specs):
sampling_weights = [spec.sampling_weight for spec in cand_specs]
sampled_spec = self._np_random.choice(
cand_specs, size=self.config.PLAYER_N, p=sampling_weights / np.sum(sampling_weights)
)
return make_task_from_spec(self.config.POSSIBLE_AGENTS, sampled_spec)

def _define_tasks(self):
# NOTE: Some tasks may not be achievable at all when the necessary game system is off
cand_specs = self._get_candidate_tasks(eval_mode=False)
return self._make_agent_tasks(cand_specs)


class AgentTaskEval(AgentTraining):
def _define_tasks(self):
# NOTE: Some tasks may not be achievable at all when the necessary game system is off
cand_specs = self._get_candidate_tasks(eval_mode=True)
return self._make_agent_tasks(cand_specs)


class AmmoTraining(ng.AgentTraining):
def is_compatible(self):
return self.config.are_systems_enabled(["COMBAT", "EQUIPMENT", "PROFESSION"])
Expand Down Expand Up @@ -171,8 +209,7 @@ def __init__(self, env_args: Namespace):
self.set("RESOURCE_RESILIENT_POPULATION", env_args.resilient_population)
self.set("COMBAT_SPAWN_IMMUNITY", env_args.spawn_immunity)

# NOTE: Disabling curriculum file for now
# self.set("CURRICULUM_FILE_PATH", env_args.curriculum_file_path)
self.set("CURRICULUM_FILE_PATH", env_args.curriculum_file_path)


class FullGameConfig(
Expand Down Expand Up @@ -210,6 +247,8 @@ def make_env_creator(
game_packs = [(TeamBattle, 1), (EasyKingoftheHill, 1), (Sandwich, 1)]
elif train_flag == "tb_ammo":
game_packs = [(TeamBattle, 5), (AmmoTraining, 1)]
elif train_flag == "tb_curr":
game_packs = [(TeamBattle, 1), (AgentTraining, 1)]
else:
raise ValueError(f"Invalid train_flag: {train_flag}")

Expand Down
6 changes: 3 additions & 3 deletions reinforcement_learning/stat_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import nmmo.systems.item as Item
from nmmo.minigames import RacetoCenter, KingoftheHill, Sandwich, RadioRaid

from reinforcement_learning.environment import TeamBattle
from reinforcement_learning.environment import TeamBattle, AgentTraining


class BaseStatWrapper(BaseParallelWrapper):
Expand Down Expand Up @@ -231,8 +231,8 @@ def _process_stats_and_early_stop(self, agent_id, reward, terminated, truncated,
# 'return' is used for ranking in the eval mode, so put the task progress here
info["return"] = task._max_progress # this is 1 if done

# Log the below stats ONLY for the team battle
if isinstance(self.env.game, TeamBattle):
# Log the below stats ONLY for the team battle & agent training
if isinstance(self.env.game, TeamBattle) or isinstance(self.env.game, AgentTraining):
# Max combat/harvest level achieved
info["stats"]["achieved/max_combat_level"] = agent.attack_level
info["stats"]["achieved/max_harvest_skill_ammo"] = max(
Expand Down
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from reinforcement_learning import environment
from train_helper import init_wandb, train, sweep, generate_replay

BASELINE_CURRICULUM = "curriculum/team_curriculum_with_embedding.pkl"
BASELINE_CURRICULUM = "curriculum/neurips_curriculum_with_embedding.pkl"


def load_from_config(agent, debug=False):
Expand Down Expand Up @@ -120,7 +120,7 @@ def update_args(args, mode=None):
args = pufferlib.namespace(**args)

args.track = not args.no_track
# args.env.curriculum_file_path = args.curriculum
args.env.curriculum_file_path = args.curriculum

vec = args.vectorization
if vec == "serial" or args.debug:
Expand Down Expand Up @@ -187,9 +187,9 @@ def update_args(args, mode=None):
choices="battle race koh sandwich radio".split(),
help="Game to evaluate/replay",
)
# parser.add_argument(
# "-c", "--curriculum", type=str, default=BASELINE_CURRICULUM, help="Path to curriculum file"
# )
parser.add_argument(
"-c", "--curriculum", type=str, default=BASELINE_CURRICULUM, help="Path to curriculum file"
)
# parser.add_argument(
# "-t",
# "--task-to-assign",
Expand Down

0 comments on commit 5f1a2f9

Please sign in to comment.