forked from joonspk-research/generative_agents
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[joonspk-research#83, joonspk-research#84] Moved each InferenceStrate…
…gy out into its separate module
- Loading branch information
Showing
8 changed files
with
519 additions
and
417 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
417 changes: 0 additions & 417 deletions
417
reverie/backend_server/persona/prompt_template/run_gpt_prompt.py
Large diffs are not rendered by default.
Oops, something went wrong.
45 changes: 45 additions & 0 deletions
45
reverie/backend_server/persona/prompts/run_gpt_prompt_act_obj_desc.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from typing import Dict, Optional | ||
|
||
from persona.prompt_template.InferenceStrategySK import JSONType, OutputType, functor, InferenceStrategySK | ||
|
||
@functor | ||
class run_gpt_prompt_act_obj_desc(InferenceStrategySK): | ||
output_type = OutputType.JSON | ||
config = { | ||
"max_tokens": 50, | ||
"temperature": 0, | ||
"top_p": 1, | ||
} | ||
prompt = """ | ||
We want to write an object description and to understand the state of an object that is being used by someone. For example, if Jack is fixing the generator, the description would state: | ||
{"object":"generator","user":"Jack","state":"being fixed"} | ||
Now, let's consider {{$object_name}}. {{$firstname}} is currently performing the task "{{$action_description}}", interacting with the {{$object_name}}. Describe the interaction in the same form as above. | ||
""" | ||
|
||
def prepare_context(self, act_game_object: str, act_desp: str, persona) -> Dict[str, str]: | ||
return { | ||
"object_name": act_game_object, | ||
"action_description": act_desp, | ||
"firstname": persona.scratch.get_str_firstname(), | ||
} | ||
|
||
def validate_json(self, json: JSONType) -> Optional[str]: | ||
# Check for the required fields in the JSON object | ||
required_fields = ["object", "user", "state"] | ||
for field in required_fields: | ||
if field not in json: | ||
return f"Missing field: {field}" | ||
# Check if the "object" field matches the lowercased object_name property | ||
if json["object"].lower() != self.context_variables['object_name'].lower(): | ||
return "Object name mismatch" | ||
# Check if the "object" field matches the lowercased object_name property | ||
if json["user"] != self.context_variables['firstname']: | ||
return "Object name mismatch" | ||
|
||
def extract_json(self, json: JSONType) -> str: | ||
return json['state'] | ||
|
||
def fallback(self, act_game_object: str, act_desp: str, persona) -> str: | ||
return f'being used by {persona.scratch.get_str_firstname()}' |
111 changes: 111 additions & 0 deletions
111
reverie/backend_server/persona/prompts/run_gpt_prompt_act_obj_event_triple.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from typing import Optional | ||
|
||
from persona.prompt_template.InferenceStrategySK import JSONType, ReturnType, OutputType, functor, InferenceStrategySK | ||
|
||
@functor | ||
class run_gpt_prompt_act_obj_event_triple(InferenceStrategySK): | ||
output_type = OutputType.JSON | ||
config = { | ||
"max_tokens": 50, | ||
"temperature": 0.8, | ||
"top_p": 0.95, | ||
"top_k": 40, # not supported by SK | ||
"min_p": 0.05, # not supported by SK | ||
|
||
} | ||
prompt = """ | ||
Transform natural language descriptions into structured JSON, focusing on the object, predicate, and specific status. The 'status' should reflect the primary action being performed with the object, described in a passive form, and should not include additional details unrelated to the action itself. Here are examples: | ||
Name: Sam | ||
Action description: Sam Johnson is eating breakfast. | ||
Object: table | ||
Object state: clear with a plate of food and a cup of coffee | ||
Output: { | ||
"object": "table", | ||
"predicate": "is", | ||
"interaction": "being eaten on" | ||
} | ||
--- | ||
Name: Joon | ||
Action description: Joon Park is brewing coffee. | ||
Object: coffee maker | ||
Object state: simmering | ||
Output: { | ||
"object": "coffee maker", | ||
"predicate": "is", | ||
"interaction": "brewing coffee" | ||
} | ||
--- | ||
Name: Jane | ||
Action description: Jane Cook is sleeping. | ||
Object: bed | ||
Object state: supported Jane during her sleep | ||
Output: { | ||
"object": "bed", | ||
"predicate": "is", | ||
"interaction": "being slept in" | ||
} | ||
--- | ||
Name: Michael | ||
Action description: Michael Bernstein is writing email on a computer. | ||
Object: computer | ||
Object state: in use | ||
Output: { | ||
"object": "computer", | ||
"predicate": "is", | ||
"interaction": "being used to write email" | ||
} | ||
--- | ||
Name: Percy | ||
Action description: Percy Liang is teaching students in a classroom. | ||
Object: classroom | ||
Object state: filled with students learning | ||
Output: { | ||
"object": "classroom", | ||
"predicate": "is", | ||
"interaction": "being used for teaching" | ||
} | ||
--- | ||
Name: Merrie | ||
Action description: Merrie Morris is running on a treadmill. | ||
Object: treadmill | ||
Object state: in use | ||
Output: { | ||
"object": "treadmill", | ||
"predicate": "is", | ||
"interaction": "being run on" | ||
} | ||
Now, for a new case: | ||
Name: {{$firstname}} | ||
Action description: {{$action_description}} | ||
Object: {{$object_name}} | ||
Object state: {{$object_state}} | ||
Based on this description, provide a single JSON object in the format shown above. The "object" field must contain object name. Do not make the "status" a generic action, such as "being used", but find a more specific word clarifying how the {{$object_name}} is being used. In addition, exclude any extraneous details not directly related to this action. No intro nor Markdown, respond just with the JSON object. | ||
""" | ||
|
||
def prepare_context(self, persona, task, act_obj_desc, object_name): | ||
return { | ||
"object_name": object_name, | ||
"action_description": task, | ||
"object_state": act_obj_desc, | ||
"firstname": persona.scratch.get_str_firstname(), | ||
} | ||
|
||
def validate_json(self, json: JSONType) -> Optional[str]: | ||
# Check for the required fields in the JSON object | ||
required_fields = ["object", "predicate", "interaction"] | ||
for field in required_fields: | ||
if field not in json: | ||
return f"Missing field: {field}" | ||
# Check if the "object" field matches the lowercased object_name property | ||
if json["object"].lower() != self.context_variables['object_name'].lower(): | ||
return "Object name mismatch" | ||
|
||
def extract_json(self, json: JSONType) -> ReturnType: | ||
return (json["object"], json["predicate"], json["interaction"]) | ||
|
||
def fallback(self, persona, task, act_obj_desc, object_name): | ||
return (object_name, "is", "idle") |
82 changes: 82 additions & 0 deletions
82
reverie/backend_server/persona/prompts/run_gpt_prompt_action_sector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import json | ||
|
||
from persona.prompt_template.InferenceStrategySK import JSONType, OutputType, functor, InferenceStrategySK | ||
|
||
@functor | ||
class run_gpt_prompt_action_sector(InferenceStrategySK): | ||
output_type = OutputType.JSON | ||
config = { | ||
"temperature": 0.3, | ||
} | ||
prompt = """ | ||
We need to choose an appropriate Sector for the task at hand. | ||
* Stay in the current sector if the activity can be done there. Only go out if the activity needs to take place in another place. | ||
* Must be one of the sectors from "All Sectors," verbatim. It must be a Sector, and not an Arena. | ||
* If none of those fit very well, we must still choose the one that's the closest fit. | ||
* Return the answer as a JSON object with a single key "area". The value is the chosen area name. | ||
Sam Kim lives in the "Sam Kim's house" Sector that has the following Arenas: ["Sam Kim's room", "bathroom", "kitchen"] | ||
Sam Kim is currently in the "Sam Kim's house" Sector that has the following Arenas: ["Sam Kim's room", "bathroom", "kitchen"] | ||
All Sectors: ["Sam Kim's house", "The Rose and Crown Pub", "Hobbs Cafe", "Oak Hill College", "Johnson Park", "Harvey Oak Supply Store", "The Willows Market and Pharmacy"]. | ||
For performing the "taking a walk" Action, Sam Kim should go to the following Sector: | ||
{"area": "Johnson Park"} | ||
--- | ||
Jane Anderson lives in the "Oak Hill College Student Dormitory" Sector that has the following Arenas: ["Jane Anderson's room"] | ||
Jane Anderson is currently in the "Oak Hill College" Sector that has the following Arenas: ["classroom", "library"] | ||
All Sectors: ["Oak Hill College Student Dormitory", "The Rose and Crown Pub", "Hobbs Cafe", "Oak Hill College", "Johnson Park", "Harvey Oak Supply Store", "The Willows Market and Pharmacy"]. | ||
For performing the "eating dinner" Action, Jane Anderson should go to the following Sector: | ||
{"area": "Hobbs Cafe"} | ||
--- | ||
{{$name}} lives in the {{$living_sector}} Sector that has the following Arenas: {{$living_sector_arenas}}. | ||
{{$name}} is currently in the {{$current_sector}} Sector that has the following Arenas: {{$current_sector_arenas}}. | ||
All Sectors: {{$all_sectors}}. | ||
Pick the Sector for performing {{$name}}'s current activity. | ||
* Stay in the current sector if the activity can be done there. Only go out if the activity needs to take place in another place. | ||
* Must be one of the sectors from "All Sectors," verbatim. It must be a Sector, and not an Arena. | ||
* If none of those fit very well, we must still choose the one that's the closest fit. | ||
* Return the answer as a JSON object with a single key "area". The value is the chosen area name. | ||
For performing the {{$action_description}} Action, {{$name}} should go to the following Sector: | ||
""" | ||
|
||
def prepare_context(self, action_description, persona, maze): | ||
self.persona = persona | ||
world_area = maze.access_tile(persona.scratch.curr_tile)['world'] | ||
self.path_to_living_sector = persona.scratch.living_area.split(":")[:2] | ||
self.path_to_current_sector = [ | ||
world_area, | ||
maze.access_tile(persona.scratch.curr_tile)['sector'], | ||
] | ||
self.living_sector_arenas = persona.s_mem.get_array_accessible_sector_arenas( | ||
":".join(self.path_to_living_sector) | ||
) | ||
self.current_sector_arenas = persona.s_mem.get_array_accessible_sector_arenas( | ||
":".join(self.path_to_current_sector) | ||
) | ||
known_sectors = persona.s_mem.get_str_accessible_sectors(world_area).split(", ") | ||
self.all_sectors = [sector for sector in known_sectors if "'s house" not in sector or persona.scratch.last_name in sector] | ||
|
||
return { | ||
"name": persona.scratch.get_str_name(), | ||
"action_description": json.dumps(action_description), | ||
"living_sector": json.dumps(self.path_to_living_sector[1]), | ||
"living_sector_arenas": json.dumps(self.living_sector_arenas), | ||
"current_sector": json.dumps(self.path_to_current_sector[1]), | ||
"current_sector_arenas": json.dumps(self.current_sector_arenas), | ||
"all_sectors": json.dumps(self.all_sectors), | ||
} | ||
|
||
def validate_json(self, json: JSONType): | ||
if "area" not in json: | ||
return "Missing area name" | ||
if json["area"] not in self.all_sectors: | ||
if json["area"] in self.living_sector_arenas or json["area"] in self.current_sector_arenas: | ||
return "Arena name was returned instead of the Sector name" | ||
else: | ||
return f"Specified Sector doesn't exist or isn't available to {self.persona.scratch.get_str_firstname()}" | ||
|
||
def extract_json(self, json: JSONType): | ||
return json["area"] | ||
|
||
def fallback(self, action_description, persona, maze): | ||
return maze.access_tile(persona.scratch.curr_tile)['sector'] |
97 changes: 97 additions & 0 deletions
97
reverie/backend_server/persona/prompts/run_gpt_prompt_daily_plan.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from random import Random | ||
|
||
from persona.common import is_valid_time, string_to_time | ||
from persona.prompt_template.InferenceStrategySK import JSONType, OutputType, functor, InferenceStrategySK | ||
|
||
""" | ||
Basically the long term planning that spans a day. Returns a list of actions | ||
that the persona will take today. Usually comes in the following form: | ||
'wake up and complete the morning routine at 6:00 am', | ||
'eat breakfast at 7:00 am',.. | ||
Note that the actions come without a period. | ||
INPUT: | ||
persona: The Persona class instance | ||
OUTPUT: | ||
a list of daily actions in broad strokes. | ||
""" | ||
@functor | ||
class run_gpt_prompt_daily_plan(InferenceStrategySK): | ||
# semantic_function = skill["daily_planning_v6"] | ||
output_type = OutputType.JSON | ||
config = { | ||
"max_tokens": 1000, | ||
"temperature": 1, | ||
"top_p": 0.8, | ||
} | ||
prompt = """ | ||
Let's consider {{$firstname}}: | ||
{{$commonset}} | ||
We need to draft a daily plan for {{$firstname}} in broad-strokes (with the time of the day. e.g., have a lunch at 12:00 pm, watch TV from 7 to 8 pm). The plan must be formatted as a single JSON array of objects, each object containing the following fields: | ||
* start: start time with am/pm | ||
* end: end time with am/pm | ||
* activity: the activity {{$firstname}} is performing, in plain text | ||
The entries must be in the correct order and must not intersect. The plan starts with waking up at {{$wake_up_hour}} and completing the morning routine, and it ends with going to sleep. What would be other items in the {{$firstname}}'s daily plan? | ||
""" | ||
|
||
def prepare_context(self, persona, wake_up_hour): | ||
return { | ||
"commonset": persona.scratch.get_str_iss(), | ||
"date": persona.scratch.get_str_curr_date_str(), | ||
"firstname": persona.scratch.get_str_firstname(), | ||
"wake_up_hour": f"{str(wake_up_hour)}:00 am" | ||
} | ||
|
||
def validate_json(self, json: JSONType): | ||
if not isinstance(json, list): | ||
return "Invalid JSON format (expected a JSON array)" | ||
if not all(isinstance(item, dict) and 'start' in item and 'end' in item and 'activity' in item for item in json): | ||
return "Invalid JSON format (expected an array of objects with 'start', 'end' and 'activity' fields)" | ||
wake_up_time = string_to_time(json[0]["start"]) | ||
prev_time = None | ||
prev_task = None | ||
for item in json: | ||
for field in ["start", "end"]: | ||
if not is_valid_time(item[field]): | ||
return f'Invalid {field} time format: "{item[field]}". Example time format: "6:00 am".' | ||
time = string_to_time(item["start"]) | ||
# For night owls, activities may continue past midnight and resume before the "wake-up" time. | ||
# This condition allows for time entries after midnight but before the first entry's time, | ||
# accommodating a schedule that doesn't strictly follow chronological order across days. | ||
is_past_midnight = time < wake_up_time and prev_time > wake_up_time | ||
if prev_time and time < prev_time and not is_past_midnight: | ||
raise ValueError(f'Tasks are not in chronological order. "{prev_task}" intersects with "{item["activity"]}"') | ||
prev_time = string_to_time(item["end"]) | ||
prev_task = item["activity"] | ||
|
||
def extract_json(self, json: JSONType): | ||
rng = Random(str(json)) | ||
activities = ["Relax", "Rest", "Chill", "Procrastinate"] | ||
result = [] | ||
for i, item in enumerate(json): | ||
if i != 0: | ||
start = item['start'] | ||
prev_end = json[i-1]['end'] | ||
if string_to_time(start) != string_to_time(prev_end): | ||
random_activity = rng.choice(activities) | ||
result.append(f"{prev_end} - {random_activity}") | ||
result.append(f"{item['start']} - {item['activity']}") | ||
return result | ||
# return [line for line in output.split('\n') if line.strip() and line[0].isdigit()] | ||
|
||
def fallback(self, persona, wake_up_hour): | ||
return [ | ||
'6:00 am - wake up and complete the morning routine', | ||
'7:00 am - eat breakfast', | ||
'8:00 am - read a book', | ||
'12:00 pm - have lunch', | ||
'1:00 pm - take a nap', | ||
'4:00 pm - relax', | ||
'7:00 pm - watch TV', | ||
'8:00 pm - relax', | ||
'11:00 pm - go to bed', | ||
] |
Oops, something went wrong.