Skip to content

Commit

Permalink
Merge pull request #66 from CarperAI/optim
Browse files Browse the repository at this point in the history
minor changes to predicate api and caching
  • Loading branch information
kywch authored Jun 4, 2023
2 parents 1d2d46e + 25bab85 commit 20db562
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions nmmo/task/predicate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ def __init__(self,
**kwargs):
self.name = self._make_name(self.__class__.__name__, args, kwargs)

def is_group(x):
return isinstance(x, Group)
self._groups: List[Group] = list(filter(is_group, args))
self._groups = self._groups + list(filter(is_group, kwargs.values()))
self._groups: List[Group] = [x for x in list(args) + list(kwargs.values())
if isinstance(x, Group)]

self._groups.append(subject)

self._args = args
Expand All @@ -54,12 +53,12 @@ def __call__(self, gs: GameState) -> float:
for group in self._groups:
group.update(gs)
# Calculate score
cache = gs.cache_result
if self.name in cache:
progress = cache[self.name]
# cache = gs.cache_result
if self.name in gs.cache_result:
progress = gs.cache_result[self.name]
else:
progress = max(min(self._evaluate(gs)*1.0,1.0),0.0)
cache[self.name] = progress
gs.cache_result[self.name] = progress
return progress

def _reset(self, config: Config):
Expand Down

0 comments on commit 20db562

Please sign in to comment.