Skip to content

Commit

Permalink
Merge pull request #1 from RoboCoachTechnologies/feature/ros2
Browse files Browse the repository at this point in the history
Feature/ros2
  • Loading branch information
RoboCoachian authored Sep 26, 2023
2 parents c9b530e + b21ab74 commit cac8b46
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 81 deletions.
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,11 @@ cython_debug/
.idea/

# Generated code
workspace/
catkin_ws/
ros_ws/

# Generated figures
*.gv
*.pdf

# Docker scripts
docker_run
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "roscribe"
version = "0.0.2"
version = "0.0.3"
description = "Translate natural language into robot software."
readme = "README.md"
authors = [{ name = "RoboCoach Technologies", email = "[email protected]" }]
Expand Down
188 changes: 159 additions & 29 deletions roscribe/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,86 @@
import roscribe.ui as ui


def catkin_ws_generator(project_name):
if not os.path.exists('catkin_ws'):
os.mkdir('catkin_ws')
ROS_WS_NAME = 'ros_ws'

if not os.path.exists('catkin_ws/src'):
os.mkdir('catkin_ws/src')
SETUP_PY_TEMPLATE = """
setup.py
```python
from setuptools import setup
os.mkdir(f'catkin_ws/src/{project_name}')
os.mkdir(f'catkin_ws/src/{project_name}/src')
os.mkdir(f'catkin_ws/src/{project_name}/launch')
package_name = '{package_name}'
setup(
name=package_name,
version='0.0.1',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='TODO',
maintainer_email='TODO',
description='TODO: Package description',
license='TODO: License declaration',
tests_require=['pytest'],
entry_points={console_scripts},
)
```
"""

def code_generator(task, node_topic_list, curr_node, summary, project_name, llm, verbose=False):
gen_code_prompt = get_gen_code_prompt()

SETUP_CFG_TEMPLATE = """
setup.cfg
```cfg
[develop]
script_dir=$base/lib/{package_name}
[install]
install_scripts=$base/lib/{package_name}
```
"""


def make_setup_py(node_topic_dict, package_name):
console_scripts = "'console_scripts': ["
for node in node_topic_dict.keys():
console_scripts += f"'{node} = {package_name}.{node}:main', "

console_scripts = console_scripts[:-2] + "]"

setup_py = SETUP_PY_TEMPLATE.format(package_name=package_name, console_scripts=console_scripts)
return setup_py


def make_setup_cfg(package_name):
setup_cfg = SETUP_CFG_TEMPLATE.format(package_name=package_name)
return setup_cfg


def ros_ws_generator(project_name, ros_version):
if not os.path.exists(ROS_WS_NAME):
os.mkdir(ROS_WS_NAME)

if not os.path.exists(f'{ROS_WS_NAME}/src'):
os.mkdir(f'{ROS_WS_NAME}/src')

os.mkdir(f'{ROS_WS_NAME}/src/{project_name}')
os.mkdir(f'{ROS_WS_NAME}/src/{project_name}/launch')

if ros_version == 'ros1':
os.mkdir(f'{ROS_WS_NAME}/src/{project_name}/src')
elif ros_version == 'ros2':
os.mkdir(f'{ROS_WS_NAME}/src/{project_name}/{project_name}')
open(f'{ROS_WS_NAME}/src/{project_name}/{project_name}/__init__.py', 'x')

os.mkdir(f'{ROS_WS_NAME}/src/{project_name}/resource')
open(f'{ROS_WS_NAME}/src/{project_name}/resource/{project_name}', 'x')


def code_generator(task, node_topic_list, curr_node, summary, project_name, ros_version, llm, verbose=False):
gen_code_prompt = get_gen_code_prompt(ros_version)

gen_code_chain = LLMChain(
llm=llm,
Expand All @@ -33,13 +99,13 @@ def code_generator(task, node_topic_list, curr_node, summary, project_name, llm,
gen_code_output = gen_code_chain.predict(task=task, node_topic_list=node_topic_list,
curr_node=curr_node, summary=summary)

to_files(gen_code_output, project_name, 'src')
to_files(gen_code_output, project_name, 'impl', ros_version)

print(ui.GEN_NODE_CODE_MSG.format(node=curr_node))


def launch_generator(task, node_topic_list, project_name, llm, verbose=False):
gen_launch_prompt = get_gen_launch_prompt()
def launch_generator(task, node_topic_list, project_name, ros_version, llm, verbose=False):
gen_launch_prompt = get_gen_launch_prompt(ros_version)

gen_launch_chain = LLMChain(
llm=llm,
Expand All @@ -54,20 +120,28 @@ def launch_generator(task, node_topic_list, project_name, llm, verbose=False):
print(ui.GEN_LAUNCH_MSG)


def install_generator(task, node_topic_list, project_name, llm, verbose=False):
gen_cmake_prompt = get_gen_cmake_prompt()
def install_generator(task, node_topic_dict, node_topic_list, project_name, ros_version, llm, verbose=False):
if ros_version == 'ros1':
gen_cmake_prompt = get_gen_cmake_prompt()

gen_cmake_chain = LLMChain(
llm=llm,
prompt=gen_cmake_prompt,
verbose=verbose
)
gen_cmake_chain = LLMChain(
llm=llm,
prompt=gen_cmake_prompt,
verbose=verbose
)

gen_cmake_output = gen_cmake_chain.predict(task=task, node_topic_list=node_topic_list, project_name=project_name)
gen_cmake_output = gen_cmake_chain.predict(task=task, node_topic_list=node_topic_list,
project_name=project_name)
to_files(gen_cmake_output, project_name, 'install')

to_files(gen_cmake_output, project_name, 'install')
elif ros_version == 'ros2':
setup_py = make_setup_py(node_topic_dict, project_name)
to_files(setup_py, project_name, 'install')

gen_package_prompt = get_gen_package_prompt()
setup_cfg = make_setup_cfg(project_name)
to_files(setup_cfg, project_name, 'install')

gen_package_prompt = get_gen_package_prompt(ros_version)

gen_package_chain = LLMChain(
llm=llm,
Expand All @@ -83,7 +157,7 @@ def install_generator(task, node_topic_list, project_name, llm, verbose=False):
print(ui.GEN_INSTALL_MSG)


def to_files(chat, project_name, mode):
def to_files(chat, project_name, mode, ros_version='ros1'):
workspace = dict()

files = get_code_from_chat(chat)
Expand All @@ -96,14 +170,70 @@ def to_files(chat, project_name, mode):
for filename in workspace.keys():
code = workspace[filename]

if mode == 'src':
with open(f'catkin_ws/src/{project_name}/src/{filename}', 'w') as file:
file.write(code)
if mode == 'impl':
if ros_version == 'ros1':
with open(f'{ROS_WS_NAME}/src/{project_name}/src/{filename}', 'w') as file:
file.write(code)
elif ros_version == 'ros2':
with open(f'{ROS_WS_NAME}/src/{project_name}/{project_name}/{filename}', 'w') as file:
file.write(code)

elif mode == 'launch':
with open(f'catkin_ws/src/{project_name}/launch/{filename}', 'w') as file:
with open(f'{ROS_WS_NAME}/src/{project_name}/launch/{filename}', 'w') as file:
file.write(code)

elif mode == 'install':
with open(f'catkin_ws/src/{project_name}/{filename}', 'w') as file:
with open(f'{ROS_WS_NAME}/src/{project_name}/{filename}', 'w') as file:
file.write(code)
else:
print('Invalid file storage mode!')


def test_setup_py():
package_name = 'my_package'
node_topic_dict = {'node_A': {'description': 'node_A description',
'published_topics': [('topic_1', 'topic_1_msg_type'),
('topic_2', 'topic_2_msg_type')],
'subscribed_topics': [('topic_3', 'topic_3_msg_type')]},
'node_B': {'description': 'node_B description',
'published_topics': [('topic_3', 'topic_3_msg_type')],
'subscribed_topics': []},
'node_C': {'description': 'node_C description',
'published_topics': [('topic_2', 'topic_2_msg_type')],
'subscribed_topics': [('topic_1', 'topic_1_msg_type')]},
}
print(make_setup_py(node_topic_dict, package_name))


def test_setup_cfg():
package_name = 'my_package'
print(make_setup_cfg(package_name))


def test_dump_setup():
package_name = 'my_package'
node_topic_dict = {'node_A': {'description': 'node_A description',
'published_topics': [('topic_1', 'topic_1_msg_type'),
('topic_2', 'topic_2_msg_type')],
'subscribed_topics': [('topic_3', 'topic_3_msg_type')]},
'node_B': {'description': 'node_B description',
'published_topics': [('topic_3', 'topic_3_msg_type')],
'subscribed_topics': []},
'node_C': {'description': 'node_C description',
'published_topics': [('topic_2', 'topic_2_msg_type')],
'subscribed_topics': [('topic_1', 'topic_1_msg_type')]}
}

ros_ws_generator(package_name, 'ros2')

setup_py = make_setup_py(node_topic_dict, package_name)
to_files(setup_py, package_name, 'install')

setup_cfg = make_setup_cfg(package_name)
to_files(setup_cfg, package_name, 'install')


if __name__ == '__main__':
test_setup_py()
test_setup_cfg()
test_dump_setup()
26 changes: 16 additions & 10 deletions roscribe/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from roscribe.prompt import get_project_name_prompt, get_task_spec_prompt, get_task_spec_summarize_prompt,\
get_gen_node_prompt, get_gen_topic_prompt, get_node_qa_prompt, get_node_qa_sum_prompt
from roscribe.parser import make_node_list, make_node_topic_dict, make_node_topic_list_str, modify_node_dict
from roscribe.generator import catkin_ws_generator, code_generator, launch_generator, install_generator
from roscribe.generator import ros_ws_generator, code_generator, launch_generator, install_generator
from roscribe.visualization import show_node_graph

import roscribe.ui as ui
Expand All @@ -17,6 +17,10 @@ def main(verbose=False):
print(ui.WELCOME_MSG)

task_message = input("Your Robot Software: ") # User-specified task
ros_version = input("ROS1 or ROS2? ").replace(" ", "").lower() # User-specified ROS version
while ros_version not in ['ros1', 'ros2']:
print(ui.VALID_ROS_VER)
ros_version = input("ROS1 or ROS2? ").replace(" ", "").lower()

project_name_prompt = get_project_name_prompt()
project_name_chain = LLMChain(
Expand All @@ -25,9 +29,9 @@ def main(verbose=False):
verbose=verbose)
project_name = project_name_chain.predict(task=task_message)

catkin_ws_generator(project_name)
ros_ws_generator(project_name, ros_version)

task_spec_prompt, task_spec_end_str = get_task_spec_prompt(task_message)
task_spec_prompt, task_spec_end_str = get_task_spec_prompt(task_message, ros_version)
task_spec_memory = ConversationBufferMemory()
task_spec_chain = ConversationChain(
llm=llm,
Expand Down Expand Up @@ -57,7 +61,7 @@ def main(verbose=False):
task_spec_memory.return_messages = True
task_spec_sum_output = task_spec_summary_chain.predict(input=task_spec_memory.load_memory_variables({}))

node_gen_prompt, node_gen_parser = get_gen_node_prompt()
node_gen_prompt, node_gen_parser = get_gen_node_prompt(ros_version)
node_gen_chain = LLMChain(
llm=llm,
prompt=node_gen_prompt,
Expand All @@ -67,7 +71,7 @@ def main(verbose=False):
node_gen_list = node_gen_parser.parse(node_gen_output).ros_nodes
node_list_str = make_node_list(node_gen_list)

topic_gen_prompt, topic_gen_parser = get_gen_topic_prompt()
topic_gen_prompt, topic_gen_parser = get_gen_topic_prompt(ros_version)
topic_gen_chain = LLMChain(
llm=llm,
prompt=topic_gen_prompt,
Expand Down Expand Up @@ -134,8 +138,10 @@ def main(verbose=False):
print(ui.QA_MSG_INIT)

for node in node_topic_dict.keys():
node_spec_prompt, node_spec_end_str = get_node_qa_prompt(
task=task_message, node_topic_list=node_topic_list_str, curr_node=node)
node_spec_prompt, node_spec_end_str = get_node_qa_prompt(task=task_message,
node_topic_list=node_topic_list_str,
curr_node=node,
ros_version=ros_version)
node_spec_memory = ConversationBufferMemory()
node_spec_chain = ConversationChain(
llm=llm,
Expand Down Expand Up @@ -165,13 +171,13 @@ def main(verbose=False):
node_spec_memory.return_messages = True
sum_output = summary_chain.predict(input=node_spec_memory.load_memory_variables({}))

code_generator(task_message, node_topic_list_str, node, sum_output, project_name, llm, verbose)
code_generator(task_message, node_topic_list_str, node, sum_output, project_name, ros_version, llm, verbose)

print(ui.LAUNCH_INSTALL_MSG)

launch_generator(task_message, node_topic_list_str, project_name, llm)
launch_generator(task_message, node_topic_list_str, project_name, ros_version, llm, verbose)

install_generator(task_message, node_topic_list_str, project_name, llm)
install_generator(task_message, node_topic_dict, node_topic_list_str, project_name, ros_version, llm, verbose)

print(ui.FAREWELL_MSG)

Expand Down
6 changes: 3 additions & 3 deletions roscribe/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from langchain.output_parsers import PydanticOutputParser


MOD_INPUT_SCHEMA = Schema((str, str, [(str, str)], [(str, str)]))


# Data structure for the output of ROS nodes
class NodeList(BaseModel):
ros_nodes: Dict[str, str] = Field(description="dictionary containing ROS node names as keys and ROS node descriptions as values")
Expand Down Expand Up @@ -79,9 +82,6 @@ def make_node_topic_list_str(node_topic_dict):
return node_topic_list_str


MOD_INPUT_SCHEMA = Schema((str, str, [(str, str)], [(str, str)]))


def modify_node_dict(mod_input, node_topic_dict):
try:
mod_tuple = ast.literal_eval(mod_input)
Expand Down
Loading

0 comments on commit cac8b46

Please sign in to comment.