diff --git a/nmmo/task/predicate_api.py b/nmmo/task/predicate_api.py index e71f2cc1..6be1a3d3 100644 --- a/nmmo/task/predicate_api.py +++ b/nmmo/task/predicate_api.py @@ -124,6 +124,26 @@ def _make_name(self, class_name, args, kwargs) -> str: def __str__(self): return self.name + @abstractmethod + def get_source_code(self) -> str: + """ Returns the actual source code how the game state/progress evaluation is done. + """ + raise NotImplementedError + + @abstractmethod + def get_signature(self) -> List: + """ Returns the signature of the game state/progress evaluation function. + """ + raise NotImplementedError + + @property + def args(self): + return self._args + + @property + def kwargs(self): + return self._kwargs + @property def subject(self): return self._subject @@ -205,12 +225,11 @@ def __init__(self, *args, **kwargs) -> None: self._kwargs = kwargs self.name = self._make_name(fn.__name__, args, kwargs) def _evaluate(self, gs: GameState) -> float: - # pylint: disable=redefined-builtin, unused-variable - __doc = fn.__doc__ - result = fn(gs, *self._args, **self._kwargs) - if isinstance(result, Predicate): - return result(gs) - return result + return fn(gs, *self._args, **self._kwargs) + def get_source_code(self): + return inspect.getsource(fn).strip() + def get_signature(self) -> List: + return list(self._signature.parameters) return FunctionPredicate @@ -245,6 +264,38 @@ def sample(self, config: Config, cls: type[PredicateOperator], **kwargs): else p(None) for p in self._predicates] return cls(*predicates, subject=subject) + def get_source_code(self) -> str: + # NOTE: get_source_code() of the combined predicates returns the joined str + # of each predicate's source code, which may NOT represent what the actual + # predicate is doing + # TODO: try to generate "the source code" that matches + # what the actual instantiated predicate returns, + # which perhaps should reflect the actual agent ids, etc... + src_list = [] + for pred in self._predicates: + if isinstance(pred, Predicate): + src_list.append(pred.get_source_code()) + return '\n\n'.join(src_list).strip() + + def get_signature(self): + # TODO: try to generate the correct signature + return [] + + @property + def args(self): + # TODO: try to generate the correct args + return [] + + @property + def kwargs(self): + # NOTE: This is incorrect implementation. kwargs of the combined predicates returns + # all summed kwargs dict, which can OVERWRITE the values of duplicated keys + # TODO: try to match the eval function and kwargs, which can be correctly used downstream + # for pred in self._predicates: + # if isinstance(pred, Predicate): + # kwargs.update(pred.kwargs) + return {} + class OR(PredicateOperator, Predicate): def __init__(self, *predicates: Predicate, subject: Group=None): super().__init__(lambda n: n>0, *predicates, subject=subject) diff --git a/nmmo/task/task_api.py b/nmmo/task/task_api.py index 8bc5d587..557c5ff3 100644 --- a/nmmo/task/task_api.py +++ b/nmmo/task/task_api.py @@ -2,6 +2,7 @@ from typing import Callable, Iterable, Dict, List, Union, Tuple from types import FunctionType from abc import ABC +import inspect from nmmo.task.group import Group from nmmo.task.predicate_api import Predicate, make_predicate, arg_to_string @@ -25,7 +26,6 @@ def __init__(self, self._progress = 0.0 self._completed = False self._reward_multiplier = reward_multiplier - self.name = self._make_name(self.__class__.__name__, eval_fn=eval_fn, assignee=self._assignee) @@ -87,6 +87,37 @@ def _make_name(self, class_name, **kwargs) -> str: def __str__(self): return self.name + @property + def subject(self): + if isinstance(self._eval_fn, Predicate): + return self._eval_fn.subject.agents + return self.assignee + + def get_source_code(self): + if isinstance(self._eval_fn, Predicate): + return self._eval_fn.get_source_code() + return inspect.getsource(self._eval_fn).strip() + + def get_signature(self): + if isinstance(self._eval_fn, Predicate): + return self._eval_fn.get_signature() + signature = inspect.signature(self._eval_fn) + return list(signature.parameters) + + @property + def args(self): + if isinstance(self._eval_fn, Predicate): + return self._eval_fn.args + # the function _eval_fn must only take gs + return [] + + @property + def kwargs(self): + if isinstance(self._eval_fn, Predicate): + return self._eval_fn.kwargs + # the function _eval_fn must only take gs + return {} + class OngoingTask(Task): def _map_progress_to_reward(self, gs) -> float: """Keep returning the progress reward after the task is completed. diff --git a/tests/task/test_task_api.py b/tests/task/test_task_api.py index 76356b2f..916f5cb1 100644 --- a/tests/task/test_task_api.py +++ b/tests/task/test_task_api.py @@ -57,21 +57,45 @@ def test_predicate_operators(self): # NOTE: only the instantiated predicate can be used with operators like below mock_gs = MockGameState() + # get the individual predicate's source code + self.assertEqual(SUCCESS.get_source_code(), + 'def Success(gs, subject: Group):\n return True') + self.assertEqual(FAILURE.get_source_code(), + 'def Failure(gs, subject: Group):\n return False') + # AND (&), OR (|), NOT (~) pred1 = SUCCESS & FAILURE self.assertFalse(pred1(mock_gs)) + # NOTE: get_source_code() of the combined predicates returns the joined str + # of each predicate's source code, which may NOT represent what the actual + # predicate is doing + self.assertEqual(pred1.get_source_code(), + 'def Success(gs, subject: Group):\n return True\n\n'+ + 'def Failure(gs, subject: Group):\n return False') pred2 = SUCCESS | FAILURE | SUCCESS self.assertTrue(pred2(mock_gs)) + self.assertEqual(pred2.get_source_code(), + 'def Success(gs, subject: Group):\n return True\n\n'+ + 'def Failure(gs, subject: Group):\n return False\n\n'+ + 'def Success(gs, subject: Group):\n return True') pred3 = SUCCESS & ~ FAILURE & SUCCESS self.assertTrue(pred3(mock_gs)) + # NOTE: demonstrating the above point -- it just returns the functions + # NOT what this predicate actually evaluates. + self.assertEqual(pred2.get_source_code(), + pred3.get_source_code()) # predicate math pred4 = 0.1 * SUCCESS + 0.3 self.assertEqual(pred4(mock_gs), 0.4) self.assertEqual(pred4.name, "(ADD_(MUL_(Success_(0,))_0.1)_0.3)") + # NOTE: demonstrating the above point again, -- it just returns the functions + # NOT what this predicate actually evaluates. + self.assertEqual(pred4.get_source_code(), + 'def Success(gs, subject: Group):\n return True') pred5 = 0.3 * SUCCESS - 1 self.assertEqual(pred5(mock_gs), 0.0) # cannot go below 0 @@ -157,13 +181,27 @@ def test_task_api_with_predicate(self): fake_pred_cls = make_predicate(Fake) mock_gs = MockGameState() - predicate = fake_pred_cls(Group(2), 1, Item.Hat, Action.Melee) + group = Group(2) + item = Item.Hat + action = Action.Melee + predicate = fake_pred_cls(group, a=1, b=item, c=action) + self.assertEqual(predicate.get_source_code(), + 'def Fake(gs, subject, a,b,c):\n return False') + self.assertEqual(predicate.get_signature(), ['gs', 'subject', 'a', 'b', 'c']) + self.assertEqual(predicate.args, [group]) + self.assertDictEqual(predicate.kwargs, {'a': 1, 'b': item, 'c': action}) + assignee = [1,2,3] # list of agent ids task = predicate.create_task(assignee=assignee) rewards, infos = task.compute_rewards(mock_gs) self.assertEqual(task.name, # contains predicate name and assignee list - "(Task_eval_fn:(Fake_(2,)_1_Hat_Melee)_assignee:(1,2,3))") + "(Task_eval_fn:(Fake_(2,)_a:1_b:Hat_c:Melee)_assignee:(1,2,3))") + self.assertEqual(task.get_source_code(), + 'def Fake(gs, subject, a,b,c):\n return False') + self.assertEqual(task.get_signature(), ['gs', 'subject', 'a', 'b', 'c']) + self.assertEqual(task.args, [group]) + self.assertDictEqual(task.kwargs, {'a': 1, 'b': item, 'c': action}) for agent_id in assignee: self.assertEqual(rewards[agent_id], 0) self.assertEqual(infos[agent_id]['progress'], 0) # progress (False -> 0) @@ -182,6 +220,14 @@ def is_agent_1(gs): self.assertEqual(task.name, # contains predicate name and assignee list "(Task_eval_fn:is_agent_1_assignee:(1,2,3))") + self.assertEqual(task.get_source_code(), + 'def is_agent_1(gs):\n ' + + 'return any(agent_id == 1 for agent_id in subject.agents)') + self.assertEqual(task.get_signature(), ['gs']) + self.assertEqual(task.args, []) + self.assertDictEqual(task.kwargs, {}) + self.assertEqual(task.subject, tuple(assignee)) + self.assertEqual(task.assignee, tuple(assignee)) for agent_id in assignee: self.assertEqual(rewards[agent_id], 1) self.assertEqual(infos[agent_id]['progress'], 1) # progress (True -> 1) @@ -206,6 +252,18 @@ def PracticeFormation(gs, subject, dist, num_tick): env = Env(config) env.reset(make_task_fn=lambda: make_team_tasks(teams, [task_spec])) + task = env.tasks[0] + self.assertEqual(task.name, + '(Task_eval_fn:(PracticeFormation_(1,2,3)_dist:1_num_tick:10)'+ + '_assignee:(1,2,3))') + self.assertEqual(task.get_source_code(), + 'def PracticeFormation(gs, subject, dist, num_tick):\n '+ + 'return AllMembersWithinRange(gs, subject, dist) * '+ + 'TickGE(gs, subject, num_tick)') + self.assertEqual(task.get_signature(), ['gs', 'subject', 'dist', 'num_tick']) + self.assertEqual(task.subject, tuple(teams[0])) + self.assertEqual(task.kwargs, task_spec[2]) + self.assertEqual(task.assignee, tuple(teams[0])) # move agent 2, 3 to agent 1's pos for agent_id in [2,3]: change_spawn_pos(env.realm, agent_id,