Skip to content

Commit

Permalink
Merge pull request #72 from CarperAI/get-src
Browse files Browse the repository at this point in the history
get the source code, signature, etc from Task, Predicate's evaluation function
  • Loading branch information
jsuarez5341 authored Jun 9, 2023
2 parents a670d02 + d69fc92 commit e46152d
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 9 deletions.
63 changes: 57 additions & 6 deletions nmmo/task/predicate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,26 @@ def _make_name(self, class_name, args, kwargs) -> str:
def __str__(self):
return self.name

@abstractmethod
def get_source_code(self) -> str:
""" Returns the actual source code how the game state/progress evaluation is done.
"""
raise NotImplementedError

@abstractmethod
def get_signature(self) -> List:
""" Returns the signature of the game state/progress evaluation function.
"""
raise NotImplementedError

@property
def args(self):
return self._args

@property
def kwargs(self):
return self._kwargs

@property
def subject(self):
return self._subject
Expand Down Expand Up @@ -205,12 +225,11 @@ def __init__(self, *args, **kwargs) -> None:
self._kwargs = kwargs
self.name = self._make_name(fn.__name__, args, kwargs)
def _evaluate(self, gs: GameState) -> float:
# pylint: disable=redefined-builtin, unused-variable
__doc = fn.__doc__
result = fn(gs, *self._args, **self._kwargs)
if isinstance(result, Predicate):
return result(gs)
return result
return fn(gs, *self._args, **self._kwargs)
def get_source_code(self):
return inspect.getsource(fn).strip()
def get_signature(self) -> List:
return list(self._signature.parameters)

return FunctionPredicate

Expand Down Expand Up @@ -245,6 +264,38 @@ def sample(self, config: Config, cls: type[PredicateOperator], **kwargs):
else p(None) for p in self._predicates]
return cls(*predicates, subject=subject)

def get_source_code(self) -> str:
# NOTE: get_source_code() of the combined predicates returns the joined str
# of each predicate's source code, which may NOT represent what the actual
# predicate is doing
# TODO: try to generate "the source code" that matches
# what the actual instantiated predicate returns,
# which perhaps should reflect the actual agent ids, etc...
src_list = []
for pred in self._predicates:
if isinstance(pred, Predicate):
src_list.append(pred.get_source_code())
return '\n\n'.join(src_list).strip()

def get_signature(self):
# TODO: try to generate the correct signature
return []

@property
def args(self):
# TODO: try to generate the correct args
return []

@property
def kwargs(self):
# NOTE: This is incorrect implementation. kwargs of the combined predicates returns
# all summed kwargs dict, which can OVERWRITE the values of duplicated keys
# TODO: try to match the eval function and kwargs, which can be correctly used downstream
# for pred in self._predicates:
# if isinstance(pred, Predicate):
# kwargs.update(pred.kwargs)
return {}

class OR(PredicateOperator, Predicate):
def __init__(self, *predicates: Predicate, subject: Group=None):
super().__init__(lambda n: n>0, *predicates, subject=subject)
Expand Down
33 changes: 32 additions & 1 deletion nmmo/task/task_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, Iterable, Dict, List, Union, Tuple
from types import FunctionType
from abc import ABC
import inspect

from nmmo.task.group import Group
from nmmo.task.predicate_api import Predicate, make_predicate, arg_to_string
Expand All @@ -25,7 +26,6 @@ def __init__(self,
self._progress = 0.0
self._completed = False
self._reward_multiplier = reward_multiplier

self.name = self._make_name(self.__class__.__name__,
eval_fn=eval_fn, assignee=self._assignee)

Expand Down Expand Up @@ -87,6 +87,37 @@ def _make_name(self, class_name, **kwargs) -> str:
def __str__(self):
return self.name

@property
def subject(self):
if isinstance(self._eval_fn, Predicate):
return self._eval_fn.subject.agents
return self.assignee

def get_source_code(self):
if isinstance(self._eval_fn, Predicate):
return self._eval_fn.get_source_code()
return inspect.getsource(self._eval_fn).strip()

def get_signature(self):
if isinstance(self._eval_fn, Predicate):
return self._eval_fn.get_signature()
signature = inspect.signature(self._eval_fn)
return list(signature.parameters)

@property
def args(self):
if isinstance(self._eval_fn, Predicate):
return self._eval_fn.args
# the function _eval_fn must only take gs
return []

@property
def kwargs(self):
if isinstance(self._eval_fn, Predicate):
return self._eval_fn.kwargs
# the function _eval_fn must only take gs
return {}

class OngoingTask(Task):
def _map_progress_to_reward(self, gs) -> float:
"""Keep returning the progress reward after the task is completed.
Expand Down
62 changes: 60 additions & 2 deletions tests/task/test_task_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,45 @@ def test_predicate_operators(self):
# NOTE: only the instantiated predicate can be used with operators like below
mock_gs = MockGameState()

# get the individual predicate's source code
self.assertEqual(SUCCESS.get_source_code(),
'def Success(gs, subject: Group):\n return True')
self.assertEqual(FAILURE.get_source_code(),
'def Failure(gs, subject: Group):\n return False')

# AND (&), OR (|), NOT (~)
pred1 = SUCCESS & FAILURE
self.assertFalse(pred1(mock_gs))
# NOTE: get_source_code() of the combined predicates returns the joined str
# of each predicate's source code, which may NOT represent what the actual
# predicate is doing
self.assertEqual(pred1.get_source_code(),
'def Success(gs, subject: Group):\n return True\n\n'+
'def Failure(gs, subject: Group):\n return False')

pred2 = SUCCESS | FAILURE | SUCCESS
self.assertTrue(pred2(mock_gs))
self.assertEqual(pred2.get_source_code(),
'def Success(gs, subject: Group):\n return True\n\n'+
'def Failure(gs, subject: Group):\n return False\n\n'+
'def Success(gs, subject: Group):\n return True')

pred3 = SUCCESS & ~ FAILURE & SUCCESS
self.assertTrue(pred3(mock_gs))
# NOTE: demonstrating the above point -- it just returns the functions
# NOT what this predicate actually evaluates.
self.assertEqual(pred2.get_source_code(),
pred3.get_source_code())

# predicate math
pred4 = 0.1 * SUCCESS + 0.3
self.assertEqual(pred4(mock_gs), 0.4)
self.assertEqual(pred4.name,
"(ADD_(MUL_(Success_(0,))_0.1)_0.3)")
# NOTE: demonstrating the above point again, -- it just returns the functions
# NOT what this predicate actually evaluates.
self.assertEqual(pred4.get_source_code(),
'def Success(gs, subject: Group):\n return True')

pred5 = 0.3 * SUCCESS - 1
self.assertEqual(pred5(mock_gs), 0.0) # cannot go below 0
Expand Down Expand Up @@ -157,13 +181,27 @@ def test_task_api_with_predicate(self):
fake_pred_cls = make_predicate(Fake)

mock_gs = MockGameState()
predicate = fake_pred_cls(Group(2), 1, Item.Hat, Action.Melee)
group = Group(2)
item = Item.Hat
action = Action.Melee
predicate = fake_pred_cls(group, a=1, b=item, c=action)
self.assertEqual(predicate.get_source_code(),
'def Fake(gs, subject, a,b,c):\n return False')
self.assertEqual(predicate.get_signature(), ['gs', 'subject', 'a', 'b', 'c'])
self.assertEqual(predicate.args, [group])
self.assertDictEqual(predicate.kwargs, {'a': 1, 'b': item, 'c': action})

assignee = [1,2,3] # list of agent ids
task = predicate.create_task(assignee=assignee)
rewards, infos = task.compute_rewards(mock_gs)

self.assertEqual(task.name, # contains predicate name and assignee list
"(Task_eval_fn:(Fake_(2,)_1_Hat_Melee)_assignee:(1,2,3))")
"(Task_eval_fn:(Fake_(2,)_a:1_b:Hat_c:Melee)_assignee:(1,2,3))")
self.assertEqual(task.get_source_code(),
'def Fake(gs, subject, a,b,c):\n return False')
self.assertEqual(task.get_signature(), ['gs', 'subject', 'a', 'b', 'c'])
self.assertEqual(task.args, [group])
self.assertDictEqual(task.kwargs, {'a': 1, 'b': item, 'c': action})
for agent_id in assignee:
self.assertEqual(rewards[agent_id], 0)
self.assertEqual(infos[agent_id]['progress'], 0) # progress (False -> 0)
Expand All @@ -182,6 +220,14 @@ def is_agent_1(gs):

self.assertEqual(task.name, # contains predicate name and assignee list
"(Task_eval_fn:is_agent_1_assignee:(1,2,3))")
self.assertEqual(task.get_source_code(),
'def is_agent_1(gs):\n ' +
'return any(agent_id == 1 for agent_id in subject.agents)')
self.assertEqual(task.get_signature(), ['gs'])
self.assertEqual(task.args, [])
self.assertDictEqual(task.kwargs, {})
self.assertEqual(task.subject, tuple(assignee))
self.assertEqual(task.assignee, tuple(assignee))
for agent_id in assignee:
self.assertEqual(rewards[agent_id], 1)
self.assertEqual(infos[agent_id]['progress'], 1) # progress (True -> 1)
Expand All @@ -206,6 +252,18 @@ def PracticeFormation(gs, subject, dist, num_tick):
env = Env(config)
env.reset(make_task_fn=lambda: make_team_tasks(teams, [task_spec]))

task = env.tasks[0]
self.assertEqual(task.name,
'(Task_eval_fn:(PracticeFormation_(1,2,3)_dist:1_num_tick:10)'+
'_assignee:(1,2,3))')
self.assertEqual(task.get_source_code(),
'def PracticeFormation(gs, subject, dist, num_tick):\n '+
'return AllMembersWithinRange(gs, subject, dist) * '+
'TickGE(gs, subject, num_tick)')
self.assertEqual(task.get_signature(), ['gs', 'subject', 'dist', 'num_tick'])
self.assertEqual(task.subject, tuple(teams[0]))
self.assertEqual(task.kwargs, task_spec[2])
self.assertEqual(task.assignee, tuple(teams[0]))
# move agent 2, 3 to agent 1's pos
for agent_id in [2,3]:
change_spawn_pos(env.realm, agent_id,
Expand Down

0 comments on commit e46152d

Please sign in to comment.