diff --git a/.gitignore b/.gitignore index 68bc17f..bbd80e2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,22 @@ +train/logs/* +train/datasets/* +train/vint_train/data/data_splits/* +train/wandb/* +train/gnm_dataset/* + +*.png +*.jpg +*.pth +*.mp4 +*.gif + +deployment/model_weights/* +deployment/topomaps/* + +.vscode/* +*/.vscode/* + + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -20,6 +39,7 @@ parts/ sdist/ var/ wheels/ +pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg @@ -49,7 +69,6 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ -cover/ # Translations *.mo @@ -72,7 +91,6 @@ instance/ docs/_build/ # PyBuilder -.pybuilder/ target/ # Jupyter Notebook @@ -83,9 +101,7 @@ profile_default/ ipython_config.py # pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version +.python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. @@ -94,22 +110,7 @@ ipython_config.py # install all needed dependencies. #Pipfile.lock -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +# PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff @@ -145,16 +146,3 @@ dmypy.json # Pyre type checker .pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ diff --git a/LICENSE b/LICENSE index e204b0b..53b69c2 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 Dhruv Shah +Copyright (c) 2022 Dhruv Shah, Ajay Sridhar, Noriaki Hirose, Sergey Levine Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 1401a5c..4d1d15e 100644 --- a/README.md +++ b/README.md @@ -3,10 +3,252 @@ _Berkeley AI Research_ - -[Project Page](https://visualnav-transformer.github.io) | [arXiV](https://arxiv.org/abs/2306.xxxx) | [Summary Video](https://www.youtube.com/watch?v=xxxx) +[Project Page](https://visualnav-transformer.github.io) | [arXiV](https://arxiv.org/abs/2306.14846) | [Summary Video](https://www.youtube.com/watch?v=6kNex5dJ5sQ) --- +## Overview +This repository contains code for training a ViNT with your own data, pre-trained model checkpoints, as well as example code to deploy it on a TurtleBot2/LoCoBot robot. The repository follows the organization from [GNM](https://github.com/PrieureDeSion/drive-any-robot). + +- `./train/train.py`: training script to train or fine-tune the ViNT model on your custom data. +- `./train/vint_train/models/`: contains model files for GNM, ViNT, and some baselines. +- `./train/process_*.py`: scripts to process rosbags or other formats of robot trajectories into training data. +- `./deployment/src/record_bag.sh`: script to collect a demo trajectory in the target environment on the robot. This trajectory is subsampled to generate a topological graph of the environment. +- `./deployment/src/navigate.sh`: script that deploys a trained ViNT model on the robot to navigate to a desired goal in the generated topological graph. Please see relevant sections below for configuration settings. + +## Train + +This subfolder contains code for processing datasets and training a ViNT from your own data. + +### Pre-requisites + +The codebase assumes access to a workstation running Ubuntu (tested on 18.04 and 20.04), Python 3.7+, and a GPU with CUDA 10+. It also assumes access to conda, but you can modify it to work with other virtual environment packages, or a native setup. +### Setup +Run the commands below inside the `vint_release/` (topmost) directory: +1. Set up the conda environment: + ```bash + conda env create -f train/train_environment.yml + ``` +2. Source the conda environment: + ``` + conda activate vint_train + ``` +3. Install the vint_train packages: + ```bash + pip install -e train/ + ``` + +### Data-Wrangling +In the [ViNT paper](https://sites.google.com/view/drive-any-robot), we train on a combination of publicly available and unreleased datasets. Below is a list of publicly available datasets used for training; please contact the respective authors for access to the unreleased data. +- [RECON](https://sites.google.com/view/recon-robot/dataset) +- [TartanDrive](https://github.com/castacks/tartan_drive) +- [SCAND](https://www.cs.utexas.edu/~xiao/SCAND/SCAND.html#Links) +- [GoStanford2 (Modified)](https://drive.google.com/drive/folders/1xrNvMl5q92oWed99noOt_UhqQnceJYV0?usp=share_link) +- [SACSoN/HuRoN](https://sites.google.com/view/sacson-review/huron-dataset) + +We recommend you to download these (and any other datasets you may want to train on) and run the processing steps below. + +#### Data Processing + +We provide some sample scripts to process these datasets, either directly from a rosbag or from a custom format like HDF5s: +1. Run `process_bags.py` with the relevant args, or `process_recon.py` for processing RECON HDF5s. You can also manually add your own dataset by following our structure below (if you are adding a custom dataset, please checkout the [Custom Datasets](#custom-datasets) section). +2. Run `data_split.py` on your dataset folder with the relevant args. + +After step 1 of data processing, the processed dataset should have the following structure: + +``` +├── +│ ├── +│ │ ├── 0.jpg +│ │ ├── 1.jpg +│ │ ├── ... +│ │ ├── T_1.jpg +│ │ └── traj_data.pkl +│ ├── +│ │ ├── 0.jpg +│ │ ├── 1.jpg +│ │ ├── ... +│ │ ├── T_2.jpg +│ │ └── traj_data.pkl +│ ... +└── └── + ├── 0.jpg + ├── 1.jpg + ├── ... + ├── T_N.jpg + └── traj_data.pkl +``` + +Each `*.jpg` file contains an forward-facing RGB observation from the robot, and they are temporally labeled. The `traj_data.pkl` file is the odometry data for the trajectory. It’s a pickled dictionary with the keys: +- `"position"`: An np.ndarray [T, 2] of the xy-coordinates of the robot at each image observation. +- `"yaw"`: An np.ndarray [T,] of the yaws of the robot at each image observation. + + +After step 2 of data processing, the processed data-split should the following structure inside `vint_release/train/vint_train/data/data_splits/`: + +``` +├── +│ ├── train +| | └── traj_names.txt +└── └── test + └── traj_names.txt +``` + +### Training your ViNT +Run this inside the `vint_release/train` directory: +```bash +python train.py -c +``` +The premade config yaml files are in the `train/config` directory. + +#### Custom Config Files +You can use one of the premade yaml files as a starting point and change the values as you need. `config/vint.yaml` is good choice since it has commented arguments. `config/defaults.yaml` contains the default config values (don't directly train with this config file since it does not specify any datasets for training). + +#### Custom Datasets +Make sure your dataset and data-split directory follows the structures provided in the [Data Processing](#data-processing) section. Locate `train/vint_train/data/data_config.yaml` and append the following: + +``` +: + metric_waypoints_distance: +``` + +Locate your training config file and add the following text under the `datasets` argument (feel free to change the values of `end_slack`, `goals_per_obs`, and `negative_mining`): +``` +: + data_folder: + train: data/data_splits//train/ + test: data/data_splits//test/ + end_slack: 0 # how many timesteps to cut off from the end of each trajectory (in case many trajectories end in collisions) + goals_per_obs: 1 # how many goals are sampled per observation + negative_mining: True # negative mining from the ViNG paper (Shah et al.) +``` + +#### Training your ViNT from a checkpoint +Instead of training from scratch, you can also load an existing checkpoint from the published results. +Add `load_run: /`to your .yaml config file in `vint_release/train/config/`. The `*.pth` of the file you are loading to be saved in this file structure and renamed to “latest”: `vint_release/train/logs///latest.pth`. This makes it easy to train from the checkpoint of a previous run since logs are saved this way by default. Note: if you are loading a checkpoint from a previous run, check for the name the run in the `vint_release/train/logs//`, since the code appends a string of the date to each run_name specified in the config yaml file of the run to avoid duplicate run names. + + +If you want to use our checkpoints, you can download the `*.pth` files from [this link](https://drive.google.com/drive/folders/1a9yWR2iooXFAqjQHetz263--4_2FFggg?usp=sharing). + + +## Deployment +This subfolder contains code to load a pre-trained ViNT and deploy it on the open-source [LoCoBot indoor robot platform](http://www.locobot.org/) with a [NVIDIA Jetson Orin Nano](https://www.amazon.com/NVIDIA-Jetson-Orin-Nano-Developer/dp/B0BZJTQ5YP/ref=asc_df_B0BZJTQ5YP/?tag=hyprod-20&linkCode=df0&hvadid=652427572954&hvpos=&hvnetw=g&hvrand=12520404772764575478&hvpone=&hvptwo=&hvqmt=&hvdev=c&hvdvcmdl=&hvlocint=&hvlocphy=1013585&hvtargid=pla-2112361227514&psc=1&gclid=CjwKCAjw4P6oBhBsEiwAKYVkq7dqJEwEPz0K-H33oN7MzjO0hnGcAJDkx2RdT43XZHdSWLWHKDrODhoCmnoQAvD_BwE). It can be easily adapted to be run on alternate robots, and researchers have been able to independently deploy it on the following robots – Clearpath Jackal, DJI Tello, Unitree A1, TurtleBot2, Vizbot – and in simulated environments like CARLA. + +### LoCoBot Setup + +This software was tested on a LoCoBot running Ubuntu 20.04. + + +#### Software Installation (in this order) +1. ROS: [ros-noetic](https://wiki.ros.org/noetic/Installation/Ubuntu) +2. ROS packages: + ```bash + sudo apt-get install ros-noetic-usb-cam ros-noetic-joy + ``` +3. [kobuki](http://wiki.ros.org/kobuki/Tutorials/Installation) +4. Conda + - Install anaconda/miniconda/etc. for managing environments + - Make conda env with environment.yml (run this inside the `vint_release/` directory) + ```bash + conda env create -f deployment/deployment_environment.yml + ``` + - Source env + ```bash + conda activate vint_deployment + ``` + - (Recommended) add to `~/.bashrc`: + ```bash + echo “conda activate vint_deployment” >> ~/.bashrc + ``` +5. Install the `vint_train` packages (run this inside the `vint_release/` directory): + ```bash + pip install -e train/ + ``` +6. (Recommended) Install [tmux](https://github.com/tmux/tmux/wiki/Installing) if not present. + Many of the bash scripts rely on tmux to launch multiple screens with different commands. This will be useful for debugging because you can see the output of each screen. + +#### Hardware Requirements +- LoCoBot: http://locobot.org (just the navigation stack) +- A wide-angle RGB camera: [Example](https://www.amazon.com/ELP-170degree-Fisheye-640x480-Resolution/dp/B00VTHD17W). The `vint_locobot.launch` file uses camera parameters that work with cameras like the ELP fisheye wide angle, feel free to modify to your own. Adjust the camera parameters in `vint_release/deployment/config/camera.yaml` your camera accordingly (used for visualization). +- [Joystick](https://www.amazon.com/Logitech-Wireless-Nano-Receiver-Controller-Vibration/dp/B0041RR0TW)/[keyboard teleop](http://wiki.ros.org/teleop_twist_keyboard) that works with Linux. Add the index mapping for the _deadman_switch_ on the joystick to the `vint_release/deployment/config/joystick.yaml`. You can find the mapping from buttons to indices for common joysticks in the [wiki](https://wiki.ros.org/joy). + + +### Loading the model weights + +Save the model weights *.pth file in `vint_release/deployment/model_weights` folder. Our model's weights are in [this link](https://drive.google.com/drive/folders/1a9yWR2iooXFAqjQHetz263--4_2FFggg?usp=sharing). + +### Collecting a Topological Map + +_Make sure to run these scripts inside the `vint_release/deployment/src/` directory._ + + +This section discusses a simple way to create a topological map of the target environment for deployment. For simplicity, we will use the robot in “path-following” mode, i.e. given a single trajectory in an environment, the task is to follow the same trajectory to the goal. The environment may have new/dynamic obstacles, lighting variations etc. + +#### Record the rosbag: +```bash +./record_bag.sh +``` + +Run this command to teleoperate the robot with the joystick and camera. This command opens up three windows +1. `roslaunch vint_locobot.launch`: This launch file opens the `usb_cam` node for the camera, the joy node for the joystick, and nodes for the robot’s mobile base. +2. `python joy_teleop.py`: This python script starts a node that reads inputs from the joy topic and outputs them on topics that teleoperate the robot’s base. +3. `rosbag record /usb_cam/image_raw -o `: This command isn’t run immediately (you have to click Enter). It will be run in the vint_release/deployment/topomaps/bags directory, where we recommend you store your rosbags. + +Once you are ready to record the bag, run the `rosbag record` script and teleoperate the robot on the map you want the robot to follow. When you are finished with recording the path, kill the `rosbag record` command, and then kill the tmux session. + +#### Make the topological map: +```bash +./create_topomap.sh +``` + +This command opens up 3 windows: +1. `roscore` +2. `python create_topomap.py —dt 1 —dir `: This command creates a directory in `/vint_release/deployment/topomaps/images` and saves an image as a node in the map every second the bag is played. +3. `rosbag play -r 1.5 `: This command plays the rosbag at x5 speed, so the python script is actually recording nodes 1.5 seconds apart. The `` should be the entire bag name with the .bag extension. You can change this value in the `make_topomap.sh` file. The command does not run until you hit Enter, which you should only do once the python script gives its waiting message. Once you play the bag, move to the screen where the python script is running so you can kill it when the rosbag stops playing. + +When the bag stops playing, kill the tmux session. + + +### Running the model +_Make sure to run this script inside the `vint_release/deployment/src/` directory._ + +```bash +./navigate.sh “--model --dir ” +``` + +To deploy one of the models from the published results, we are releasing model checkpoints that you can download from [this link](TODO). + + +The `` is the name of the model in the `vint_release/deployment/config/models.yaml` file. In this file, you specify these parameters of the model for each model (defaults used): +- `config_path` (str): path of the *.yaml file in `vint_release/train/config/` used to train the model +- `ckpt_path` (str): path of the *.pth file in `vint_release/deployment/model_weights/` + + +Make sure these configurations match what you used to train the model. The configurations for the models we provided the weights for are provided in yaml file for your reference. + +The `` is the name of the directory in `vint_release/deployment/topomaps/images` that has the images corresponding to the nodes in the topological map. The images are ordered by name from 0 to N. + +This command opens up 4 windows: + +1. `roslaunch vint_locobot.launch`: This launch file opens the usb_cam node for the camera, the joy node for the joystick, and several nodes for the robot’s mobile base). +2. `python navigate.py --model -—dir `: This python script starts a node that reads in image observations from the `/usb_cam/image_raw` topic, inputs the observations and the map into the model, and publishes actions to the `/waypoint` topic. +3. `python joy_teleop.py`: This python script starts a node that reads inputs from the joy topic and outputs them on topics that teleoperate the robot’s base. +4. `python pd_controller.py`: This python script starts a node that reads messages from the `/waypoint` topic (waypoints from the model) and outputs velocities to navigate the robot’s base. + +When the robot is finishing navigating, kill the `pd_controller.py` script, and then kill the tmux session. If you want to take control of the robot while it is navigating, the `joy_teleop.py` script allows you to do so with the joystick. + +### Adapting this code to different robots + +We hope that this codebase is general enough to allow you to deploy it to your favorite ROS-based robots. You can change the robot configuration parameters in `vint_release/deployment/config/robot.yaml`, like the max angular and linear velocities of the robot and the topics to publish to teleop and control the robot. Please feel free to create a Github Issue or reach out to the authors at shah@cs.berkeley.edu. -Code and model checkpoints coming soon! +## Citing +``` +@article{shah2023vint, + author = {Dhruv Shah and Ajay Sridhar and Nitish Dashora + and Kyle Stachowicz and Kevin Black and Noriaki Hirose and Sergey Levine}, + title = {{ViNT: A Foundation Model for Visual Navigation}}, + journal = {arXiv pre-print}, + year = {2023}, + url = {https://arxiv.org/abs/2306.14846}, +} +``` diff --git a/deployment/config/camera_front.yaml b/deployment/config/camera_front.yaml new file mode 100644 index 0000000..174693b --- /dev/null +++ b/deployment/config/camera_front.yaml @@ -0,0 +1,8 @@ +# camera parameters for src/gnm_locobot.launch +video_device: "/dev/video0" # change this to your video device path +image_width: 160 #640 +image_height: 120 #480 +pixel_format: yuyv +camera_frame_id: "usb_cam" +io_method: "mmap" +framerate: 9 \ No newline at end of file diff --git a/deployment/config/camera_reverse.yaml b/deployment/config/camera_reverse.yaml new file mode 100644 index 0000000..929967d --- /dev/null +++ b/deployment/config/camera_reverse.yaml @@ -0,0 +1,8 @@ +# camera parameters for src/gnm_locobot.launch +video_device: "/dev/video2" # change this to your video device path +image_width: 160 +image_height: 120 +pixel_format: yuyv +camera_frame_id: "usb_cam" +io_method: "mmap" +framerate: 9 \ No newline at end of file diff --git a/deployment/config/cmd_vel_mux.yaml b/deployment/config/cmd_vel_mux.yaml new file mode 100644 index 0000000..6852ac7 --- /dev/null +++ b/deployment/config/cmd_vel_mux.yaml @@ -0,0 +1,38 @@ +# # Velocity command sources +# vel_sources: +# # Teleoperation commands +# teleop: +# topic: /cmd_vel_mux/input/teleop +# timeout: 0.6 +# priority: 100 + +# # Move base commands +# move_base: +# topic: /cmd_vel_mux/input/navi +# timeout: 0.6 +# priority: 90 + +# # Mux parameters +# yaml_cfg_file: "" +# allow_unsafe_topics: false +# cmd_vel_timeout: 0.25 +# publish_topic: /mobile_base/commands/velocity + + +subscribers: + - name: "gnm vels" + topic: "/cmd_vel_mux/input/navi" + timeout: 0.1 + priority: 0 + short_desc: "The default cmd_vel, controllers unaware that we are multiplexing cmd_vel should come here" + - name: "teleop" + topic: "/cmd_vel_mux/input/teleop" + timeout: 0.5 + priority: 2 + short_desc: "Navigation stack controller" + - name: "gnm recovery" + topic: "/cmd_vel_mux/input/recovery" + timeout: 0.1 + priority: 1 +publisher: "/mobile_base/commands/velocity" + \ No newline at end of file diff --git a/deployment/config/joystick.yaml b/deployment/config/joystick.yaml new file mode 100644 index 0000000..6c420e7 --- /dev/null +++ b/deployment/config/joystick.yaml @@ -0,0 +1,7 @@ +# joystick parameters for src/gnm_locobot.launch +dev: "/dev/input/js0" # change this to your joystick device path + +# joystick parameters for src/joy_teleop.py +deadman_switch: 5 # button index +lin_vel_button: 4 +ang_vel_button: 0 diff --git a/deployment/config/models.yaml b/deployment/config/models.yaml new file mode 100644 index 0000000..cb22033 --- /dev/null +++ b/deployment/config/models.yaml @@ -0,0 +1,18 @@ +vint: + config_path: "../../train/config/vint.yaml" + ckpt_path: "../model_weights/vint5c_29.pth" + +# gnmv9: +# config_path: "../../train/config/gnmv9.yaml" +# ckpt_path: "../model_weights/gnmv9_29.pth" + +late_fusion: + config_path: "../../train/config/late_fusion.yaml" + ckpt_path: "../model_weights/vint_late_fusion.pth" + +gnm: + config_path: "../../train/config/gnm.yaml" + ckpt_path: "../model_weights/gnm_large.pth" + + +# add your own model configs here after saving the *.pth file to ../model_weight \ No newline at end of file diff --git a/deployment/config/robot.yaml b/deployment/config/robot.yaml new file mode 100644 index 0000000..ea538c1 --- /dev/null +++ b/deployment/config/robot.yaml @@ -0,0 +1,13 @@ +# linear and angular speed limits for the robot +max_v: 0.2 #0.4 # m/s +max_w: 0.4 #0.8 # rad/s +# observation rate fo the robot +frame_rate: 4 # Hz +graph_rate: 0.3333 # Hz + +# topic names (modify for different robots/nodes) +vel_teleop_topic: /cmd_vel_mux/input/teleop +vel_navi_topic: /cmd_vel_mux/input/navi +vel_recovery_topic: /cmd_vel_mux/input/recovery + + diff --git a/deployment/deployment_environment.yml b/deployment/deployment_environment.yml new file mode 100644 index 0000000..abdf78f --- /dev/null +++ b/deployment/deployment_environment.yml @@ -0,0 +1,17 @@ +name: vint_deployment +channels: +- pytorch +- conda-forge +dependencies: +- python=3.8.5 +- cudatoolkit=11. +- torchvision +- numpy +- matplotlib +- pyyaml +- rospkg +- pip: + - torch + - torchvision + - efficientnet_pytorch + - warmup_scheduler \ No newline at end of file diff --git a/deployment/src/create_topomap.py b/deployment/src/create_topomap.py new file mode 100755 index 0000000..c10087c --- /dev/null +++ b/deployment/src/create_topomap.py @@ -0,0 +1,93 @@ +import argparse +import os +from utils import msg_to_pil +import time + +# ROS +import rospy +from sensor_msgs.msg import Image +from sensor_msgs.msg import Joy + +IMAGE_TOPIC = "/usb_cam/image_raw" +TOPOMAP_IMAGES_DIR = "../topomaps/images" +obs_img = None + + +def remove_files_in_dir(dir_path: str): + for f in os.listdir(dir_path): + file_path = os.path.join(dir_path, f) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print("Failed to delete %s. Reason: %s" % (file_path, e)) + + +def callback_obs(msg: Image): + global obs_img + obs_img = msg_to_pil(msg) + + +def callback_joy(msg: Joy): + if msg.buttons[0]: + rospy.signal_shutdown("shutdown") + + +def main(args: argparse.Namespace): + global obs_img + rospy.init_node("CREATE_TOPOMAP", anonymous=False) + image_curr_msg = rospy.Subscriber( + IMAGE_TOPIC, Image, callback_obs, queue_size=1) + subgoals_pub = rospy.Publisher( + "/subgoals", Image, queue_size=1) + joy_sub = rospy.Subscriber("joy", Joy, callback_joy) + + topomap_name_dir = os.path.join(TOPOMAP_IMAGES_DIR, args.dir) + if not os.path.isdir(topomap_name_dir): + os.makedirs(topomap_name_dir) + else: + print(f"{topomap_name_dir} already exists. Removing previous images...") + remove_files_in_dir(topomap_name_dir) + + + assert args.dt > 0, "dt must be positive" + rate = rospy.Rate(1/args.dt) + print("Registered with master node. Waiting for images...") + i = 0 + start_time = float("inf") + while not rospy.is_shutdown(): + if obs_img is not None: + obs_img.save(os.path.join(topomap_name_dir, f"{i}.png")) + print("published image", i) + i += 1 + rate.sleep() + start_time = time.time() + obs_img = None + if time.time() - start_time > 2 * args.dt: + print(f"Topic {IMAGE_TOPIC} not publishing anymore. Shutting down...") + rospy.signal_shutdown("shutdown") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=f"Code to generate topomaps from the {IMAGE_TOPIC} topic" + ) + parser.add_argument( + "--dir", + "-d", + default="topomap", + type=str, + help="path to topological map images in ../topomaps/images directory (default: topomap)", + ) + parser.add_argument( + "--dt", + "-t", + default=1., + type=float, + help=f"time between images sampled from the {IMAGE_TOPIC} topic (default: 3.0)", + ) + args = parser.parse_args() + + main(args) diff --git a/deployment/src/create_topomap.sh b/deployment/src/create_topomap.sh new file mode 100755 index 0000000..c0b848c --- /dev/null +++ b/deployment/src/create_topomap.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# Create a new tmux session +session_name="gnm_locobot_$(date +%s)" +tmux new-session -d -s $session_name + +# Split the window into three panes +tmux selectp -t 0 # select the first (0) pane +tmux splitw -v -p 50 # split it into two halves +tmux selectp -t 0 # go back to the first pane +tmux splitw -h -p 50 # split it into two halves + +# Run roscore in the first pane +tmux select-pane -t 0 +tmux send-keys "roscore" Enter + +# Run the create_topoplan.py script with command line args in the second pane +tmux select-pane -t 1 +tmux send-keys "conda activate gnm_deployment" Enter +tmux send-keys "python create_topomap.py --dt 1 --dir $1" Enter + +# Change the directory to ../topomaps/bags and run the rosbag play command in the third pane +tmux select-pane -t 2 +tmux send-keys "mkdir -p ../topomaps/bags" Enter +tmux send-keys "cd ../topomaps/bags" Enter +tmux send-keys "rosbag play -r 1.5 $2" # feel free to change the playback rate to change the edge length in the graph + +# Attach to the tmux session +tmux -2 attach-session -t $session_name diff --git a/deployment/src/joy_teleop.py b/deployment/src/joy_teleop.py new file mode 100755 index 0000000..bb1f4ab --- /dev/null +++ b/deployment/src/joy_teleop.py @@ -0,0 +1,68 @@ +import yaml + +# ROS +import rospy +from geometry_msgs.msg import Twist +from sensor_msgs.msg import Joy +from std_msgs.msg import Bool + +from topic_names import JOY_BUMPER_TOPIC + +vel_msg = Twist() +CONFIG_PATH = "../config/robot.yaml" +with open(CONFIG_PATH, "r") as f: + robot_config = yaml.safe_load(f) +MAX_V = 0.4#robot_config["max_v"] +MAX_W = 0.8#robot_config["max_w"] +VEL_TOPIC = robot_config["vel_teleop_topic"] +JOY_CONFIG_PATH = "../config/joystick.yaml" +with open(JOY_CONFIG_PATH, "r") as f: + joy_config = yaml.safe_load(f) +DEADMAN_SWITCH = joy_config["deadman_switch"] # button index +LIN_VEL_BUTTON = joy_config["lin_vel_button"] +ANG_VEL_BUTTON = joy_config["ang_vel_button"] +RATE = 9 +vel_pub = rospy.Publisher(VEL_TOPIC, Twist, queue_size=1) +bumper_pub = rospy.Publisher(JOY_BUMPER_TOPIC, Bool, queue_size=1) +button = None +bumper = False + + +def callback_joy(data: Joy): + """Callback function for the joystick subscriber""" + global vel_msg, button, bumper + button = data.buttons[DEADMAN_SWITCH] + bumper_button = data.buttons[DEADMAN_SWITCH - 1] + if button is not None: # hold down the dead-man switch to teleop the robot + vel_msg.linear.x = MAX_V * data.axes[LIN_VEL_BUTTON] + vel_msg.angular.z = MAX_W * data.axes[ANG_VEL_BUTTON] + else: + vel_msg = Twist() + vel_pub.publish(vel_msg) + if bumper_button is not None: + bumper = bool(data.buttons[DEADMAN_SWITCH - 1]) + else: + bumper = False + + + +def main(): + rospy.init_node("Joy2Locobot", anonymous=False) + joy_sub = rospy.Subscriber("joy", Joy, callback_joy) + rate = rospy.Rate(RATE) + print("Registered with master node. Waiting for joystick input...") + while not rospy.is_shutdown(): + if button: + print(f"Teleoperating the robot:\n {vel_msg}") + vel_pub.publish(vel_msg) + rate.sleep() + bumper_msg = Bool() + bumper_msg.data = bumper + bumper_pub.publish(bumper_msg) + if bumper: + print("Bumper pressed!") + + +if __name__ == "__main__": + main() + diff --git a/deployment/src/joy_teleop.sh b/deployment/src/joy_teleop.sh new file mode 100755 index 0000000..5fe6bf5 --- /dev/null +++ b/deployment/src/joy_teleop.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Create a new tmux session +session_name="teleop_locobot_$(date +%s)" +tmux new-session -d -s $session_name + +# Split the window into two panes +tmux selectp -t 0 # select the first (0) pane +tmux splitw -v -p 50 # split it into two halves + +# Run the roslaunch command in the first pane +tmux select-pane -t 0 +tmux send-keys "roslaunch gnm_locobot.launch" Enter + +# Run the teleop.py script in the second pane +tmux select-pane -t 1 +tmux send-keys "conda activate gnm_deployment" Enter +tmux send-keys "python joy_teleop.py" Enter + +# Attach to the tmux session +tmux -2 attach-session -t $session_name \ No newline at end of file diff --git a/deployment/src/navigate.py b/deployment/src/navigate.py new file mode 100755 index 0000000..828b59f --- /dev/null +++ b/deployment/src/navigate.py @@ -0,0 +1,201 @@ +# ROS +import rospy +from sensor_msgs.msg import Image +from std_msgs.msg import Bool, Float32MultiArray + +# UTILS +from utils import msg_to_pil, to_numpy, transform_images, load_model + +import torch +from PIL import Image as PILImage +import numpy as np +import os +import argparse +import yaml +from topic_names import IMAGE_TOPIC + +TOPOMAP_IMAGES_DIR = "../topomaps/images" +MODEL_WEIGHTS_PATH = "../model_weights" +ROBOT_CONFIG_PATH ="../config/robot.yaml" +MODEL_CONFIG_PATH = "../config/models.yaml" +with open(ROBOT_CONFIG_PATH, "r") as f: + robot_config = yaml.safe_load(f) +MAX_V = robot_config["max_v"] +MAX_W = robot_config["max_w"] +RATE = robot_config["frame_rate"] + +# GLOBALS +context_queue = [] +context_size = None +recent_obs = None + + +# Load the model +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print("Using device:", device) + +def callback_obs(msg): + global recent_obs + obs_img = msg_to_pil(msg) + recent_obs = obs_img + + if context_size is not None: + if len(context_queue) < context_size + 1: + context_queue.append(obs_img) + else: + context_queue.pop(0) + context_queue.append(obs_img) + + +def main(args: argparse.Namespace): + global context_queue, recent_obs, context_size + # load topomap + topomap_filenames = sorted(os.listdir(os.path.join( + TOPOMAP_IMAGES_DIR, args.dir)), key=lambda x: int(x.split(".")[0])) + topomap_dir = f"{TOPOMAP_IMAGES_DIR}/{args.dir}" + num_nodes = len(os.listdir(topomap_dir)) + topomap = [] + for i in range(num_nodes): + image_path = os.path.join(topomap_dir, topomap_filenames[i]) + topomap.append(PILImage.open(image_path)) + + # load model parameters + with open(MODEL_CONFIG_PATH, "r") as f: + model_paths = yaml.safe_load(f) + + model_config_path = model_paths[args.model]["config_path"] + with open(model_config_path, "r") as f: + model_params = yaml.safe_load(f) + + # load model weights + ckpth_path = model_paths[args.model]["ckpt_path"] + if os.path.exists(ckpth_path): + print(f"Loading model from {ckpth_path}") + else: + raise FileNotFoundError(f"Model weights not found at {ckpth_path}") + model = load_model( + ckpth_path, + model_params, + device, + ) + model.eval() + + context_size = model_params["context_size"] + + # ROS + rospy.init_node("TOPOPLAN", anonymous=False) + rate = rospy.Rate(RATE) + image_curr_msg = rospy.Subscriber( + IMAGE_TOPIC, Image, callback_obs, queue_size=1) + waypoint_pub = rospy.Publisher( + "/waypoint", Float32MultiArray, queue_size=1) + goal_pub = rospy.Publisher("/topoplan/reached_goal", Bool, queue_size=1) + print("Registered with master node. Waiting for image observations...") + + closest_node = 0 + assert -1 <= args.goal_node < len(topomap), "Invalid goal index" + if args.goal_node == -1: + goal_node = len(topomap) - 1 + else: + goal_node = args.goal_node + reached_goal = False + + # navigation loop + while not rospy.is_shutdown(): + if len(context_queue) > context_size: + start = max(closest_node - args.radius, 0) + end = min(closest_node + args.radius + 1, goal_node) + distances = [] + waypoints = [] + batch_obs_imgs = [] + batch_goal_data = [] + for i, sg_img in enumerate(topomap[start: end + 1]): + transf_obs_img = transform_images(context_queue, model_params["image_size"]) + goal_data = transform_images(sg_img, model_params["image_size"]) + batch_obs_imgs.append(transf_obs_img) + batch_goal_data.append(goal_data) + + # predict distances and waypoints + batch_obs_imgs = torch.cat(batch_obs_imgs, dim=0).to(device) + batch_goal_data = torch.cat(batch_goal_data, dim=0).to(device) + + distances, waypoints = model(batch_obs_imgs, batch_goal_data) + distances = to_numpy(distances) + waypoints = to_numpy(waypoints) + # look for closest node + closest_node = np.argmin(distances) + # chose subgoal and output waypoints + if distances[closest_node] > args.close_threshold: + chosen_waypoint = waypoints[closest_node][args.waypoint] + sg_img = topomap[start + closest_node] + else: + chosen_waypoint = waypoints[min( + closest_node + 1, len(waypoints) - 1)][args.waypoint] + sg_img = topomap[start + min(closest_node + 1, len(waypoints) - 1)] + waypoint_msg = Float32MultiArray() + if model_params["normalize"]: + chosen_waypoint[:2] *= (MAX_V / RATE) + waypoint_msg.data = chosen_waypoint + waypoint_pub.publish(waypoint_msg) + closest_node += start + reached_goal = closest_node == goal_node + print("closest node:", closest_node) + goal_pub.publish(reached_goal) + if reached_goal: + print("Reached goal! Stopping...") + rate.sleep() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Code to run ViNT on the locobot") + parser.add_argument( + "--dir", + "-d", + default="topomap", + type=str, + help="path to topomap images", + ) + parser.add_argument( + "--model", + "-m", + default="vint", + type=str, + help="model name (hint: check ../config/models.yaml) (default: vint)", + ) + parser.add_argument( + "--close-threshold", + "-t", + default=3, + type=int, + help="""temporal distance within the next node in the topomap before + localizing to it (default: 3)""", + ) + parser.add_argument( + "--radius", + "-r", + default=4, + type=int, + help="""temporal number of locobal nodes to look at in the topopmap for + localization (default: 2)""", + ) + parser.add_argument( + "--waypoint", + "-w", + default=2, # close waypoints exihibit straight line motion (the middle waypoint is a good default) + type=int, + help=f"""index of the waypoint used for navigation (between 0 and 4 or + how many waypoints your model predicts) (default: 2)""", + ) + parser.add_argument( + "--goal-node", + "-g", + default=-1, + type=int, + help="""goal node index in the topomap (if -1, then the goal node is + the last node in the topomap) (default: -1)""", + ) + args = parser.parse_args() + print(f"Using {device}") + main(args) + diff --git a/deployment/src/navigate.sh b/deployment/src/navigate.sh new file mode 100755 index 0000000..6f37ddb --- /dev/null +++ b/deployment/src/navigate.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# Create a new tmux session +session_name="vint_locobot_$(date +%s)" +tmux new-session -d -s $session_name + +# Split the window into four panes +tmux selectp -t 0 # select the first (0) pane +tmux splitw -h -p 50 # split it into two halves +tmux selectp -t 0 # select the first (0) pane +tmux splitw -v -p 50 # split it into two halves + +tmux selectp -t 2 # select the new, second (2) pane +tmux splitw -v -p 50 # split it into two halves +tmux selectp -t 0 # go back to the first pane + +# Run the roslaunch command in the first pane +tmux select-pane -t 0 +tmux send-keys "roslaunch vint_locobot.launch" Enter + +# Run the navigate.py script with command line args in the second pane +tmux select-pane -t 1 +# tmux send-keys "conda activate vint_deployment" Enter +tmux send-keys "python navigate.py $@" Enter + +# Run the teleop.py script in the third pane +tmux select-pane -t 2 +# tmux send-keys "conda activate vint_deployment" Enter +tmux send-keys "python joy_teleop.py" Enter + +# Run the pd_controller.py script in the fourth pane +tmux select-pane -t 3 +tmux send-keys "conda activate vint_deployment" Enter +tmux send-keys "python pd_controller.py" Enter + +# Attach to the tmux session +tmux -2 attach-session -t $session_name diff --git a/deployment/src/pd_controller.py b/deployment/src/pd_controller.py new file mode 100755 index 0000000..c704244 --- /dev/null +++ b/deployment/src/pd_controller.py @@ -0,0 +1,104 @@ +import numpy as np +import yaml +from typing import Tuple + +# ROS +import rospy +from geometry_msgs.msg import Twist +from std_msgs.msg import Float32MultiArray, Bool + +from topic_names import (WAYPOINT_TOPIC, + REACHED_GOAL_TOPIC) +from ros_data import ROSData +from utils import clip_angle + +# CONSTS +CONFIG_PATH = "../config/robot.yaml" +with open(CONFIG_PATH, "r") as f: + robot_config = yaml.safe_load(f) +MAX_V = robot_config["max_v"] +MAX_W = robot_config["max_w"] +VEL_TOPIC = robot_config["vel_navi_topic"] +DT = 1/robot_config["frame_rate"] +RATE = 9 +EPS = 1e-8 +WAYPOINT_TIMEOUT = 1 # seconds # TODO: tune this +FLIP_ANG_VEL = np.pi/4 + +# GLOBALS +vel_msg = Twist() +waypoint = ROSData(WAYPOINT_TIMEOUT, name="waypoint") +reached_goal = False +reverse_mode = False +current_yaw = None + +def clip_angle(theta) -> float: + """Clip angle to [-pi, pi]""" + theta %= 2 * np.pi + if -np.pi < theta < np.pi: + return theta + return theta - 2 * np.pi + + +def pd_controller(waypoint: np.ndarray) -> Tuple[float]: + """PD controller for the robot""" + assert len(waypoint) == 2 or len(waypoint) == 4, "waypoint must be a 2D or 4D vector" + if len(waypoint) == 2: + dx, dy = waypoint + else: + dx, dy, hx, hy = waypoint + # this controller only uses the predicted heading if dx and dy near zero + if len(waypoint) == 4 and np.abs(dx) < EPS and np.abs(dy) < EPS: + v = 0 + w = clip_angle(np.arctan2(hy, hx))/DT + elif np.abs(dx) < EPS: + v = 0 + w = np.sign(dy) * np.pi/(2*DT) + else: + v = dx / DT + w = np.arctan(dy/dx) / DT + v = np.clip(v, 0, MAX_V) + w = np.clip(w, -MAX_W, MAX_W) + return v, w + + +def callback_drive(waypoint_msg: Float32MultiArray): + """Callback function for the waypoint subscriber""" + global vel_msg + print("seting waypoint") + waypoint.set(waypoint_msg.data) + + +def callback_reached_goal(reached_goal_msg: Bool): + """Callback function for the reached goal subscriber""" + global reached_goal + reached_goal = reached_goal_msg.data + + +def main(): + global vel_msg, reverse_mode + rospy.init_node("PD_CONTROLLER", anonymous=False) + waypoint_sub = rospy.Subscriber(WAYPOINT_TOPIC, Float32MultiArray, callback_drive, queue_size=1) + reached_goal_sub = rospy.Subscriber(REACHED_GOAL_TOPIC, Bool, callback_reached_goal, queue_size=1) + vel_out = rospy.Publisher(VEL_TOPIC, Twist, queue_size=1) + rate = rospy.Rate(RATE) + print("Registered with master node. Waiting for waypoints...") + while not rospy.is_shutdown(): + vel_msg = Twist() + if reached_goal: + vel_out.publish(vel_msg) + print("Reached goal! Stopping...") + return + elif waypoint.is_valid(verbose=True): + v, w = pd_controller(waypoint.get()) + if reverse_mode: + v *= -1 + vel_msg.linear.x = v + vel_msg.angular.z = w + print(f"publishing new vel: {v}, {w}") + vel_out.publish(vel_msg) + rate.sleep() + + +if __name__ == '__main__': + main() diff --git a/deployment/src/record_bag.sh b/deployment/src/record_bag.sh new file mode 100755 index 0000000..b10e256 --- /dev/null +++ b/deployment/src/record_bag.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +# Create a new tmux session +session_name="record_bag_$(date +%s)" +tmux new-session -d -s $session_name + +# Split the window into three panes +tmux selectp -t 0 # select the first (0) pane +tmux splitw -v -p 50 # split it into two halves +tmux selectp -t 0 # go back to the first pane +tmux splitw -h -p 50 # split it into two halves + +# Run the roslaunch command in the first pane +tmux select-pane -t 0 +tmux send-keys "roslaunch vint_locobot.launch" Enter + +# Run the teleop.py script in the second pane +tmux select-pane -t 1 +tmux send-keys "conda activate vint_deployment" Enter +tmux send-keys "python joy_teleop.py" Enter + +# Change the directory to ../topomaps/bags and run the rosbag record command in the third pane +tmux select-pane -t 2 +tmux send-keys "cd ../topomaps/bags" Enter +tmux send-keys "rosbag record /usb_cam/image_raw -o $1" # change topic if necessary + +# Attach to the tmux session +tmux -2 attach-session -t $session_name \ No newline at end of file diff --git a/deployment/src/ros_data.py b/deployment/src/ros_data.py new file mode 100644 index 0000000..2b7da28 --- /dev/null +++ b/deployment/src/ros_data.py @@ -0,0 +1,34 @@ +import rospy + +class ROSData: + def __init__(self, timeout: int = 3, queue_size: int = 1, name: str = ""): + self.timout = timeout + self.last_time_received = float("-inf") + self.queue_size = queue_size + self.data = None + self.name = name + self.phantom = False + + def get(self): + return self.data + + def set(self, data): + time_waited = rospy.get_time() - self.last_time_received + if self.queue_size == 1: + self.data = data + else: + if self.data is None or time_waited > self.timout: # reset queue if timeout + self.data = [] + if len(self.data) == self.queue_size: + self.data.pop(0) + self.data.append(data) + self.last_time_received = rospy.get_time() + + def is_valid(self, verbose: bool = False): + time_waited = rospy.get_time() - self.last_time_received + valid = time_waited < self.timout + if self.queue_size > 1: + valid = valid and len(self.data) == self.queue_size + if verbose and not valid: + print(f"Not receiving {self.name} data for {time_waited} seconds (timeout: {self.timout} seconds)") + return valid \ No newline at end of file diff --git a/deployment/src/topic_names.py b/deployment/src/topic_names.py new file mode 100644 index 0000000..bff509e --- /dev/null +++ b/deployment/src/topic_names.py @@ -0,0 +1,36 @@ +# topic names for ROS communication + +# image obs topics +FRONT_IMAGE_TOPIC = "/usb_cam_front/image_raw" +REVERSE_IMAGE_TOPIC = "/usb_cam_reverse/image_raw" +IMAGE_TOPIC = "/usb_cam/image_raw" + +# exploration topics +SUBGOALS_TOPIC = "/subgoals" +GRAPH_NAME_TOPIC = "/graph_name" +WAYPOINT_TOPIC = "/waypoint" +REVERSE_MODE_TOPIC = "/reverse_mode" +SAMPLED_OUTPUTS_TOPIC = "/sampled_outputs" +REACHED_GOAL_TOPIC = "/topoplan/reached_goal" +SAMPLED_WAYPOINTS_GRAPH_TOPIC = "/sampled_waypoints_graph" +BACKTRACKING_IMAGE_TOPIC = "/backtracking_image" +FRONTIER_IMAGE_TOPIC = "/frontier_image" +SUBGOALS_SHAPE_TOPIC = "/subgoal_shape" +SAMPLED_ACTIONS_TOPIC = "/sampled_actions" +ANNOTATED_IMAGE_TOPIC = "/annotated_image" +CURRENT_NODE_IMAGE_TOPIC = "/current_node_image" +FLIP_DIRECTION_TOPIC = "/flip_direction" +TURNING_TOPIC = "/turning" +SUBGOAL_GEN_RATE_TOPIC = "/subgoal_gen_rate" +MARKER_TOPIC = "/visualization_marker_array" +VIZ_NAV_IMAGE_TOPIC = "/nav_image" + +# visualization topics +CHOSEN_SUBGOAL_TOPIC = "/chosen_subgoal" + +# recorded ont the robot +ODOM_TOPIC = "/odom" +BUMPER_TOPIC = "/mobile_base/events/bumper" +JOY_BUMPER_TOPIC = "/joy_bumper" + +# move the robot \ No newline at end of file diff --git a/deployment/src/utils.py b/deployment/src/utils.py new file mode 100644 index 0000000..a842f61 --- /dev/null +++ b/deployment/src/utils.py @@ -0,0 +1,117 @@ + +import os +import sys +import io +import matplotlib.pyplot as plt + +# ROS +from sensor_msgs.msg import Image + +# pytorch +import torch +import torch.nn as nn +from torchvision import transforms +import torchvision.transforms.functional as TF + +import numpy as np +from PIL import Image as PILImage +from typing import List, Tuple, Dict, Optional + +# models +from vint_train.models.gnm import GNM +from vint_train.models.vint import ViNT + +from vint_train.data.data_utils import IMAGE_ASPECT_RATIO + + +def load_model( + model_path: str, + config: dict, + device: torch.device = torch.device("cpu"), +) -> nn.Module: + """Load a model from a checkpoint file (works with models trained on multiple GPUs)""" + checkpoint = torch.load(model_path, map_location=device) + loaded_model = checkpoint["model"] + model_type = config["model_type"] + + if model_type == "gnm": + model = GNM( + config["context_size"], + config["len_traj_pred"], + config["learn_angle"], + config["obs_encoding_size"], + config["goal_encoding_size"], + ) + elif model_type == "vint": + model = ViNT( + context_size=config["context_size"], + len_traj_pred=config["len_traj_pred"], + learn_angle=config["learn_angle"], + obs_encoder=config["obs_encoder"], + obs_encoding_size=config["obs_encoding_size"], + late_fusion=config["late_fusion"], + mha_num_attention_heads=config["mha_num_attention_heads"], + mha_num_attention_layers=config["mha_num_attention_layers"], + mha_ff_dim_factor=config["mha_ff_dim_factor"], + ) + else: + raise ValueError(f"Invalid model type: {model_type}") + try: + state_dict = loaded_model.module.state_dict() + model.load_state_dict(state_dict, strict=False) + except AttributeError as e: + state_dict = loaded_model.state_dict() + model.load_state_dict(state_dict, strict=False) + model.to(device) + return model + + +def msg_to_pil(msg: Image) -> PILImage.Image: + img = np.frombuffer(msg.data, dtype=np.uint8).reshape( + msg.height, msg.width, -1) + pil_image = PILImage.fromarray(img) + return pil_image + + +def pil_to_msg(pil_img: PILImage.Image, encoding="mono8") -> Image: + img = np.asarray(pil_img) + ros_image = Image(encoding=encoding) + ros_image.height, ros_image.width, _ = img.shape + ros_image.data = img.ravel().tobytes() + ros_image.step = ros_image.width + return ros_image + + +def to_numpy(tensor): + return tensor.cpu().detach().numpy() + + +def transform_images(pil_imgs: List[PILImage.Image], image_size: List[int], center_crop: bool = False) -> torch.Tensor: + """Transforms a list of PIL image to a torch tensor.""" + transform_type = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ + 0.229, 0.224, 0.225]), + ] + ) + if type(pil_imgs) != list: + pil_imgs = [pil_imgs] + transf_imgs = [] + for pil_img in pil_imgs: + w, h = pil_img.size + if center_crop: + if w > h: + pil_img = TF.center_crop(pil_img, (h, int(h * IMAGE_ASPECT_RATIO))) # crop to the right ratio + else: + pil_img = TF.center_crop(pil_img, (int(w / IMAGE_ASPECT_RATIO), w)) + pil_img = pil_img.resize(image_size) + transf_img = transform_type(pil_img) + transf_img = torch.unsqueeze(transf_img, 0) + transf_imgs.append(transf_img) + return torch.cat(transf_imgs, dim=1) + + +# clip angle between -pi and pi +def clip_angle(angle): + return np.mod(angle + np.pi, 2 * np.pi) - np.pi diff --git a/deployment/src/vint_locobot.launch b/deployment/src/vint_locobot.launch new file mode 100644 index 0000000..d1b06bd --- /dev/null +++ b/deployment/src/vint_locobot.launch @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/deployment/topomaps/images/README.txt b/deployment/topomaps/images/README.txt new file mode 100644 index 0000000..7cb6b19 --- /dev/null +++ b/deployment/topomaps/images/README.txt @@ -0,0 +1 @@ +This is a directory to save folders of images for the topological graphs \ No newline at end of file diff --git a/train/README.md b/train/README.md new file mode 100644 index 0000000..92ebc3e --- /dev/null +++ b/train/README.md @@ -0,0 +1,2 @@ +# ViNT + diff --git a/train/config/defaults.yaml b/train/config/defaults.yaml new file mode 100644 index 0000000..3acd7c7 --- /dev/null +++ b/train/config/defaults.yaml @@ -0,0 +1,62 @@ +# defaults for training +project_name: vint +run_name: vint + +# training setup +use_wandb: True # set to false if you don't want to log to wandb +train: True +batch_size: 400 +eval_batch_size: 400 +epochs: 30 +gpu_ids: [0] +num_workers: 4 +lr: 5e-4 +optimizer: adam +seed: 0 +clipping: False +train_subset: 1. + +# model params +model_type: gnm +obs_encoding_size: 1024 +goal_encoding_size: 1024 + +# normalization for the action space +normalize: True + +# context +context_type: temporal +context_size: 5 + +# tradeoff between action and distance prediction loss +alpha: 0.5 + +# tradeoff between task loss and kld +beta: 0.1 + +obs_type: image +goal_type: image +scheduler: null + +# distance bounds for distance and action and distance predictions +distance: + min_dist_cat: 0 + max_dist_cat: 20 +action: + min_dist_cat: 2 + max_dist_cat: 10 +close_far_threshold: 10 # distance threshold used to seperate the close and the far subgoals that are sampled per datapoint + +# action output params +len_traj_pred: 5 +learn_angle: True + +# dataset specific parameters +image_size: [85, 64] # width, height + +# logging stuff +print_log_freq: 100 # in iterations +image_log_freq: 1000 # in iterations +num_images_log: 8 # number of images to log in a logging iteration +pairwise_test_freq: 10 # in epochs +eval_fraction: 0.25 # fraction of the dataset to use for evaluation diff --git a/train/config/gnm.yaml b/train/config/gnm.yaml new file mode 100644 index 0000000..62297e9 --- /dev/null +++ b/train/config/gnm.yaml @@ -0,0 +1,109 @@ +# NOTE: this model uses private datasets + +project_name: vint-release +run_name: gnm + +# training setup +use_wandb: True # set to false if you don't want to log to wandb +train: True +batch_size: 400 +eval_batch_size: 400 +epochs: 30 +gpu_ids: [0] +num_workers: 4 +lr: 7e-4 +optimizer: adam +seed: 0 + +# model params +model_type: gnm +obs_encoding_size: 1024 +goal_encoding_size: 1024 + +# normalization for the action space +normalize: True + +# context +context_type: temporal # [temporal, randomized] +context_size: 5 + +# tradeoff between action and distance prediction loss +alpha: 0.5 + +# distance bounds for distance and action and distance predictions +distance: + min_dist_cat: 0 + max_dist_cat: 20 +action: + min_dist_cat: 2 + max_dist_cat: 10 +close_far_threshold: 10 # distance threshold used to seperate the close and the far subgoals that are sampled per datapoint + +# action output params +len_traj_pred: 5 +learn_angle: True + +# dataset specific parameters +image_size: [85, 64] # width, height +goal_type: "image" + + +datasets: + recon: + data_folder: /home//vint_dataset/recon + train: /home//data_splits/recon/train/ # path to train folder with traj_names.txt + test: /home//data_splits/recon/test/ # path to test folder with traj_names.txt + end_slack: 3 # because many trajectories end in collisions + goals_per_obs: 1 # how many goals are sampled per observation + negative_mining: True # negative mining from the ViNG paper (Shah et al.) + go_stanford: + data_folder: /home//vint_dataset/go_stanford_cropped # datasets/stanford_go_new + train: /home//data_splits/go_stanford/train/ + test: /home//data_splits/go_stanford/test/ + end_slack: 0 + goals_per_obs: 2 # increase dataset size + negative_mining: True + cory_hall: + data_folder: /home//vint_dataset/cory_hall/ + train: /home//data_splits/cory_hall/train/ + test: /home//data_splits/cory_hall/test/ + end_slack: 3 # because many trajectories end in collisions + goals_per_obs: 1 + negative_mining: True + tartan_drive: + data_folder: /home//vint_dataset/tartan_drive/ + train: /home//data_splits/tartan_drive/train/ + test: /home//data_splits/tartan_drive/test/ + end_slack: 3 # because many trajectories end in collisions + goals_per_obs: 1 + negative_mining: True + sacson: + data_folder: /home//vint_dataset/sacson/ + train: /home//data_splits/sacson/train/ + test: /home//data_splits/sacson/test/ + end_slack: 3 # because many trajectories end in collisions + goals_per_obs: 1 + negative_mining: True + + # private datasets (uncomment if you have access) + # seattle: + # data_folder: /home//vint_dataset/seattle/ + # train: /home//data_splits/seattle/train/ + # test: /home//data_splits/seattle/test/ + # end_slack: 0 + # goals_per_obs: 1 + # negative_mining: True + # scand: + # data_folder: /home//vint_dataset/scand/ + # train: /home//data_splits/scand/train/ + # test: /home//data_splits/scand/test/ + # end_slack: 0 + # goals_per_obs: 1 + # negative_mining: True + + +# logging stuff +print_log_freq: 100 # in iterations +image_log_freq: 1000 # in iterations +num_images_log: 8 # number of images to log in a logging iteration +pairwise_test_freq: 20 # in epochs diff --git a/train/config/late_fusion.yaml b/train/config/late_fusion.yaml new file mode 100644 index 0000000..2132caf --- /dev/null +++ b/train/config/late_fusion.yaml @@ -0,0 +1,119 @@ +project_name: vint-release +run_name: vint-late-fusion + +# training setup +use_wandb: True # set to false if you don't want to log to wandb +train: True +batch_size: 256 +epochs: 100 +gpu_ids: [0] +num_workers: 12 +lr: 5e-4 +optimizer: adamw +clipping: False +max_norm: 1. +scheduler: "cosine" +warmup: False +warmup_epochs: 4 +cyclic_period: 10 +plateau_patience: 3 +plateau_factor: 0.5 +seed: 0 + + +# model params +model_type: vint +obs_encoder: "efficientnet-b0" # by default, this is imagenet pretrained +obs_encoding_size: 512 +mha_num_attention_heads: 4 +mha_num_attention_layers: 4 +mha_ff_dim_factor: 4 +late_fusion: True + +# normalization for the action space +normalize: True + +# context +context_type: temporal +context_size: 0 + +# tradeoff between action and distance prediction loss +alpha: 0.5 + +# distance bounds for distance and action and distance predictions +distance: + min_dist_cat: 0 + max_dist_cat: 20 +action: + min_dist_cat: 0 + max_dist_cat: 10 +close_far_threshold: 10 # distance threshold used to seperate the close and the far subgoals that are sampled per datapoint + +# action output params +len_traj_pred: 5 +learn_angle: True + +# dataset specific parameters +image_size: [85, 64] # width, height +goal_type: "image" + +datasets: + recon: + data_folder: /home//vint_dataset/recon + train: /home//data_splits/recon/train/ # path to train folder with traj_names.txt + test: /home//data_splits/recon/test/ # path to test folder with traj_names.txt + end_slack: 3 # because many trajectories end in collisions + goals_per_obs: 1 # how many goals are sampled per observation + negative_mining: True # negative mining from the ViNG paper (Shah et al.) + go_stanford: + data_folder: /home//vint_dataset/go_stanford_cropped # datasets/stanford_go_new + train: /home//data_splits/go_stanford/train/ + test: /home//data_splits/go_stanford/test/ + end_slack: 0 + goals_per_obs: 2 # increase dataset size + negative_mining: True + cory_hall: + data_folder: /home//vint_dataset/cory_hall/ + train: /home//data_splits/cory_hall/train/ + test: /home//data_splits/cory_hall/test/ + end_slack: 3 # because many trajectories end in collisions + goals_per_obs: 1 + negative_mining: True + tartan_drive: + data_folder: /home//vint_dataset/tartan_drive/ + train: /home//data_splits/tartan_drive/train/ + test: /home//data_splits/tartan_drive/test/ + end_slack: 3 # because many trajectories end in collisions + goals_per_obs: 1 + negative_mining: True + sacson: + data_folder: /home//vint_dataset/sacson/ + train: /home//data_splits/sacson/train/ + test: /home//data_splits/sacson/test/ + end_slack: 3 # because many trajectories end in collisions + goals_per_obs: 1 + negative_mining: True + + # private datasets (uncomment if you have access) + # seattle: + # data_folder: /home//vint_dataset/seattle/ + # train: /home//data_splits/seattle/train/ + # test: /home//data_splits/seattle/test/ + # end_slack: 0 + # goals_per_obs: 1 + # negative_mining: True + # scand: + # data_folder: /home//vint_dataset/scand/ + # train: /home//data_splits/scand/train/ + # test: /home//data_splits/scand/test/ + # end_slack: 0 + # goals_per_obs: 1 + # negative_mining: True + +# logging stuff +## =0 turns off +print_log_freq: 100 # in iterations +image_log_freq: 1000 #0 # in iterations +num_images_log: 8 #0 +pairwise_test_freq: 0 # in epochs +eval_fraction: 0.25 \ No newline at end of file diff --git a/train/config/vint.yaml b/train/config/vint.yaml new file mode 100644 index 0000000..9364ed5 --- /dev/null +++ b/train/config/vint.yaml @@ -0,0 +1,117 @@ +project_name: vint-release +run_name: vint-5c + +# training setup +use_wandb: True # set to false if you don't want to log to wandb +train: True +batch_size: 256 +epochs: 100 +gpu_ids: [0] +num_workers: 12 +lr: 5e-4 +optimizer: adamw +clipping: False +max_norm: 1. +scheduler: "cosine" +warmup: True +warmup_epochs: 4 +cyclic_period: 10 +plateau_patience: 3 +plateau_factor: 0.5 +seed: 0 + +# model params +model_type: vint +obs_encoder: "efficientnet-b0" # by default, this is imagenet pretrained +obs_encoding_size: 512 +mha_num_attention_heads: 4 +mha_num_attention_layers: 4 +mha_ff_dim_factor: 4 +late_fusion: False + +# normalization for the action space +normalize: True + +# context +context_type: temporal +context_size: 5 +# tradeoff between action and distance prediction loss +alpha: 0.5 + +# distance bounds for distance and action and distance predictions +distance: + min_dist_cat: 0 + max_dist_cat: 20 +action: + min_dist_cat: 0 + max_dist_cat: 10 +close_far_threshold: 10 # distance threshold used to seperate the close and the far subgoals that are sampled per datapoint + +# action output params +len_traj_pred: 5 +learn_angle: True + +# dataset specific parameters +image_size: [85, 64] # width, height +goal_type: "image" + +datasets: + recon: + data_folder: /home//vint_dataset/recon + train: /home//data_splits/recon/train/ # path to train folder with traj_names.txt + test: /home//data_splits/recon/test/ # path to test folder with traj_names.txt + end_slack: 3 # because many trajectories end in collisions + goals_per_obs: 1 # how many goals are sampled per observation + negative_mining: True # negative mining from the ViNG paper (Shah et al.) + go_stanford: + data_folder: /home//vint_dataset/go_stanford_cropped # datasets/stanford_go_new + train: /home//data_splits/go_stanford/train/ + test: /home//data_splits/go_stanford/test/ + end_slack: 0 + goals_per_obs: 2 # increase dataset size + negative_mining: True + cory_hall: + data_folder: /home//vint_dataset/cory_hall/ + train: /home//data_splits/cory_hall/train/ + test: /home//data_splits/cory_hall/test/ + end_slack: 3 # because many trajectories end in collisions + goals_per_obs: 1 + negative_mining: True + tartan_drive: + data_folder: /home//vint_dataset/tartan_drive/ + train: /home//data_splits/tartan_drive/train/ + test: /home//data_splits/tartan_drive/test/ + end_slack: 3 # because many trajectories end in collisions + goals_per_obs: 1 + negative_mining: True + sacson: + data_folder: /home//vint_dataset/sacson/ + train: /home//data_splits/sacson/train/ + test: /home//data_splits/sacson/test/ + end_slack: 3 # because many trajectories end in collisions + goals_per_obs: 1 + negative_mining: True + + # private datasets (uncomment if you have access) + # seattle: + # data_folder: /home//vint_dataset/seattle/ + # train: /home//data_splits/seattle/train/ + # test: /home//data_splits/seattle/test/ + # end_slack: 0 + # goals_per_obs: 1 + # negative_mining: True + # scand: + # data_folder: /home//vint_dataset/scand/ + # train: /home//data_splits/scand/train/ + # test: /home//data_splits/scand/test/ + # end_slack: 0 + # goals_per_obs: 1 + # negative_mining: True + +# logging stuff +## =0 turns off +print_log_freq: 100 # in iterations +image_log_freq: 1000 #0 # in iterations +num_images_log: 8 #0 +pairwise_test_freq: 0 # in epochs +eval_fraction: 0.25 \ No newline at end of file diff --git a/train/data_split.py b/train/data_split.py new file mode 100644 index 0000000..6e6f913 --- /dev/null +++ b/train/data_split.py @@ -0,0 +1,75 @@ +import argparse +import os +import shutil +import random + + +def remove_files_in_dir(dir_path: str): + for f in os.listdir(dir_path): + file_path = os.path.join(dir_path, f) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print("Failed to delete %s. Reason: %s" % (file_path, e)) + + +def main(args: argparse.Namespace): + # Get the names of the folders in the data directory that contain the file 'traj_data.pkl' + folder_names = [ + f + for f in os.listdir(args.data_dir) + if os.path.isdir(os.path.join(args.data_dir, f)) + and "traj_data.pkl" in os.listdir(os.path.join(args.data_dir, f)) + ] + + # Randomly shuffle the names of the folders + random.shuffle(folder_names) + + # Split the names of the folders into train and test sets + split_index = int(args.split * len(folder_names)) + train_folder_names = folder_names[:split_index] + test_folder_names = folder_names[split_index:] + + # Create directories for the train and test sets + train_dir = os.path.join(args.data_splits_dir, args.dataset_name, "train") + test_dir = os.path.join(args.data_splits_dir, args.dataset_name, "test") + for dir_path in [train_dir, test_dir]: + if os.path.exists(dir_path): + print(f"Clearing files from {dir_path} for new data split") + remove_files_in_dir(dir_path) + else: + print(f"Creating {dir_path}") + os.makedirs(dir_path) + + # Write the names of the train and test folders to files + with open(os.path.join(train_dir, "traj_names.txt"), "w") as f: + for folder_name in train_folder_names: + f.write(folder_name + "\n") + + with open(os.path.join(test_dir, "traj_names.txt"), "w") as f: + for folder_name in test_folder_names: + f.write(folder_name + "\n") + + +if __name__ == "__main__": + # Set up the command line argument parser + parser = argparse.ArgumentParser() + + parser.add_argument( + "--data-dir", "-i", help="Directory containing the data", required=True + ) + parser.add_argument( + "--dataset-name", "-d", help="Name of the dataset", required=True + ) + parser.add_argument( + "--split", "-s", type=float, default=0.8, help="Train/test split (default: 0.8)" + ) + parser.add_argument( + "--data-splits-dir", "-o", default="vint_train/data/data_splits", help="Data splits directory" + ) + args = parser.parse_args() + main(args) + print("Done") diff --git a/train/process_bag_diff.py b/train/process_bag_diff.py new file mode 100644 index 0000000..7c39c2e --- /dev/null +++ b/train/process_bag_diff.py @@ -0,0 +1,138 @@ + +import os +import pickle +from PIL import Image +import io +import argparse +import tqdm +import yaml +import rosbag + +# utils +from vint_train.process_data.process_data_utils import * + + +def main(args: argparse.Namespace): + + # load the config file + with open("vint_train/process_data/process_bags_config.yaml", "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # create output dir if it doesn't exist + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + # iterate recurisively through all the folders and get the path of files with .bag extension in the args.input_dir + bag_files = [] + for root, dirs, files in os.walk(args.input_dir): + for file in files: + if file.endswith(".bag") and "diff" in file: + bag_files.append(os.path.join(root, file)) + if args.num_trajs >= 0: + bag_files = bag_files[: args.num_trajs] + + # processing loop + for bag_path in tqdm.tqdm(bag_files, desc="Bags processed"): + try: + b = rosbag.Bag(bag_path) + except rosbag.ROSBagException as e: + print(e) + print(f"Error loading {bag_path}. Skipping...") + continue + + # name is that folders separated by _ and then the last part of the path + traj_name = "_".join(bag_path.split("/")[-2:])[:-4] + + # load the bag file + bag_img_data, bag_traj_data = get_images_and_odom_2( + b, + ['/usb_cam_front/image_raw', '/chosen_subgoal'], + ['/odom'], + rate=args.sample_rate, + ) + + if bag_img_data is None: + print( + f"{bag_path} did not have the topics we were looking for. Skipping..." + ) + continue + # remove backwards movement + # cut_trajs = filter_backwards(bag_img_data, bag_traj_data) + + # for i, (img_data_i, traj_data_i) in enumerate(cut_trajs): + # traj_name_i = traj_name + f"_{i}" + # traj_folder_i = os.path.join(args.output_dir, traj_name_i) + # # make a folder for the traj + # if not os.path.exists(traj_folder_i): + # os.makedirs(traj_folder_i) + # with open(os.path.join(traj_folder_i, "traj_data.pkl"), "wb") as f: + # pickle.dump(traj_data_i, f) + # # save the image data to disk + # for i, img in enumerate(img_data_i): + # img.save(os.path.join(traj_folder_i, f"{i}.jpg")) + + traj_folder = os.path.join(args.output_dir, traj_name) + if not os.path.exists(traj_folder): + os.makedirs(traj_folder) + + obs_images = bag_img_data["/usb_cam_front/image_raw"] + diff_images = bag_img_data["/chosen_subgoal"] + for i, img_data in enumerate(zip(obs_images, diff_images)): + obs_image, diff_image = img_data + # save the image data to disk + # save the image data to disk + obs_image.save(os.path.join(traj_folder, f"{i}.jpg")) + diff_image.save(os.path.join(traj_folder, f"diff_{i}.jpg")) + + with open(os.path.join(traj_folder, "traj_data.pkl"), "wb") as f: + pickle.dump(bag_traj_data['/odom'], f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # get arguments for the recon input dir and the output dir + # add dataset name + # parser.add_argument( + # "--dataset-name", + # "-d", + # type=str, + # help="name of the dataset (must be in process_config.yaml)", + # default="tartan_drive", + # required=True, + # ) + parser.add_argument( + "--input-dir", + "-i", + type=str, + help="path of the datasets with rosbags", + required=True, + ) + parser.add_argument( + "--output-dir", + "-o", + default="../datasets/tartan_drive/", + type=str, + help="path for processed dataset (default: ../datasets/tartan_drive/)", + ) + # number of trajs to process + parser.add_argument( + "--num-trajs", + "-n", + default=-1, + type=int, + help="number of bags to process (default: -1, all)", + ) + # sampling rate + parser.add_argument( + "--sample-rate", + "-s", + default=4.0, + type=float, + help="sampling rate (default: 4.0 hz)", + ) + + args = parser.parse_args() + # all caps for the dataset name + print(f"STARTING PROCESSING DIFF DATASET") + main(args) + print(f"FINISHED PROCESSING DIFF DATASET") diff --git a/train/process_bags.py b/train/process_bags.py new file mode 100644 index 0000000..9e3c899 --- /dev/null +++ b/train/process_bags.py @@ -0,0 +1,126 @@ + +import os +import pickle +from PIL import Image +import io +import argparse +import tqdm +import yaml +import rosbag + +# utils +from vint_train.process_data.process_data_utils import * + + +def main(args: argparse.Namespace): + + # load the config file + with open("vint_train/process_data/process_bags_config.yaml", "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # create output dir if it doesn't exist + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + # iterate recurisively through all the folders and get the path of files with .bag extension in the args.input_dir + bag_files = [] + for root, dirs, files in os.walk(args.input_dir): + for file in files: + if file.endswith(".bag"): + bag_files.append(os.path.join(root, file)) + if args.num_trajs >= 0: + bag_files = bag_files[: args.num_trajs] + + # processing loop + for bag_path in tqdm.tqdm(bag_files, desc="Bags processed"): + try: + b = rosbag.Bag(bag_path) + except rosbag.ROSBagException as e: + print(e) + print(f"Error loading {bag_path}. Skipping...") + continue + + # name is that folders separated by _ and then the last part of the path + traj_name = "_".join(bag_path.split("/")[-2:])[:-4] + + # load the hdf5 file + bag_img_data, bag_traj_data = get_images_and_odom( + b, + config[args.dataset_name]["imtopics"], + config[args.dataset_name]["odomtopics"], + eval(config[args.dataset_name]["img_process_func"]), + eval(config[args.dataset_name]["odom_process_func"]), + rate=args.sample_rate, + ang_offset=config[args.dataset_name]["ang_offset"], + ) + + + if bag_img_data is None or bag_traj_data is None: + print( + f"{bag_path} did not have the topics we were looking for. Skipping..." + ) + continue + # remove backwards movement + cut_trajs = filter_backwards(bag_img_data, bag_traj_data) + + for i, (img_data_i, traj_data_i) in enumerate(cut_trajs): + traj_name_i = traj_name + f"_{i}" + traj_folder_i = os.path.join(args.output_dir, traj_name_i) + # make a folder for the traj + if not os.path.exists(traj_folder_i): + os.makedirs(traj_folder_i) + with open(os.path.join(traj_folder_i, "traj_data.pkl"), "wb") as f: + pickle.dump(traj_data_i, f) + # save the image data to disk + for i, img in enumerate(img_data_i): + img.save(os.path.join(traj_folder_i, f"{i}.jpg")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # get arguments for the recon input dir and the output dir + # add dataset name + parser.add_argument( + "--dataset-name", + "-d", + type=str, + help="name of the dataset (must be in process_config.yaml)", + default="tartan_drive", + required=True, + ) + parser.add_argument( + "--input-dir", + "-i", + type=str, + help="path of the datasets with rosbags", + required=True, + ) + parser.add_argument( + "--output-dir", + "-o", + default="../datasets/tartan_drive/", + type=str, + help="path for processed dataset (default: ../datasets/tartan_drive/)", + ) + # number of trajs to process + parser.add_argument( + "--num-trajs", + "-n", + default=-1, + type=int, + help="number of bags to process (default: -1, all)", + ) + # sampling rate + parser.add_argument( + "--sample-rate", + "-s", + default=4.0, + type=float, + help="sampling rate (default: 4.0 hz)", + ) + + args = parser.parse_args() + # all caps for the dataset name + print(f"STARTING PROCESSING {args.dataset_name.upper()} DATASET") + main(args) + print(f"FINISHED PROCESSING {args.dataset_name.upper()} DATASET") diff --git a/train/process_recon.py b/train/process_recon.py new file mode 100644 index 0000000..b83e3df --- /dev/null +++ b/train/process_recon.py @@ -0,0 +1,80 @@ +import h5py +import os +import pickle +from PIL import Image +import io +import argparse +import tqdm + + +def main(args: argparse.Namespace): + recon_dir = os.path.join(args.input_dir, "recon_release") + output_dir = args.output_dir + + # create output dir if it doesn't exist + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # get all the folders in the recon dataset + filenames = os.listdir(recon_dir) + if args.num_trajs >= 0: + filenames = filenames[: args.num_trajs] + + # processing loop + for filename in tqdm.tqdm(filenames, desc="Trajectories processed"): + # extract the name without the extension + traj_name = filename.split(".")[0] + # load the hdf5 file + try: + h5_f = h5py.File(os.path.join(recon_dir, filename), "r") + except OSError: + print(f"Error loading {filename}. Skipping...") + continue + # extract the position and yaw data + position_data = h5_f["jackal"]["position"][:, :2] + yaw_data = h5_f["jackal"]["yaw"][()] + # save the data to a dictionary + traj_data = {"position": position_data, "yaw": yaw_data} + traj_folder = os.path.join(output_dir, traj_name) + os.makedirs(traj_folder, exist_ok=True) + with open(os.path.join(traj_folder, "traj_data.pkl"), "wb") as f: + pickle.dump(traj_data, f) + # make a folder for the file + if not os.path.exists(traj_folder): + os.makedirs(traj_folder) + # save the image data to disk + for i in range(h5_f["images"]["rgb_left"].shape[0]): + img = Image.open(io.BytesIO(h5_f["images"]["rgb_left"][i])) + img.save(os.path.join(traj_folder, f"{i}.jpg")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # get arguments for the recon input dir and the output dir + parser.add_argument( + "--input-dir", + "-i", + type=str, + help="path of the recon_dataset", + required=True, + ) + parser.add_argument( + "--output-dir", + "-o", + default="datasets/recon/", + type=str, + help="path for processed recon dataset (default: datasets/recon/)", + ) + # number of trajs to process + parser.add_argument( + "--num-trajs", + "-n", + default=-1, + type=int, + help="number of trajectories to process (default: -1, all)", + ) + + args = parser.parse_args() + print("STARTING PROCESSING RECON DATASET") + main(args) + print("FINISHED PROCESSING RECON DATASET") diff --git a/train/setup.py b/train/setup.py new file mode 100644 index 0000000..2706a31 --- /dev/null +++ b/train/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, find_packages + +setup( + name="vint_train", + version="0.1.0", + packages=find_packages(), +) diff --git a/train/train.py b/train/train.py new file mode 100644 index 0000000..8efcc56 --- /dev/null +++ b/train/train.py @@ -0,0 +1,316 @@ +import os +import wandb +import argparse +import numpy as np +import yaml +import time +import pdb + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, ConcatDataset +from torch.optim import Adam, AdamW +from torchvision import transforms +import torch.backends.cudnn as cudnn +from warmup_scheduler import GradualWarmupScheduler + +""" +IMPORT YOUR MODEL HERE +""" +from vint_train.models.gnm import GNM +from vint_train.models.vint import ViNT + +from vint_train.data.vint_dataset import ViNT_Dataset +from vint_train.training.train_eval_loop import ( + train_eval_loop, + load_model, + count_parameters, +) + + +def main(config): + assert config["distance"]["min_dist_cat"] < config["distance"]["max_dist_cat"] + assert config["action"]["min_dist_cat"] < config["action"]["max_dist_cat"] + + if torch.cuda.is_available(): + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + if "gpu_ids" not in config: + config["gpu_ids"] = [0] + elif type(config["gpu_ids"]) == int: + config["gpu_ids"] = [config["gpu_ids"]] + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( + [str(x) for x in config["gpu_ids"]] + ) + print("Using cuda devices:", os.environ["CUDA_VISIBLE_DEVICES"]) + else: + print("Using cpu") + + first_gpu_id = config["gpu_ids"][0] + device = torch.device( + f"cuda:{first_gpu_id}" if torch.cuda.is_available() else "cpu" + ) + + if "seed" in config: + np.random.seed(config["seed"]) + torch.manual_seed(config["seed"]) + cudnn.deterministic = True + + cudnn.benchmark = True # good if input sizes don't vary + transform = ([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + transform = transforms.Compose(transform) + + # Load the data + train_dataset = [] + test_dataloaders = {} + + if "context_type" not in config: + config["context_type"] = "temporal" + + if "clip_goals" not in config: + config["clip_goals"] = False + + for dataset_name in config["datasets"]: + data_config = config["datasets"][dataset_name] + if "negative_mining" not in data_config: + data_config["negative_mining"] = True + if "goals_per_obs" not in data_config: + data_config["goals_per_obs"] = 1 + if "end_slack" not in data_config: + data_config["end_slack"] = 0 + if "waypoint_spacing" not in data_config: + data_config["waypoint_spacing"] = 1 + + for data_split_type in ["train", "test"]: + if data_split_type in data_config: + dataset = ViNT_Dataset( + data_folder=data_config["data_folder"], + data_split_folder=data_config[data_split_type], + dataset_name=dataset_name, + image_size=config["image_size"], + waypoint_spacing=data_config["waypoint_spacing"], + min_dist_cat=config["distance"]["min_dist_cat"], + max_dist_cat=config["distance"]["max_dist_cat"], + min_action_distance=config["action"]["min_dist_cat"], + max_action_distance=config["action"]["max_dist_cat"], + negative_mining=data_config["negative_mining"], + len_traj_pred=config["len_traj_pred"], + learn_angle=config["learn_angle"], + context_size=config["context_size"], + context_type=config["context_type"], + end_slack=data_config["end_slack"], + goals_per_obs=data_config["goals_per_obs"], + normalize=config["normalize"], + goal_type=config["goal_type"], + ) + if data_split_type == "train": + train_dataset.append(dataset) + else: + dataset_type = f"{dataset_name}_{data_split_type}" + if dataset_type not in test_dataloaders: + test_dataloaders[dataset_type] = {} + test_dataloaders[dataset_type] = dataset + + # combine all the datasets from different robots + train_dataset = ConcatDataset(train_dataset) + + train_loader = DataLoader( + train_dataset, + batch_size=config["batch_size"], + shuffle=True, + num_workers=config["num_workers"], + drop_last=False, + persistent_workers=True, + ) + + if "eval_batch_size" not in config: + config["eval_batch_size"] = config["batch_size"] + + for dataset_type, dataset in test_dataloaders.items(): + test_dataloaders[dataset_type] = DataLoader( + dataset, + batch_size=config["eval_batch_size"], + shuffle=True, + num_workers=0, + drop_last=False, + ) + + # Create the model + if config["model_type"] == "gnm": + model = GNM( + config["context_size"], + config["len_traj_pred"], + config["learn_angle"], + config["obs_encoding_size"], + config["goal_encoding_size"], + ) + elif config["model_type"] == "vint": + model = ViNT( + context_size=config["context_size"], + len_traj_pred=config["len_traj_pred"], + learn_angle=config["learn_angle"], + obs_encoder=config["obs_encoder"], + obs_encoding_size=config["obs_encoding_size"], + late_fusion=config["late_fusion"], + mha_num_attention_heads=config["mha_num_attention_heads"], + mha_num_attention_layers=config["mha_num_attention_layers"], + mha_ff_dim_factor=config["mha_ff_dim_factor"], + ) + else: + raise ValueError(f"Model {config['model']} not supported") + + count_parameters(model) # print number of parameters + + if config["clipping"]: + print("Clipping gradients to", config["max_norm"]) + for p in model.parameters(): + if not p.requires_grad: + continue + p.register_hook( + lambda grad: torch.clamp( + grad, -1 * config["max_norm"], config["max_norm"] + ) + ) + + lr = float(config["lr"]) + config["optimizer"] = config["optimizer"].lower() + if config["optimizer"] == "adam": + optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.98)) + elif config["optimizer"] == "adamw": + optimizer = AdamW(model.parameters(), lr=lr) + elif config["optimizer"] == "sgd": + optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) + else: + raise ValueError(f"Optimizer {config['optimizer']} not supported") + + scheduler = None + if config["scheduler"] is not None: + config["scheduler"] = config["scheduler"].lower() + if config["scheduler"] == "cosine": + print("Using cosine annealing with T_max", config["epochs"]) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=config["epochs"] + ) + elif config["scheduler"] == "cyclic": + print("Using cyclic LR with cycle", config["cyclic_period"]) + scheduler = torch.optim.lr_scheduler.CyclicLR( + optimizer, + base_lr=lr / 10., + max_lr=lr, + step_size_up=config["cyclic_period"] // 2, + cycle_momentum=False, + ) + elif config["scheduler"] == "plateau": + print("Using ReduceLROnPlateau") + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + factor=config["plateau_factor"], + patience=config["plateau_patience"], + verbose=True, + ) + else: + raise ValueError(f"Scheduler {config['scheduler']} not supported") + + if config["warmup"]: + print("Using warmup scheduler") + scheduler = GradualWarmupScheduler( + optimizer, + multiplier=1, + total_epoch=config["warmup_epochs"], + after_scheduler=scheduler, + ) + + current_epoch = 0 + if "load_run" in config: + load_project_folder = os.path.join("logs", config["load_run"]) + print("Loading model from ", load_project_folder) + latest_path = os.path.join(load_project_folder, "latest.pth") + latest_checkpoint = torch.load(latest_path) #f"cuda:{}" if torch.cuda.is_available() else "cpu") + load_model(model, latest_checkpoint) + current_epoch = latest_checkpoint["epoch"] + 1 + + # Multi-GPU + if len(config["gpu_ids"]) > 1: + model = nn.DataParallel(model, device_ids=config["gpu_ids"]) + model = model.to(device) + + if "load_run" in config: # load optimizer and scheduler after data parallel + optimizer.load_state_dict(latest_checkpoint["optimizer"].state_dict()) + if scheduler is not None: + scheduler.load_state_dict(latest_checkpoint["scheduler"].state_dict()) + + train_eval_loop( + train_model=config["train"], + model=model, + optimizer=optimizer, + scheduler=scheduler, + train_loader=train_loader, + test_dataloaders=test_dataloaders, + transform=transform, + epochs=config["epochs"], + device=device, + project_folder=config["project_folder"], + normalized=config["normalize"], + print_log_freq=config["print_log_freq"], + image_log_freq=config["image_log_freq"], + num_images_log=config["num_images_log"], + current_epoch=current_epoch, + learn_angle=config["learn_angle"], + alpha=config["alpha"], + use_wandb=config["use_wandb"], + eval_fraction=config["eval_fraction"], + ) + print("FINISHED TRAINING") + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn") + + parser = argparse.ArgumentParser(description="Visual Navigation Transformer") + + # project setup + parser.add_argument( + "--config", + "-c", + default="config/vint.yaml", + type=str, + help="Path to the config file in train_config folder", + ) + args = parser.parse_args() + + with open("config/defaults.yaml", "r") as f: + default_config = yaml.safe_load(f) + + config = default_config + + with open(args.config, "r") as f: + user_config = yaml.safe_load(f) + + config.update(user_config) + + config["run_name"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S") + config["project_folder"] = os.path.join( + "logs", config["project_name"], config["run_name"] + ) + os.makedirs( + config[ + "project_folder" + ], # should error if dir already exists to avoid overwriting and old project + ) + + if config["use_wandb"]: + wandb.login() + wandb.init( + project=config["project_name"], + settings=wandb.Settings(start_method="fork"), + entity="gnmv2", # TODO: change this to your wandb entity + ) + wandb.save(args.config, policy="now") # save the config file + wandb.run.name = config["run_name"] + # update the wandb args with the training configurations + if wandb.run: + wandb.config.update(config) + + print(config) + main(config) diff --git a/train/train_environment.yml b/train/train_environment.yml new file mode 100644 index 0000000..a2a5a1a --- /dev/null +++ b/train/train_environment.yml @@ -0,0 +1,25 @@ +name: vint_train +channels: +- pytorch +dependencies: +- python=3.8.5 +- cudatoolkit=11. +- numpy +- matplotlib +- ipykernel +- pip +- pip: + - torch + - torchvision + - tqdm==4.64.0 + - git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git + - opencv-python==4.6.0.66 + - h5py==3.6.0 + - wandb==0.12.18 + - --extra-index-url https://rospypi.github.io/simple/ + - rosbag + - roslz4 + - prettytable + - efficientnet_pytorch + - warmup_scheduler + - lmdb diff --git a/train/vint_train/__init__.py b/train/vint_train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train/vint_train/data/__init__.py b/train/vint_train/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train/vint_train/data/data_config.yaml b/train/vint_train/data/data_config.yaml new file mode 100644 index 0000000..54303f7 --- /dev/null +++ b/train/vint_train/data/data_config.yaml @@ -0,0 +1,63 @@ + +# global params for diffusion model +# normalized min and max +action_stats: + min: [-2.5, -4] + max: [5, 4] + +# data specific params +recon: + metric_waypoint_spacing: 0.25 # average spacing between waypoints (meters) + + # OPTIONAL (FOR VISUALIZATION ONLY) + camera_metrics: # https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html + camera_height: 0.95 # meters + camera_x_offset: 0.45 # distance between the center of the robot and the forward facing camera + camera_matrix: + fx: 272.547000 + fy: 266.358000 + cx: 320.000000 + cy: 220.000000 + dist_coeffs: + k1: -0.038483 + k2: -0.010456 + p1: 0.003930 + p2: -0.001007 + k3: 0.0 + +scand: + metric_waypoint_spacing: 0.38 + +tartan_drive: + metric_waypoint_spacing: 0.72 + +go_stanford: + metric_waypoint_spacing: 0.12 + +# private datasets: +cory_hall: + metric_waypoint_spacing: 0.06 + +seattle: + metric_waypoint_spacing: 0.35 + +racer: + metric_waypoint_spacing: 0.38 + +carla_intvns: + metric_waypoint_spacing: 1.39 + +carla_cil: + metric_waypoint_spacing: 1.27 + +carla_intvns: + metric_waypoint_spacing: 1.39 + +carla: + metric_waypoint_spacing: 1.59 + image_path_func: get_image_path + +sacson: + metric_waypoint_spacing: 0.255 + +# add your own dataset params here: diff --git a/train/vint_train/data/data_utils.py b/train/vint_train/data/data_utils.py new file mode 100644 index 0000000..07f22f5 --- /dev/null +++ b/train/vint_train/data/data_utils.py @@ -0,0 +1,135 @@ +import numpy as np +import os +from PIL import Image +from typing import Any, Iterable, Tuple + +import torch +from torchvision import transforms +import torchvision.transforms.functional as TF +import torch.nn.functional as F +import io +from typing import Union + +VISUALIZATION_IMAGE_SIZE = (160, 120) +IMAGE_ASPECT_RATIO = ( + 4 / 3 +) # all images are centered cropped to a 4:3 aspect ratio in training + + + +def get_data_path(data_folder: str, f: str, time: int, data_type: str = "image"): + data_ext = { + "image": ".jpg", + # add more data types here + } + return os.path.join(data_folder, f, f"{str(time)}{data_ext[data_type]}") + + +def yaw_rotmat(yaw: float) -> np.ndarray: + return np.array( + [ + [np.cos(yaw), -np.sin(yaw), 0.0], + [np.sin(yaw), np.cos(yaw), 0.0], + [0.0, 0.0, 1.0], + ], + ) + + +def to_local_coords( + positions: np.ndarray, curr_pos: np.ndarray, curr_yaw: float +) -> np.ndarray: + """ + Convert positions to local coordinates + + Args: + positions (np.ndarray): positions to convert + curr_pos (np.ndarray): current position + curr_yaw (float): current yaw + Returns: + np.ndarray: positions in local coordinates + """ + rotmat = yaw_rotmat(curr_yaw) + if positions.shape[-1] == 2: + rotmat = rotmat[:2, :2] + elif positions.shape[-1] == 3: + pass + else: + raise ValueError + + return (positions - curr_pos).dot(rotmat) + + +def calculate_deltas(waypoints: torch.Tensor) -> torch.Tensor: + """ + Calculate deltas between waypoints + + Args: + waypoints (torch.Tensor): waypoints + Returns: + torch.Tensor: deltas + """ + num_params = waypoints.shape[1] + origin = torch.zeros(1, num_params) + prev_waypoints = torch.concat((origin, waypoints[:-1]), axis=0) + deltas = waypoints - prev_waypoints + if num_params > 2: + return calculate_sin_cos(deltas) + return deltas + + +def calculate_sin_cos(waypoints: torch.Tensor) -> torch.Tensor: + """ + Calculate sin and cos of the angle + + Args: + waypoints (torch.Tensor): waypoints + Returns: + torch.Tensor: waypoints with sin and cos of the angle + """ + assert waypoints.shape[1] == 3 + angle_repr = torch.zeros_like(waypoints[:, :2]) + angle_repr[:, 0] = torch.cos(waypoints[:, 2]) + angle_repr[:, 1] = torch.sin(waypoints[:, 2]) + return torch.concat((waypoints[:, :2], angle_repr), axis=1) + + +def transform_images( + img: Image.Image, transform: transforms, image_resize_size: Tuple[int, int], aspect_ratio: float = IMAGE_ASPECT_RATIO +): + w, h = img.size + if w > h: + img = TF.center_crop(img, (h, int(h * aspect_ratio))) # crop to the right ratio + else: + img = TF.center_crop(img, (int(w / aspect_ratio), w)) + viz_img = img.resize(VISUALIZATION_IMAGE_SIZE) + viz_img = TF.to_tensor(viz_img) + img = img.resize(image_resize_size) + transf_img = transform(img) + return viz_img, transf_img + + +def resize_and_aspect_crop( + img: Image.Image, image_resize_size: Tuple[int, int], aspect_ratio: float = IMAGE_ASPECT_RATIO +): + w, h = img.size + if w > h: + img = TF.center_crop(img, (h, int(h * aspect_ratio))) # crop to the right ratio + else: + img = TF.center_crop(img, (int(w / aspect_ratio), w)) + img = img.resize(image_resize_size) + resize_img = TF.to_tensor(img) + return resize_img + + +def img_path_to_data(path: Union[str, io.BytesIO], image_resize_size: Tuple[int, int]) -> torch.Tensor: + """ + Load an image from a path and transform it + Args: + path (str): path to the image + image_resize_size (Tuple[int, int]): size to resize the image to + Returns: + torch.Tensor: resized image as tensor + """ + # return transform_images(Image.open(path), transform, image_resize_size, aspect_ratio) + return resize_and_aspect_crop(Image.open(path), image_resize_size) + diff --git a/train/vint_train/data/vint_dataset.py b/train/vint_train/data/vint_dataset.py new file mode 100644 index 0000000..87b101c --- /dev/null +++ b/train/vint_train/data/vint_dataset.py @@ -0,0 +1,362 @@ +import numpy as np +import os +import pickle +import yaml +from typing import Any, Dict, List, Optional, Tuple +import tqdm +import io +import lmdb + +import torch +from torch.utils.data import Dataset +import torchvision.transforms.functional as TF + +from vint_train.data.data_utils import ( + img_path_to_data, + calculate_sin_cos, + get_data_path, + to_local_coords, +) + +class ViNT_Dataset(Dataset): + def __init__( + self, + data_folder: str, + data_split_folder: str, + dataset_name: str, + image_size: Tuple[int, int], + waypoint_spacing: int, + min_dist_cat: int, + max_dist_cat: int, + min_action_distance: int, + max_action_distance: int, + negative_mining: bool, + len_traj_pred: int, + learn_angle: bool, + context_size: int, + context_type: str = "temporal", + end_slack: int = 0, + goals_per_obs: int = 1, + normalize: bool = True, + obs_type: str = "image", + goal_type: str = "image", + ): + """ + Main ViNT dataset class + + Args: + data_folder (string): Directory with all the image data + data_split_folder (string): Directory with filepaths.txt, a list of all trajectory names in the dataset split that are each seperated by a newline + dataset_name (string): Name of the dataset [recon, go_stanford, scand, tartandrive, etc.] + waypoint_spacing (int): Spacing between waypoints + min_dist_cat (int): Minimum distance category to use + max_dist_cat (int): Maximum distance category to use + negative_mining (bool): Whether to use negative mining from the ViNG paper (Shah et al.) (https://arxiv.org/abs/2012.09812) + len_traj_pred (int): Length of trajectory of waypoints to predict if this is an action dataset + learn_angle (bool): Whether to learn the yaw of the robot at each predicted waypoint if this is an action dataset + context_size (int): Number of previous observations to use as context + context_type (str): Whether to use temporal, randomized, or randomized temporal context + end_slack (int): Number of timesteps to ignore at the end of the trajectory + goals_per_obs (int): Number of goals to sample per observation + normalize (bool): Whether to normalize the distances or actions + goal_type (str): What data type to use for the goal. The only one supported is "image" for now. + """ + self.data_folder = data_folder + self.data_split_folder = data_split_folder + self.dataset_name = dataset_name + + traj_names_file = os.path.join(data_split_folder, "traj_names.txt") + with open(traj_names_file, "r") as f: + file_lines = f.read() + self.traj_names = file_lines.split("\n") + if "" in self.traj_names: + self.traj_names.remove("") + + self.image_size = image_size + self.waypoint_spacing = waypoint_spacing + self.distance_categories = list( + range(min_dist_cat, max_dist_cat + 1, self.waypoint_spacing) + ) + self.min_dist_cat = self.distance_categories[0] + self.max_dist_cat = self.distance_categories[-1] + self.negative_mining = negative_mining + if self.negative_mining: + self.distance_categories.append(-1) + self.len_traj_pred = len_traj_pred + self.learn_angle = learn_angle + + self.min_action_distance = min_action_distance + self.max_action_distance = max_action_distance + + self.context_size = context_size + assert context_type in { + "temporal", + "randomized", + "randomized_temporal", + }, "context_type must be one of temporal, randomized, randomized_temporal" + self.context_type = context_type + self.end_slack = end_slack + self.goals_per_obs = goals_per_obs + self.normalize = normalize + self.obs_type = obs_type + self.goal_type = goal_type + + # load data/data_config.yaml + with open( + os.path.join(os.path.dirname(__file__), "data_config.yaml"), "r" + ) as f: + all_data_config = yaml.safe_load(f) + assert ( + self.dataset_name in all_data_config + ), f"Dataset {self.dataset_name} not found in data_config.yaml" + dataset_names = list(all_data_config.keys()) + dataset_names.sort() + # use this index to retrieve the dataset name from the data_config.yaml + self.dataset_index = dataset_names.index(self.dataset_name) + self.data_config = all_data_config[self.dataset_name] + self.trajectory_cache = {} + self._load_index() + self._build_caches() + + if self.learn_angle: + self.num_action_params = 3 + else: + self.num_action_params = 2 + + def __getstate__(self): + state = self.__dict__.copy() + state["_image_cache"] = None + return state + + def __setstate__(self, state): + self.__dict__ = state + self._build_caches() + + def _build_caches(self, use_tqdm: bool = True): + """ + Build a cache of images for faster loading using LMDB + """ + cache_filename = os.path.join( + self.data_split_folder, + f"dataset_{self.dataset_name}.lmdb", + ) + + # Load all the trajectories into memory. These should already be loaded, but just in case. + for traj_name in self.traj_names: + self._get_trajectory(traj_name) + + """ + If the cache file doesn't exist, create it by iterating through the dataset and writing each image to the cache + """ + if not os.path.exists(cache_filename): + tqdm_iterator = tqdm.tqdm( + self.goals_index, + disable=not use_tqdm, + dynamic_ncols=True, + desc=f"Building LMDB cache for {self.dataset_name}" + ) + with lmdb.open(cache_filename, map_size=2**40) as image_cache: + with image_cache.begin(write=True) as txn: + for traj_name, time in tqdm_iterator: + image_path = get_data_path(self.data_folder, traj_name, time) + with open(image_path, "rb") as f: + txn.put(image_path.encode(), f.read()) + + # Reopen the cache file in read-only mode + self._image_cache: lmdb.Environment = lmdb.open(cache_filename, readonly=True) + + def _build_index(self, use_tqdm: bool = False): + """ + Build an index consisting of tuples (trajectory name, time, max goal distance) + """ + samples_index = [] + goals_index = [] + + for traj_name in tqdm.tqdm(self.traj_names, disable=not use_tqdm, dynamic_ncols=True): + traj_data = self._get_trajectory(traj_name) + traj_len = len(traj_data["position"]) + + for goal_time in range(0, traj_len): + goals_index.append((traj_name, goal_time)) + + begin_time = self.context_size * self.waypoint_spacing + end_time = traj_len - self.end_slack - self.len_traj_pred * self.waypoint_spacing + for curr_time in range(begin_time, end_time): + max_goal_distance = min(self.max_dist_cat * self.waypoint_spacing, traj_len - curr_time - 1) + samples_index.append((traj_name, curr_time, max_goal_distance)) + + return samples_index, goals_index + + def _sample_goal(self, trajectory_name, curr_time, max_goal_dist): + """ + Sample a goal from the future in the same trajectory. + Returns: (trajectory_name, goal_time, goal_is_negative) + """ + goal_offset = np.random.randint(0, max_goal_dist + 1) + if goal_offset == 0: + trajectory_name, goal_time = self._sample_negative() + return trajectory_name, goal_time, True + else: + goal_time = curr_time + int(goal_offset * self.waypoint_spacing) + return trajectory_name, goal_time, False + + def _sample_negative(self): + """ + Sample a goal from a (likely) different trajectory. + """ + return self.goals_index[np.random.randint(0, len(self.goals_index))] + + def _load_index(self) -> None: + """ + Generates a list of tuples of (obs_traj_name, goal_traj_name, obs_time, goal_time) for each observation in the dataset + """ + index_to_data_path = os.path.join( + self.data_split_folder, + f"dataset_dist_{self.min_dist_cat}_to_{self.max_dist_cat}_context_{self.context_type}_n{self.context_size}_slack_{self.end_slack}.pkl", + ) + try: + # load the index_to_data if it already exists (to save time) + with open(index_to_data_path, "rb") as f: + self.index_to_data, self.goals_index = pickle.load(f) + except: + # if the index_to_data file doesn't exist, create it + self.index_to_data, self.goals_index = self._build_index() + with open(index_to_data_path, "wb") as f: + pickle.dump((self.index_to_data, self.goals_index), f) + + def _load_image(self, trajectory_name, time): + image_path = get_data_path(self.data_folder, trajectory_name, time) + + try: + with self._image_cache.begin() as txn: + image_buffer = txn.get(image_path.encode()) + image_bytes = bytes(image_buffer) + image_bytes = io.BytesIO(image_bytes) + return img_path_to_data(image_bytes, self.image_size) + except TypeError: + print(f"Failed to load image {image_path}") + + def _compute_actions(self, traj_data, curr_time, goal_time): + start_index = curr_time + end_index = curr_time + self.len_traj_pred * self.waypoint_spacing + 1 + yaw = traj_data["yaw"][start_index:end_index:self.waypoint_spacing] + positions = traj_data["position"][start_index:end_index:self.waypoint_spacing] + goal_pos = traj_data["position"][min(goal_time, len(traj_data["position"]) - 1)] + + if len(yaw.shape) == 2: + yaw = yaw.squeeze(1) + + if yaw.shape != (self.len_traj_pred + 1,): + const_len = self.len_traj_pred + 1 - yaw.shape[0] + yaw = np.concatenate([yaw, np.repeat(yaw[-1], const_len)]) + positions = np.concatenate([positions, np.repeat(positions[-1][None], const_len, axis=0)], axis=0) + + assert yaw.shape == (self.len_traj_pred + 1,), f"{yaw.shape} and {(self.len_traj_pred + 1,)} should be equal" + assert positions.shape == (self.len_traj_pred + 1, 2), f"{positions.shape} and {(self.len_traj_pred + 1, 2)} should be equal" + + waypoints = to_local_coords(positions, positions[0], yaw[0]) + goal_pos = to_local_coords(goal_pos, positions[0], yaw[0]) + + assert waypoints.shape == (self.len_traj_pred + 1, 2), f"{waypoints.shape} and {(self.len_traj_pred + 1, 2)} should be equal" + + if self.learn_angle: + yaw = yaw[1:] - yaw[0] + actions = np.concatenate([waypoints[1:], yaw[:, None]], axis=-1) + else: + actions = waypoints[1:] + + if self.normalize: + actions[:, :2] /= self.data_config["metric_waypoint_spacing"] * self.waypoint_spacing + goal_pos /= self.data_config["metric_waypoint_spacing"] * self.waypoint_spacing + + assert actions.shape == (self.len_traj_pred, self.num_action_params), f"{actions.shape} and {(self.len_traj_pred, self.num_action_params)} should be equal" + + return actions, goal_pos + + def _get_trajectory(self, trajectory_name): + if trajectory_name in self.trajectory_cache: + return self.trajectory_cache[trajectory_name] + else: + with open(os.path.join(self.data_folder, trajectory_name, "traj_data.pkl"), "rb") as f: + traj_data = pickle.load(f) + self.trajectory_cache[trajectory_name] = traj_data + return traj_data + + def __len__(self) -> int: + return len(self.index_to_data) + + def __getitem__(self, i: int) -> Tuple[torch.Tensor]: + """ + Args: + i (int): index to ith datapoint + Returns: + Tuple of tensors containing the context, observation, goal, transformed context, transformed observation, transformed goal, distance label, and action label + obs_image (torch.Tensor): tensor of shape [3, H, W] containing the image of the robot's observation + goal_image (torch.Tensor): tensor of shape [3, H, W] containing the subgoal image + dist_label (torch.Tensor): tensor of shape (1,) containing the distance labels from the observation to the goal + action_label (torch.Tensor): tensor of shape (5, 2) or (5, 4) (if training with angle) containing the action labels from the observation to the goal + which_dataset (torch.Tensor): index of the datapoint in the dataset [for identifying the dataset for visualization when using multiple datasets] + """ + f_curr, curr_time, max_goal_dist = self.index_to_data[i] + f_goal, goal_time, goal_is_negative = self._sample_goal(f_curr, curr_time, max_goal_dist) + + # Load images + context = [] + if self.context_type == "temporal": + # sample the last self.context_size times from interval [0, curr_time) + context_times = list( + range( + curr_time + -self.context_size * self.waypoint_spacing, + curr_time + 1, + self.waypoint_spacing, + ) + ) + context = [(f_curr, t) for t in context_times] + else: + raise ValueError(f"Invalid context type {self.context_type}") + + obs_image = torch.cat([ + self._load_image(f, t) for f, t in context + ]) + + # Load goal image + goal_image = self._load_image(f_goal, goal_time) + + # Load other trajectory data + curr_traj_data = self._get_trajectory(f_curr) + curr_traj_len = len(curr_traj_data["position"]) + assert curr_time < curr_traj_len, f"{curr_time} and {curr_traj_len}" + + goal_traj_data = self._get_trajectory(f_goal) + goal_traj_len = len(goal_traj_data["position"]) + assert goal_time < goal_traj_len, f"{goal_time} an {goal_traj_len}" + + # Compute actions + actions, goal_pos = self._compute_actions(curr_traj_data, curr_time, goal_time) + + # Compute distances + if goal_is_negative: + distance = self.max_dist_cat + else: + distance = (goal_time - curr_time) // self.waypoint_spacing + assert (goal_time - curr_time) % self.waypoint_spacing == 0, f"{goal_time} and {curr_time} should be separated by an integer multiple of {self.waypoint_spacing}" + + actions_torch = torch.as_tensor(actions, dtype=torch.float32) + if self.learn_angle: + actions_torch = calculate_sin_cos(actions_torch) + + action_mask = ( + (distance < self.max_action_distance) and + (distance > self.min_action_distance) and + (not goal_is_negative) + ) + + return ( + torch.as_tensor(obs_image, dtype=torch.float32), + torch.as_tensor(goal_image, dtype=torch.float32), + actions_torch, + torch.as_tensor(distance, dtype=torch.int64), + torch.as_tensor(goal_pos, dtype=torch.float32), + torch.as_tensor(self.dataset_index, dtype=torch.int64), + torch.as_tensor(action_mask, dtype=torch.float32), + ) diff --git a/train/vint_train/models/__init__.py b/train/vint_train/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train/vint_train/models/base_model.py b/train/vint_train/models/base_model.py new file mode 100644 index 0000000..5762516 --- /dev/null +++ b/train/vint_train/models/base_model.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn + +from typing import List, Dict, Optional, Tuple + + +class BaseModel(nn.Module): + def __init__( + self, + context_size: int = 5, + len_traj_pred: Optional[int] = 5, + learn_angle: Optional[bool] = True, + ) -> None: + """ + Base Model main class + Args: + context_size (int): how many previous observations to used for context + len_traj_pred (int): how many waypoints to predict in the future + learn_angle (bool): whether to predict the yaw of the robot + """ + super(BaseModel, self).__init__() + self.context_size = context_size + self.learn_angle = learn_angle + self.len_trajectory_pred = len_traj_pred + if self.learn_angle: + self.num_action_params = 4 # last two dims are the cos and sin of the angle + else: + self.num_action_params = 2 + + def flatten(self, z: torch.Tensor) -> torch.Tensor: + z = nn.functional.adaptive_avg_pool2d(z, (1, 1)) + z = torch.flatten(z, 1) + return z + + def forward( + self, obs_img: torch.tensor, goal_img: torch.tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of the model + Args: + obs_img (torch.Tensor): batch of observations + goal_img (torch.Tensor): batch of goals + Returns: + dist_pred (torch.Tensor): predicted distance to goal + action_pred (torch.Tensor): predicted action + """ + raise NotImplementedError diff --git a/train/vint_train/models/gnm.py b/train/vint_train/models/gnm.py new file mode 100644 index 0000000..701bcd9 --- /dev/null +++ b/train/vint_train/models/gnm.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import List, Dict, Optional, Tuple +from vint_train.models.modified_mobilenetv2 import MobileNetEncoder +from vint_train.models.base_model import BaseModel + + +class GNM(BaseModel): + def __init__( + self, + context_size: int = 5, + len_traj_pred: Optional[int] = 5, + learn_angle: Optional[bool] = True, + obs_encoding_size: Optional[int] = 1024, + goal_encoding_size: Optional[int] = 1024, + ) -> None: + """ + GNM main class + Args: + context_size (int): how many previous observations to used for context + len_traj_pred (int): how many waypoints to predict in the future + learn_angle (bool): whether to predict the yaw of the robot + obs_encoding_size (int): size of the encoding of the observation images + goal_encoding_size (int): size of the encoding of the goal images + """ + super(GNM, self).__init__(context_size, len_traj_pred, learn_angle) + mobilenet = MobileNetEncoder(num_images=1 + self.context_size) + self.obs_mobilenet = mobilenet.features + self.obs_encoding_size = obs_encoding_size + self.compress_observation = nn.Sequential( + nn.Linear(mobilenet.last_channel, self.obs_encoding_size), + nn.ReLU(), + ) + stacked_mobilenet = MobileNetEncoder( + num_images=2 + self.context_size + ) # stack the goal and the current observation + self.goal_mobilenet = stacked_mobilenet.features + self.goal_encoding_size = goal_encoding_size + self.compress_goal = nn.Sequential( + nn.Linear(stacked_mobilenet.last_channel, 1024), + nn.ReLU(), + nn.Linear(1024, self.goal_encoding_size), + nn.ReLU(), + ) + self.linear_layers = nn.Sequential( + nn.Linear(self.goal_encoding_size + self.obs_encoding_size, 256), + nn.ReLU(), + nn.Linear(256, 32), + nn.ReLU(), + ) + self.dist_predictor = nn.Sequential( + nn.Linear(32, 1), + ) + self.action_predictor = nn.Sequential( + nn.Linear(32, self.len_trajectory_pred * self.num_action_params), + ) + + def forward( + self, obs_img: torch.tensor, goal_img: torch.tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + obs_encoding = self.obs_mobilenet(obs_img) + obs_encoding = self.flatten(obs_encoding) + obs_encoding = self.compress_observation(obs_encoding) + + obs_goal_input = torch.cat([obs_img, goal_img], dim=1) + goal_encoding = self.goal_mobilenet(obs_goal_input) + goal_encoding = self.flatten(goal_encoding) + goal_encoding = self.compress_goal(goal_encoding) + + z = torch.cat([obs_encoding, goal_encoding], dim=1) + z = self.linear_layers(z) + dist_pred = self.dist_predictor(z) + action_pred = self.action_predictor(z) + + # augment outputs to match labels size-wise + action_pred = action_pred.reshape( + (action_pred.shape[0], self.len_trajectory_pred, self.num_action_params) + ) + action_pred[:, :, :2] = torch.cumsum( + action_pred[:, :, :2], dim=1 + ) # convert position deltas into waypoints + if self.learn_angle: + action_pred[:, :, 2:] = F.normalize( + action_pred[:, :, 2:].clone(), dim=-1 + ) # normalize the angle prediction + return dist_pred, action_pred diff --git a/train/vint_train/models/modified_mobilenetv2.py b/train/vint_train/models/modified_mobilenetv2.py new file mode 100644 index 0000000..2a0d3e8 --- /dev/null +++ b/train/vint_train/models/modified_mobilenetv2.py @@ -0,0 +1,143 @@ +# modified from PyTorch torchvision library +from typing import Callable, Any, Optional, List + +import torch +from torch import Tensor +from torch import nn + +from torchvision.ops.misc import ConvNormActivation +from torchvision.models._utils import _make_divisible +from torchvision.models.mobilenetv2 import InvertedResidual + + +class MobileNetEncoder(nn.Module): + def __init__( + self, + num_images: int = 1, + num_classes: int = 1000, + width_mult: float = 1.0, + inverted_residual_setting: Optional[List[List[int]]] = None, + round_nearest: int = 8, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + dropout: float = 0.2, + ) -> None: + """ + MobileNet V2 main class + Args: + num_images (int): number of images stacked in the input tensor + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + norm_layer: Module specifying the normalization layer to use + dropout (float): The droupout probability + """ + super().__init__() + + if block is None: + block = InvertedResidual + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + input_channel = 32 + last_channel = 1280 + + if inverted_residual_setting is None: + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if ( + len(inverted_residual_setting) == 0 + or len(inverted_residual_setting[0]) != 4 + ): + raise ValueError( + f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}" + ) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible( + last_channel * max(1.0, width_mult), round_nearest + ) + features: List[nn.Module] = [ + ConvNormActivation( + num_images * 3, + input_channel, + stride=2, + norm_layer=norm_layer, + activation_layer=nn.ReLU6, + ) + ] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append( + block( + input_channel, + output_channel, + stride, + expand_ratio=t, + norm_layer=norm_layer, + ) + ) + input_channel = output_channel + # building last several layers + features.append( + # Conv2dNormActivation( + ConvNormActivation( + input_channel, + self.last_channel, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=nn.ReLU6, + ) + ) + # make it nn.Sequential + self.features = nn.Sequential(*features) + + # # building classifier + self.classifier = nn.Sequential( + nn.Dropout(p=dropout), + nn.Linear(self.last_channel, num_classes), + ) + + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor) -> Tensor: + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + x = self.features(x) + # Cannot use "squeeze" as batch-size can be 1 + x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) diff --git a/train/vint_train/models/self_attention.py b/train/vint_train/models/self_attention.py new file mode 100644 index 0000000..a161254 --- /dev/null +++ b/train/vint_train/models/self_attention.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, max_seq_len=6): + super().__init__() + + # Compute the positional encoding once + pos_enc = torch.zeros(max_seq_len, d_model) + pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pos_enc[:, 0::2] = torch.sin(pos * div_term) + pos_enc[:, 1::2] = torch.cos(pos * div_term) + pos_enc = pos_enc.unsqueeze(0) + + # Register the positional encoding as a buffer to avoid it being + # considered a parameter when saving the model + self.register_buffer('pos_enc', pos_enc) + + def forward(self, x): + # Add the positional encoding to the input + x = x + self.pos_enc[:, :x.size(1), :] + return x + +class MultiLayerDecoder(nn.Module): + def __init__(self, embed_dim=512, seq_len=6, output_layers=[256, 128, 64], nhead=8, num_layers=8, ff_dim_factor=4): + super(MultiLayerDecoder, self).__init__() + self.positional_encoding = PositionalEncoding(embed_dim, max_seq_len=seq_len) + self.sa_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=ff_dim_factor*embed_dim, activation="gelu", batch_first=True, norm_first=True) + self.sa_decoder = nn.TransformerEncoder(self.sa_layer, num_layers=num_layers) + self.output_layers = nn.ModuleList([nn.Linear(seq_len*embed_dim, embed_dim)]) + self.output_layers.append(nn.Linear(embed_dim, output_layers[0])) + for i in range(len(output_layers)-1): + self.output_layers.append(nn.Linear(output_layers[i], output_layers[i+1])) + + def forward(self, x): + if self.positional_encoding: x = self.positional_encoding(x) + x = self.sa_decoder(x) + # currently, x is [batch_size, seq_len, embed_dim] + x = x.reshape(x.shape[0], -1) + for i in range(len(self.output_layers)): + x = self.output_layers[i](x) + x = F.relu(x) + return x diff --git a/train/vint_train/models/vint.py b/train/vint_train/models/vint.py new file mode 100644 index 0000000..33b0304 --- /dev/null +++ b/train/vint_train/models/vint.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Dict, Optional, Tuple +from efficientnet_pytorch import EfficientNet +from vint_train.models.base_model import BaseModel +from vint_train.models.self_attention import MultiLayerDecoder + +class ViNT(BaseModel): + def __init__( + self, + context_size: int = 5, + len_traj_pred: Optional[int] = 5, + learn_angle: Optional[bool] = True, + obs_encoder: Optional[str] = "efficientnet-b0", + obs_encoding_size: Optional[int] = 512, + late_fusion: Optional[bool] = False, + mha_num_attention_heads: Optional[int] = 2, + mha_num_attention_layers: Optional[int] = 2, + mha_ff_dim_factor: Optional[int] = 4, + ) -> None: + """ + ViNT class: uses a Transformer-based architecture to encode (current and past) visual observations + and goals using an EfficientNet CNN, and predicts temporal distance and normalized actions + in an embodiment-agnostic manner + Args: + context_size (int): how many previous observations to used for context + len_traj_pred (int): how many waypoints to predict in the future + learn_angle (bool): whether to predict the yaw of the robot + obs_encoder (str): name of the EfficientNet architecture to use for encoding observations (ex. "efficientnet-b0") + obs_encoding_size (int): size of the encoding of the observation images + goal_encoding_size (int): size of the encoding of the goal images + """ + super(ViNT, self).__init__(context_size, len_traj_pred, learn_angle) + self.obs_encoding_size = obs_encoding_size + self.goal_encoding_size = obs_encoding_size + + self.late_fusion = late_fusion + if obs_encoder.split("-")[0] == "efficientnet": + self.obs_encoder = EfficientNet.from_name(obs_encoder, in_channels=3) # context + self.num_obs_features = self.obs_encoder._fc.in_features + if self.late_fusion: + self.goal_encoder = EfficientNet.from_name("efficientnet-b0", in_channels=3) + else: + self.goal_encoder = EfficientNet.from_name("efficientnet-b0", in_channels=6) # obs+goal + self.num_goal_features = self.goal_encoder._fc.in_features + else: + raise NotImplementedError + + if self.num_obs_features != self.obs_encoding_size: + self.compress_obs_enc = nn.Linear(self.num_obs_features, self.obs_encoding_size) + else: + self.compress_obs_enc = nn.Identity() + + if self.num_goal_features != self.goal_encoding_size: + self.compress_goal_enc = nn.Linear(self.num_goal_features, self.goal_encoding_size) + else: + self.compress_goal_enc = nn.Identity() + + self.decoder = MultiLayerDecoder( + embed_dim=self.obs_encoding_size, + seq_len=self.context_size+2, + output_layers=[256, 128, 64, 32], + nhead=mha_num_attention_heads, + num_layers=mha_num_attention_layers, + ff_dim_factor=mha_ff_dim_factor, + ) + self.dist_predictor = nn.Sequential( + nn.Linear(32, 1), + ) + self.action_predictor = nn.Sequential( + nn.Linear(32, self.len_trajectory_pred * self.num_action_params), + ) + + def forward( + self, obs_img: torch.tensor, goal_img: torch.tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + + # get the fused observation and goal encoding + if self.late_fusion: + goal_encoding = self.goal_encoder.extract_features(goal_img) + else: + obsgoal_img = torch.cat([obs_img[:, 3*self.context_size:, :, :], goal_img], dim=1) + goal_encoding = self.goal_encoder.extract_features(obsgoal_img) + goal_encoding = self.goal_encoder._avg_pooling(goal_encoding) + if self.goal_encoder._global_params.include_top: + goal_encoding = goal_encoding.flatten(start_dim=1) + goal_encoding = self.goal_encoder._dropout(goal_encoding) + # currently, the size of goal_encoding is [batch_size, num_goal_features] + goal_encoding = self.compress_goal_enc(goal_encoding) + if len(goal_encoding.shape) == 2: + goal_encoding = goal_encoding.unsqueeze(1) + # currently, the size of goal_encoding is [batch_size, 1, self.goal_encoding_size] + assert goal_encoding.shape[2] == self.goal_encoding_size + + # split the observation into context based on the context size + # image size is [batch_size, 3*self.context_size, H, W] + obs_img = torch.split(obs_img, 3, dim=1) + + # image size is [batch_size*self.context_size, 3, H, W] + obs_img = torch.concat(obs_img, dim=0) + + # get the observation encoding + obs_encoding = self.obs_encoder.extract_features(obs_img) + # currently the size is [batch_size*(self.context_size + 1), 1280, H/32, W/32] + obs_encoding = self.obs_encoder._avg_pooling(obs_encoding) + # currently the size is [batch_size*(self.context_size + 1), 1280, 1, 1] + if self.obs_encoder._global_params.include_top: + obs_encoding = obs_encoding.flatten(start_dim=1) + obs_encoding = self.obs_encoder._dropout(obs_encoding) + # currently, the size is [batch_size, self.context_size+2, self.obs_encoding_size] + + obs_encoding = self.compress_obs_enc(obs_encoding) + # currently, the size is [batch_size*(self.context_size + 1), self.obs_encoding_size] + # reshape the obs_encoding to [context + 1, batch, encoding_size], note that the order is flipped + obs_encoding = obs_encoding.reshape((self.context_size+1, -1, self.obs_encoding_size)) + obs_encoding = torch.transpose(obs_encoding, 0, 1) + # currently, the size is [batch_size, self.context_size+1, self.obs_encoding_size] + + # concatenate the goal encoding to the observation encoding + tokens = torch.cat((obs_encoding, goal_encoding), dim=1) + final_repr = self.decoder(tokens) + # currently, the size is [batch_size, 32] + + dist_pred = self.dist_predictor(final_repr) + action_pred = self.action_predictor(final_repr) + + # augment outputs to match labels size-wise + action_pred = action_pred.reshape( + (action_pred.shape[0], self.len_trajectory_pred, self.num_action_params) + ) + action_pred[:, :, :2] = torch.cumsum( + action_pred[:, :, :2], dim=1 + ) # convert position deltas into waypoints + if self.learn_angle: + action_pred[:, :, 2:] = F.normalize( + action_pred[:, :, 2:].clone(), dim=-1 + ) # normalize the angle prediction + return dist_pred, action_pred \ No newline at end of file diff --git a/train/vint_train/process_data/__init__.py b/train/vint_train/process_data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train/vint_train/process_data/process_bags_config.yaml b/train/vint_train/process_data/process_bags_config.yaml new file mode 100644 index 0000000..5d801a7 --- /dev/null +++ b/train/vint_train/process_data/process_bags_config.yaml @@ -0,0 +1,29 @@ +tartan_drive: + odomtopics: "/odometry/filtered_odom" + imtopics: "/multisense/left/image_rect_color" + ang_offset: 1.5707963267948966 # pi/2 + img_process_func: "process_tartan_img" + odom_process_func: "nav_to_xy_yaw" + +scand: + odomtopics: ["/odom", "/jackal_velocity_controller/odom"] + imtopics: ["/image_raw/compressed", "/camera/rgb/image_raw/compressed"] + ang_offset: 0.0 + img_process_func: "process_scand_img" + odom_process_func: "nav_to_xy_yaw" + +locobot: + odomtopics: "/odom" + imtopics: "/usb_cam/image_raw" + ang_offset: 0.0 + img_process_func: "process_locobot_img" + odom_process_func: "nav_to_xy_yaw" + +sacson: + odomtopics: "/odometry" + imtopics: "/fisheye_image/compressed" + ang_offset: 0.0 + img_process_func: "process_sacson_img" + odom_process_func: "nav_to_xy_yaw" + +# add your own datasets below: diff --git a/train/vint_train/process_data/process_data_utils.py b/train/vint_train/process_data/process_data_utils.py new file mode 100644 index 0000000..c63f6c0 --- /dev/null +++ b/train/vint_train/process_data/process_data_utils.py @@ -0,0 +1,317 @@ +import numpy as np +import io +import os +import rosbag +from PIL import Image +import cv2 +from typing import Any, Tuple, List, Dict +import torchvision.transforms.functional as TF + +IMAGE_SIZE = (160, 120) +IMAGE_ASPECT_RATIO = 4 / 3 + + +def process_images(im_list: List, img_process_func) -> List: + """ + Process image data from a topic that publishes ros images into a list of PIL images + """ + images = [] + for img_msg in im_list: + img = img_process_func(img_msg) + images.append(img) + return images + + +def process_tartan_img(msg) -> Image: + """ + Process image data from a topic that publishes sensor_msgs/Image to a PIL image for the tartan_drive dataset + """ + img = ros_to_numpy(msg, output_resolution=IMAGE_SIZE) * 255 + img = img.astype(np.uint8) + # reverse the axis order to get the image in the right orientation + img = np.moveaxis(img, 0, -1) + # convert rgb to bgr + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + img = Image.fromarray(img) + return img + + +def process_locobot_img(msg) -> Image: + """ + Process image data from a topic that publishes sensor_msgs/Image to a PIL image for the locobot dataset + """ + img = np.frombuffer(msg.data, dtype=np.uint8).reshape( + msg.height, msg.width, -1) + pil_image = Image.fromarray(img) + return pil_image + + +def process_scand_img(msg) -> Image: + """ + Process image data from a topic that publishes sensor_msgs/CompressedImage to a PIL image for the scand dataset + """ + # convert sensor_msgs/CompressedImage to PIL image + img = Image.open(io.BytesIO(msg.data)) + # center crop image to 4:3 aspect ratio + w, h = img.size + img = TF.center_crop( + img, (h, int(h * IMAGE_ASPECT_RATIO)) + ) # crop to the right ratio + # resize image to IMAGE_SIZE + img = img.resize(IMAGE_SIZE) + return img + + +############## Add custom image processing functions here ############# + +def process_sacson_img(msg) -> Image: + np_arr = np.fromstring(msg.data, np.uint8) + image_np = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(image_np) + return pil_image + + +####################################################################### + + +def process_odom( + odom_list: List, + odom_process_func: Any, + ang_offset: float = 0.0, +) -> Dict[np.ndarray, np.ndarray]: + """ + Process odom data from a topic that publishes nav_msgs/Odometry into position and yaw + """ + xys = [] + yaws = [] + for odom_msg in odom_list: + xy, yaw = odom_process_func(odom_msg, ang_offset) + xys.append(xy) + yaws.append(yaw) + return {"position": np.array(xys), "yaw": np.array(yaws)} + + +def nav_to_xy_yaw(odom_msg, ang_offset: float) -> Tuple[List[float], float]: + """ + Process odom data from a topic that publishes nav_msgs/Odometry into position + """ + + position = odom_msg.pose.pose.position + orientation = odom_msg.pose.pose.orientation + yaw = ( + quat_to_yaw(orientation.x, orientation.y, orientation.z, orientation.w) + + ang_offset + ) + return [position.x, position.y], yaw + + +############ Add custom odometry processing functions here ############ + + +####################################################################### + + +def get_images_and_odom( + bag: rosbag.Bag, + imtopics: List[str] or str, + odomtopics: List[str] or str, + img_process_func: Any, + odom_process_func: Any, + rate: float = 4.0, + ang_offset: float = 0.0, +): + """ + Get image and odom data from a bag file + + Args: + bag (rosbag.Bag): bag file + imtopics (list[str] or str): topic name(s) for image data + odomtopics (list[str] or str): topic name(s) for odom data + img_process_func (Any): function to process image data + odom_process_func (Any): function to process odom data + rate (float, optional): rate to sample data. Defaults to 4.0. + ang_offset (float, optional): angle offset to add to odom data. Defaults to 0.0. + Returns: + img_data (list): list of PIL images + traj_data (list): list of odom data + """ + # check if bag has both topics + odomtopic = None + imtopic = None + if type(imtopics) == str: + imtopic = imtopics + else: + for imt in imtopics: + if bag.get_message_count(imt) > 0: + imtopic = imt + break + if type(odomtopics) == str: + odomtopic = odomtopics + else: + for ot in odomtopics: + if bag.get_message_count(ot) > 0: + odomtopic = ot + break + if not (imtopic and odomtopic): + # bag doesn't have both topics + return None, None + + synced_imdata = [] + synced_odomdata = [] + # get start time of bag in seconds + currtime = bag.get_start_time() + starttime = currtime + + curr_imdata = None + curr_odomdata = None + times = [] + + for topic, msg, t in bag.read_messages(topics=[imtopic, odomtopic]): + if topic == imtopic: + curr_imdata = msg + elif topic == odomtopic: + curr_odomdata = msg + if (t.to_sec() - currtime) >= 1.0 / rate: + if curr_imdata is not None and curr_odomdata is not None: + synced_imdata.append(curr_imdata) + synced_odomdata.append(curr_odomdata) + currtime = t.to_sec() + times.append(currtime - starttime) + + img_data = process_images(synced_imdata, img_process_func) + traj_data = process_odom( + synced_odomdata, + odom_process_func, + ang_offset=ang_offset, + ) + + return img_data, traj_data + + +def is_backwards( + pos1: np.ndarray, yaw1: float, pos2: np.ndarray, eps: float = 1e-5 +) -> bool: + """ + Check if the trajectory is going backwards given the position and yaw of two points + Args: + pos1: position of the first point + + """ + dx, dy = pos2 - pos1 + return dx * np.cos(yaw1) + dy * np.sin(yaw1) < eps + + +# cut out non-positive velocity segments of the trajectory +def filter_backwards( + img_list: List[Image.Image], + traj_data: Dict[str, np.ndarray], + start_slack: int = 0, + end_slack: int = 0, +) -> Tuple[List[np.ndarray], List[int]]: + """ + Cut out non-positive velocity segments of the trajectory + Args: + traj_type: type of trajectory to cut + img_list: list of images + traj_data: dictionary of position and yaw data + start_slack: number of points to ignore at the start of the trajectory + end_slack: number of points to ignore at the end of the trajectory + Returns: + cut_trajs: list of cut trajectories + start_times: list of start times of the cut trajectories + """ + traj_pos = traj_data["position"] + traj_yaws = traj_data["yaw"] + cut_trajs = [] + start = True + + def process_pair(traj_pair: list) -> Tuple[List, Dict]: + new_img_list, new_traj_data = zip(*traj_pair) + new_traj_data = np.array(new_traj_data) + new_traj_pos = new_traj_data[:, :2] + new_traj_yaws = new_traj_data[:, 2] + return (new_img_list, {"position": new_traj_pos, "yaw": new_traj_yaws}) + + for i in range(max(start_slack, 1), len(traj_pos) - end_slack): + pos1 = traj_pos[i - 1] + yaw1 = traj_yaws[i - 1] + pos2 = traj_pos[i] + if not is_backwards(pos1, yaw1, pos2): + if start: + new_traj_pairs = [ + (img_list[i - 1], [*traj_pos[i - 1], traj_yaws[i - 1]]) + ] + start = False + elif i == len(traj_pos) - end_slack - 1: + cut_trajs.append(process_pair(new_traj_pairs)) + else: + new_traj_pairs.append( + (img_list[i - 1], [*traj_pos[i - 1], traj_yaws[i - 1]]) + ) + elif not start: + cut_trajs.append(process_pair(new_traj_pairs)) + start = True + return cut_trajs + + +def quat_to_yaw( + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + w: np.ndarray, +) -> np.ndarray: + """ + Convert a batch quaternion into a yaw angle + yaw is rotation around z in radians (counterclockwise) + """ + t3 = 2.0 * (w * z + x * y) + t4 = 1.0 - 2.0 * (y * y + z * z) + yaw = np.arctan2(t3, t4) + return yaw + + +def ros_to_numpy( + msg, nchannels=3, empty_value=None, output_resolution=None, aggregate="none" +): + """ + Convert a ROS image message to a numpy array + """ + if output_resolution is None: + output_resolution = (msg.width, msg.height) + + is_rgb = "8" in msg.encoding + if is_rgb: + data = np.frombuffer(msg.data, dtype=np.uint8).copy() + else: + data = np.frombuffer(msg.data, dtype=np.float32).copy() + + data = data.reshape(msg.height, msg.width, nchannels) + + if empty_value: + mask = np.isclose(abs(data), empty_value) + fill_value = np.percentile(data[~mask], 99) + data[mask] = fill_value + + data = cv2.resize( + data, + dsize=(output_resolution[0], output_resolution[1]), + interpolation=cv2.INTER_AREA, + ) + + if aggregate == "littleendian": + data = sum([data[:, :, i] * (256**i) for i in range(nchannels)]) + elif aggregate == "bigendian": + data = sum([data[:, :, -(i + 1)] * (256**i) for i in range(nchannels)]) + + if len(data.shape) == 2: + data = np.expand_dims(data, axis=0) + else: + data = np.moveaxis(data, 2, 0) # Switch to channels-first + + if is_rgb: + data = data.astype(np.float32) / ( + 255.0 if aggregate == "none" else 255.0**nchannels + ) + + return data diff --git a/train/vint_train/training/__init__.py b/train/vint_train/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train/vint_train/training/logger.py b/train/vint_train/training/logger.py new file mode 100644 index 0000000..83747cc --- /dev/null +++ b/train/vint_train/training/logger.py @@ -0,0 +1,52 @@ +import numpy as np + + +class Logger: + def __init__( + self, + name: str, + dataset: str, + window_size: int = 10, + rounding: int = 4, + ): + """ + Args: + name (str): Name of the metric + dataset (str): Name of the dataset + window_size (int, optional): Size of the moving average window. Defaults to 10. + rounding (int, optional): Number of decimals to round to. Defaults to 4. + """ + self.data = [] + self.name = name + self.dataset = dataset + self.rounding = rounding + self.window_size = window_size + + def display(self) -> str: + latest = round(self.latest(), self.rounding) + average = round(self.average(), self.rounding) + moving_average = round(self.moving_average(), self.rounding) + output = f"{self.full_name()}: {latest} ({self.window_size}pt moving_avg: {moving_average}) (avg: {average})" + return output + + def log_data(self, data: float): + if not np.isnan(data): + self.data.append(data) + + def full_name(self) -> str: + return f"{self.name} ({self.dataset})" + + def latest(self) -> float: + if len(self.data) > 0: + return self.data[-1] + return np.nan + + def average(self) -> float: + if len(self.data) > 0: + return np.mean(self.data) + return np.nan + + def moving_average(self) -> float: + if len(self.data) > self.window_size: + return np.mean(self.data[-self.window_size :]) + return self.average() \ No newline at end of file diff --git a/train/vint_train/training/train_eval_loop.py b/train/vint_train/training/train_eval_loop.py new file mode 100644 index 0000000..f9e1a00 --- /dev/null +++ b/train/vint_train/training/train_eval_loop.py @@ -0,0 +1,167 @@ +import wandb +import os +import numpy as np +from typing import List, Optional, Dict +from prettytable import PrettyTable + +from vint_train.training.train_utils import train, evaluate + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.optim import Adam +from torchvision import transforms + +def train_eval_loop( + train_model: bool, + model: nn.Module, + optimizer: Adam, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], + train_loader: DataLoader, + test_dataloaders: Dict[str, DataLoader], + transform: transforms, + epochs: int, + device: torch.device, + project_folder: str, + normalized: bool, + wandb_log_freq: int = 10, + print_log_freq: int = 100, + image_log_freq: int = 1000, + num_images_log: int = 8, + current_epoch: int = 0, + alpha: float = 0.5, + learn_angle: bool = True, + use_wandb: bool = True, + eval_fraction: float = 0.25, +): + """ + Train and evaluate the model for several epochs. + + Args: + model: model to train + optimizer: optimizer to use + train_dist_loader: dataloader for training distance predictions + train_action_loader: dataloader for training action predictions + test_dataloaders: dict of dataloaders for testing + transform: transform to apply to images + epochs: number of epochs to train + device: device to train on + project_folder: folder to save checkpoints and logs + log_freq: frequency of logging to wandb + image_log_freq: frequency of logging images to wandb + num_images_log: number of images to log to wandb + current_epoch: epoch to start training from + alpha: tradeoff between distance and action loss + learn_angle: whether to learn the angle or not + use_wandb: whether to log to wandb or not + eval_fraction: fraction of training data to use for evaluation + """ + assert 0 <= alpha <= 1 + latest_path = os.path.join(project_folder, f"latest.pth") + + for epoch in range(current_epoch, current_epoch + epochs): + if train_model: + print( + f"Start ViNT Training Epoch {epoch}/{current_epoch + epochs - 1}" + ) + train( + model=model, + optimizer=optimizer, + train_loader=train_loader, + transform=transform, + device=device, + project_folder=project_folder, + normalized=normalized, + epoch=epoch, + alpha=alpha, + learn_angle=learn_angle, + print_log_freq=print_log_freq, + wandb_log_freq=wandb_log_freq, + image_log_freq=image_log_freq, + num_images_log=num_images_log, + use_wandb=use_wandb, + ) + + avg_total_test_loss = [] + for dataset_type in test_dataloaders: + print( + f"Start {dataset_type} ViNT Testing Epoch {epoch}/{current_epoch + epochs - 1}" + ) + loader = test_dataloaders[dataset_type] + + test_dist_loss, test_action_loss, total_eval_loss = evaluate( + eval_type=dataset_type, + model=model, + eval_loader=loader, + transform=transform, + device=device, + project_folder=project_folder, + normalized=normalized, + epoch=epoch, + alpha=alpha, + learn_angle=learn_angle, + num_images_log=num_images_log, + use_wandb=use_wandb, + eval_fraction=eval_fraction, + ) + + avg_total_test_loss.append(total_eval_loss) + + checkpoint = { + "epoch": epoch, + "model": model, + "optimizer": optimizer, + "avg_total_test_loss": np.mean(avg_total_test_loss), + "scheduler": scheduler + } + # log average eval loss + wandb.log({}, commit=False) + + if scheduler is not None: + # scheduler calls based on the type of scheduler + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + scheduler.step(np.mean(avg_total_test_loss)) + else: + scheduler.step() + wandb.log({ + "avg_total_test_loss": np.mean(avg_total_test_loss), + "lr": optimizer.param_groups[0]["lr"], + }, commit=False) + + numbered_path = os.path.join(project_folder, f"{epoch}.pth") + torch.save(checkpoint, latest_path) + torch.save(checkpoint, numbered_path) # keep track of model at every epoch + + # Flush the last set of eval logs + wandb.log({}) + print() + + +def load_model(model, checkpoint: dict) -> None: + """Load model from checkpoint.""" + loaded_model = checkpoint["model"] + try: # for DataParallel + state_dict = loaded_model.module.state_dict() + model.load_state_dict(state_dict) + except (RuntimeError, AttributeError) as e: + state_dict = loaded_model.state_dict() + model.load_state_dict(state_dict) + + +def load_ema_model(ema_model, state_dict: dict) -> None: + """Load model from checkpoint.""" + ema_model.load_state_dict(state_dict) + + +def count_parameters(model): + table = PrettyTable(["Modules", "Parameters"]) + total_params = 0 + for name, parameter in model.named_parameters(): + if not parameter.requires_grad: continue + params = parameter.numel() + table.add_row([name, params]) + total_params+=params + # print(table) + print(f"Total Trainable Params: {total_params/1e6:.2f}M") + return total_params \ No newline at end of file diff --git a/train/vint_train/training/train_utils.py b/train/vint_train/training/train_utils.py new file mode 100644 index 0000000..0c7fca3 --- /dev/null +++ b/train/vint_train/training/train_utils.py @@ -0,0 +1,431 @@ +import wandb +import os +import numpy as np +from typing import List, Optional, Dict +from prettytable import PrettyTable +import tqdm +import itertools + +from vint_train.visualizing.action_utils import visualize_traj_pred +from vint_train.visualizing.distance_utils import visualize_dist_pred +from vint_train.visualizing.visualize_utils import to_numpy +from vint_train.training.logger import Logger +from vint_train.data.data_utils import VISUALIZATION_IMAGE_SIZE + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.optim import Adam +from torchvision import transforms +import torchvision.transforms.functional as TF + + +def _compute_losses( + dist_label: torch.Tensor, + action_label: torch.Tensor, + dist_pred: torch.Tensor, + action_pred: torch.Tensor, + alpha: float, + learn_angle: bool, + action_mask: torch.Tensor = None, +): + """ + Compute losses for distance and action prediction. + """ + dist_loss = F.mse_loss(dist_pred.squeeze(-1), dist_label.float()) + + def action_reduce(unreduced_loss: torch.Tensor): + # Reduce over non-batch dimensions to get loss per batch element + while unreduced_loss.dim() > 1: + unreduced_loss = unreduced_loss.mean(dim=-1) + assert unreduced_loss.shape == action_mask.shape, f"{unreduced_loss.shape} != {action_mask.shape}" + return (unreduced_loss * action_mask).mean() / (action_mask.mean() + 1e-2) + + # Mask out invalid inputs (for negatives, or when the distance between obs and goal is large) + assert action_pred.shape == action_label.shape, f"{action_pred.shape} != {action_label.shape}" + action_loss = action_reduce(F.mse_loss(action_pred, action_label, reduction="none")) + + action_waypts_cos_similairity = action_reduce(F.cosine_similarity( + action_pred[:, :, :2], action_label[:, :, :2], dim=-1 + )) + multi_action_waypts_cos_sim = action_reduce(F.cosine_similarity( + torch.flatten(action_pred[:, :, :2], start_dim=1), + torch.flatten(action_label[:, :, :2], start_dim=1), + dim=-1, + )) + + results = { + "dist_loss": dist_loss, + "action_loss": action_loss, + "action_waypts_cos_sim": action_waypts_cos_similairity, + "multi_action_waypts_cos_sim": multi_action_waypts_cos_sim, + } + + if learn_angle: + action_orien_cos_sim = action_reduce(F.cosine_similarity( + action_pred[:, :, 2:], action_label[:, :, 2:], dim=-1 + )) + multi_action_orien_cos_sim = action_reduce(F.cosine_similarity( + torch.flatten(action_pred[:, :, 2:], start_dim=1), + torch.flatten(action_label[:, :, 2:], start_dim=1), + dim=-1, + ) + ) + results["action_orien_cos_sim"] = action_orien_cos_sim + results["multi_action_orien_cos_sim"] = multi_action_orien_cos_sim + + total_loss = alpha * 1e-2 * dist_loss + (1 - alpha) * action_loss + results["total_loss"] = total_loss + + return results + + +def _log_data( + i, + epoch, + num_batches, + normalized, + project_folder, + num_images_log, + loggers, + obs_image, + goal_image, + action_pred, + action_label, + dist_pred, + dist_label, + goal_pos, + dataset_index, + use_wandb, + mode, + use_latest, + wandb_log_freq=1, + print_log_freq=1, + image_log_freq=1, + wandb_increment_step=True, +): + """ + Log data to wandb and print to console. + """ + data_log = {} + for key, logger in loggers.items(): + if use_latest: + data_log[logger.full_name()] = logger.latest() + if i % print_log_freq == 0 and print_log_freq != 0: + print(f"(epoch {epoch}) (batch {i}/{num_batches - 1}) {logger.display()}") + else: + data_log[logger.full_name()] = logger.average() + if i % print_log_freq == 0 and print_log_freq != 0: + print(f"(epoch {epoch}) {logger.full_name()} {logger.average()}") + + if use_wandb and i % wandb_log_freq == 0 and wandb_log_freq != 0: + wandb.log(data_log, commit=wandb_increment_step) + + if image_log_freq != 0 and i % image_log_freq == 0: + visualize_dist_pred( + to_numpy(obs_image), + to_numpy(goal_image), + to_numpy(dist_pred), + to_numpy(dist_label), + mode, + project_folder, + epoch, + num_images_log, + use_wandb=use_wandb, + ) + visualize_traj_pred( + to_numpy(obs_image), + to_numpy(goal_image), + to_numpy(dataset_index), + to_numpy(goal_pos), + to_numpy(action_pred), + to_numpy(action_label), + mode, + normalized, + project_folder, + epoch, + num_images_log, + use_wandb=use_wandb, + ) + + +def train( + model: nn.Module, + optimizer: Adam, + train_loader: DataLoader, + transform: transforms, + device: torch.device, + project_folder: str, + normalized: bool, + epoch: int, + alpha: float = 0.5, + learn_angle: bool = True, + print_log_freq: int = 100, + wandb_log_freq: int = 10, + image_log_freq: int = 1000, + num_images_log: int = 8, + use_wandb: bool = True, + use_tqdm: bool = True, +): + """ + Train the model for one epoch. + + Args: + model: model to train + optimizer: optimizer to use + train_loader: dataloader for training + transform: transform to use + device: device to use + project_folder: folder to save images to + epoch: current epoch + alpha: weight of action loss + learn_angle: whether to learn the angle of the action + print_log_freq: how often to print loss + image_log_freq: how often to log images + num_images_log: number of images to log + use_wandb: whether to use wandb + use_tqdm: whether to use tqdm + """ + model.train() + dist_loss_logger = Logger("dist_loss", "train", window_size=print_log_freq) + action_loss_logger = Logger("action_loss", "train", window_size=print_log_freq) + action_waypts_cos_sim_logger = Logger( + "action_waypts_cos_sim", "train", window_size=print_log_freq + ) + multi_action_waypts_cos_sim_logger = Logger( + "multi_action_waypts_cos_sim", "train", window_size=print_log_freq + ) + total_loss_logger = Logger("total_loss", "train", window_size=print_log_freq) + loggers = { + "dist_loss": dist_loss_logger, + "action_loss": action_loss_logger, + "action_waypts_cos_sim": action_waypts_cos_sim_logger, + "multi_action_waypts_cos_sim": multi_action_waypts_cos_sim_logger, + "total_loss": total_loss_logger, + } + + if learn_angle: + action_orien_cos_sim_logger = Logger( + "action_orien_cos_sim", "train", window_size=print_log_freq + ) + multi_action_orien_cos_sim_logger = Logger( + "multi_action_orien_cos_sim", "train", window_size=print_log_freq + ) + loggers["action_orien_cos_sim"] = action_orien_cos_sim_logger + loggers["multi_action_orien_cos_sim"] = multi_action_orien_cos_sim_logger + + num_batches = len(train_loader) + tqdm_iter = tqdm.tqdm( + train_loader, + disable=not use_tqdm, + dynamic_ncols=True, + desc=f"Training epoch {epoch}", + ) + for i, data in enumerate(tqdm_iter): + ( + obs_image, + goal_image, + action_label, + dist_label, + goal_pos, + dataset_index, + action_mask, + ) = data + + obs_images = torch.split(obs_image, 3, dim=1) + viz_obs_image = TF.resize(obs_images[-1], VISUALIZATION_IMAGE_SIZE) + obs_images = [transform(obs_image).to(device) for obs_image in obs_images] + obs_image = torch.cat(obs_images, dim=1) + + viz_goal_image = TF.resize(goal_image, VISUALIZATION_IMAGE_SIZE) + + goal_image = transform(goal_image).to(device) + model_outputs = model(obs_image, goal_image) + + dist_label = dist_label.to(device) + action_label = action_label.to(device) + action_mask = action_mask.to(device) + + optimizer.zero_grad() + + dist_pred, action_pred = model_outputs + + losses = _compute_losses( + dist_label=dist_label, + action_label=action_label, + dist_pred=dist_pred, + action_pred=action_pred, + alpha=alpha, + learn_angle=learn_angle, + action_mask=action_mask, + ) + + losses["total_loss"].backward() + optimizer.step() + + for key, value in losses.items(): + if key in loggers: + logger = loggers[key] + logger.log_data(value.item()) + + _log_data( + i=i, + epoch=epoch, + num_batches=num_batches, + normalized=normalized, + project_folder=project_folder, + num_images_log=num_images_log, + loggers=loggers, + obs_image=viz_obs_image, + goal_image=viz_goal_image, + action_pred=action_pred, + action_label=action_label, + dist_pred=dist_pred, + dist_label=dist_label, + goal_pos=goal_pos, + dataset_index=dataset_index, + wandb_log_freq=wandb_log_freq, + print_log_freq=print_log_freq, + image_log_freq=image_log_freq, + use_wandb=use_wandb, + mode="train", + use_latest=True, + ) + + +def evaluate( + eval_type: str, + model: nn.Module, + eval_loader: DataLoader, + transform: transforms, + device: torch.device, + project_folder: str, + normalized: bool, + epoch: int = 0, + alpha: float = 0.5, + learn_angle: bool = True, + num_images_log: int = 8, + use_wandb: bool = True, + eval_fraction: float = 1.0, + use_tqdm: bool = True, + +): + """ + Evaluate the model on the given evaluation dataset. + + Args: + eval_type (string): f"{data_type}_{eval_type}" (e.g. "recon_train", "gs_test", etc.) + model (nn.Module): model to evaluate + eval_loader (DataLoader): dataloader for eval + transform (transforms): transform to apply to images + device (torch.device): device to use for evaluation + project_folder (string): path to project folder + epoch (int): current epoch + alpha (float): weight for action loss + learn_angle (bool): whether to learn the angle of the action + print_log_freq (int): frequency of printing loss + image_log_freq (int): frequency of logging images + num_images_log (int): number of images to log + use_wandb (bool): whether to use wandb for logging + eval_fraction (float): fraction of data to use for evaluation + use_tqdm (bool): whether to use tqdm for logging + """ + model.eval() + dist_loss_logger = Logger("dist_loss", eval_type) + action_loss_logger = Logger("action_loss", eval_type) + action_waypts_cos_sim_logger = Logger("action_waypts_cos_sim", eval_type) + multi_action_waypts_cos_sim_logger = Logger("multi_action_waypts_cos_sim", eval_type) + total_loss_logger = Logger("total_loss", eval_type) + loggers = { + "dist_loss": dist_loss_logger, + "action_loss": action_loss_logger, + "action_waypts_cos_sim": action_waypts_cos_sim_logger, + "multi_action_waypts_cos_sim": multi_action_waypts_cos_sim_logger, + "total_loss": total_loss_logger, + } + + if learn_angle: + action_orien_cos_sim_logger = Logger("action_orien_cos_sim", eval_type) + multi_action_orien_cos_sim_logger = Logger("multi_action_orien_cos_sim", eval_type) + loggers["action_orien_cos_sim"] = action_orien_cos_sim_logger + loggers["multi_action_orien_cos_sim"] = multi_action_orien_cos_sim_logger + + num_batches = len(eval_loader) + num_batches = max(int(num_batches * eval_fraction), 1) + + viz_obs_image = None + with torch.no_grad(): + tqdm_iter = tqdm.tqdm( + itertools.islice(eval_loader, num_batches), + total=num_batches, + disable=not use_tqdm, + dynamic_ncols=True, + desc=f"Evaluating {eval_type} for epoch {epoch}", + ) + for i, data in enumerate(tqdm_iter): + ( + obs_image, + goal_image, + action_label, + dist_label, + goal_pos, + dataset_index, + action_mask, + ) = data + + obs_images = torch.split(obs_image, 3, dim=1) + viz_obs_image = TF.resize(obs_images[-1], VISUALIZATION_IMAGE_SIZE) + obs_images = [transform(obs_image).to(device) for obs_image in obs_images] + obs_image = torch.cat(obs_images, dim=1) + + viz_goal_image = TF.resize(goal_image, VISUALIZATION_IMAGE_SIZE) + + goal_image = transform(goal_image).to(device) + model_outputs = model(obs_image, goal_image) + + dist_label = dist_label.to(device) + action_label = action_label.to(device) + action_mask = action_mask.to(device) + + dist_pred, action_pred = model_outputs + + losses = _compute_losses( + dist_label=dist_label, + action_label=action_label, + dist_pred=dist_pred, + action_pred=action_pred, + alpha=alpha, + learn_angle=learn_angle, + action_mask=action_mask, + ) + + for key, value in losses.items(): + if key in loggers: + logger = loggers[key] + logger.log_data(value.item()) + + # Log data to wandb/console, with visualizations selected from the last batch + _log_data( + i=i, + epoch=epoch, + num_batches=num_batches, + normalized=normalized, + project_folder=project_folder, + num_images_log=num_images_log, + loggers=loggers, + obs_image=viz_obs_image, + goal_image=viz_goal_image, + action_pred=action_pred, + action_label=action_label, + goal_pos=goal_pos, + dist_pred=dist_pred, + dist_label=dist_label, + dataset_index=dataset_index, + use_wandb=use_wandb, + mode=eval_type, + use_latest=False, + wandb_increment_step=False, + ) + + return dist_loss_logger.average(), action_loss_logger.average(), total_loss_logger.average() diff --git a/train/vint_train/visualizing/__init__.py b/train/vint_train/visualizing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train/vint_train/visualizing/action_utils.py b/train/vint_train/visualizing/action_utils.py new file mode 100644 index 0000000..50c341e --- /dev/null +++ b/train/vint_train/visualizing/action_utils.py @@ -0,0 +1,476 @@ +import os +import numpy as np +import matplotlib.pyplot as plt +import cv2 +from typing import Optional, List +import wandb +import yaml +import torch +import torch.nn as nn +from vint_train.visualizing.visualize_utils import ( + to_numpy, + numpy_to_img, + VIZ_IMAGE_SIZE, + RED, + GREEN, + BLUE, + CYAN, + YELLOW, + MAGENTA, +) + +# load data_config.yaml +with open(os.path.join(os.path.dirname(__file__), "../data/data_config.yaml"), "r") as f: + data_config = yaml.safe_load(f) + + +def visualize_traj_pred( + batch_obs_images: np.ndarray, + batch_goal_images: np.ndarray, + dataset_indices: np.ndarray, + batch_goals: np.ndarray, + batch_pred_waypoints: np.ndarray, + batch_label_waypoints: np.ndarray, + eval_type: str, + normalized: bool, + save_folder: str, + epoch: int, + num_images_preds: int = 8, + use_wandb: bool = True, + display: bool = False, +): + """ + Compare predicted path with the gt path of waypoints using egocentric visualization. This visualization is for the last batch in the dataset. + + Args: + batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels] + batch_goal_images (np.ndarray): batch of goal images [batch_size, height, width, channels] + dataset_names: indices corresponding to the dataset name + batch_goals (np.ndarray): batch of goal positions [batch_size, 2] + batch_pred_waypoints (np.ndarray): batch of predicted waypoints [batch_size, horizon, 4] or [batch_size, horizon, 2] or [batch_size, num_trajs_sampled horizon, {2 or 4}] + batch_label_waypoints (np.ndarray): batch of label waypoints [batch_size, T, 4] or [batch_size, horizon, 2] + eval_type (string): f"{data_type}_{eval_type}" (e.g. "recon_train", "gs_test", etc.) + normalized (bool): whether the waypoints are normalized + save_folder (str): folder to save the images. If None, will not save the images + epoch (int): current epoch number + num_images_preds (int): number of images to visualize + use_wandb (bool): whether to use wandb to log the images + display (bool): whether to display the images + """ + visualize_path = None + if save_folder is not None: + visualize_path = os.path.join( + save_folder, "visualize", eval_type, f"epoch{epoch}", "action_prediction" + ) + + if not os.path.exists(visualize_path): + os.makedirs(visualize_path) + + assert ( + len(batch_obs_images) + == len(batch_goal_images) + == len(batch_goals) + == len(batch_pred_waypoints) + == len(batch_label_waypoints) + ) + + dataset_names = list(data_config.keys()) + dataset_names.sort() + + batch_size = batch_obs_images.shape[0] + wandb_list = [] + for i in range(min(batch_size, num_images_preds)): + obs_img = numpy_to_img(batch_obs_images[i]) + goal_img = numpy_to_img(batch_goal_images[i]) + dataset_name = dataset_names[int(dataset_indices[i])] + goal_pos = batch_goals[i] + pred_waypoints = batch_pred_waypoints[i] + label_waypoints = batch_label_waypoints[i] + + if normalized: + pred_waypoints *= data_config[dataset_name]["metric_waypoint_spacing"] + label_waypoints *= data_config[dataset_name]["metric_waypoint_spacing"] + goal_pos *= data_config[dataset_name]["metric_waypoint_spacing"] + + save_path = None + if visualize_path is not None: + save_path = os.path.join(visualize_path, f"{str(i).zfill(4)}.png") + + compare_waypoints_pred_to_label( + obs_img, + goal_img, + dataset_name, + goal_pos, + pred_waypoints, + label_waypoints, + save_path, + display, + ) + if use_wandb: + wandb_list.append(wandb.Image(save_path)) + if use_wandb: + wandb.log({f"{eval_type}_action_prediction": wandb_list}, commit=False) + + +def compare_waypoints_pred_to_label( + obs_img, + goal_img, + dataset_name: str, + goal_pos: np.ndarray, + pred_waypoints: np.ndarray, + label_waypoints: np.ndarray, + save_path: Optional[str] = None, + display: Optional[bool] = False, +): + """ + Compare predicted path with the gt path of waypoints using egocentric visualization. + + Args: + obs_img: image of the observation + goal_img: image of the goal + dataset_name: name of the dataset found in data_config.yaml (e.g. "recon") + goal_pos: goal position in the image + pred_waypoints: predicted waypoints in the image + label_waypoints: label waypoints in the image + save_path: path to save the figure + display: whether to display the figure + """ + + fig, ax = plt.subplots(1, 3) + start_pos = np.array([0, 0]) + if len(pred_waypoints.shape) > 2: + trajs = [*pred_waypoints, label_waypoints] + else: + trajs = [pred_waypoints, label_waypoints] + plot_trajs_and_points( + ax[0], + trajs, + [start_pos, goal_pos], + traj_colors=[CYAN, MAGENTA], + point_colors=[GREEN, RED], + ) + plot_trajs_and_points_on_image( + ax[1], + obs_img, + dataset_name, + trajs, + [start_pos, goal_pos], + traj_colors=[CYAN, MAGENTA], + point_colors=[GREEN, RED], + ) + ax[2].imshow(goal_img) + + fig.set_size_inches(18.5, 10.5) + ax[0].set_title(f"Action Prediction") + ax[1].set_title(f"Observation") + ax[2].set_title(f"Goal") + + if save_path is not None: + fig.savefig( + save_path, + bbox_inches="tight", + ) + + if not display: + plt.close(fig) + + +def plot_trajs_and_points_on_image( + ax: plt.Axes, + img: np.ndarray, + dataset_name: str, + list_trajs: list, + list_points: list, + traj_colors: list = [CYAN, MAGENTA], + point_colors: list = [RED, GREEN], +): + """ + Plot trajectories and points on an image. If there is no configuration for the camera interinstics of the dataset, the image will be plotted as is. + Args: + ax: matplotlib axis + img: image to plot + dataset_name: name of the dataset found in data_config.yaml (e.g. "recon") + list_trajs: list of trajectories, each trajectory is a numpy array of shape (horizon, 2) (if there is no yaw) or (horizon, 4) (if there is yaw) + list_points: list of points, each point is a numpy array of shape (2,) + traj_colors: list of colors for trajectories + point_colors: list of colors for points + """ + assert len(list_trajs) <= len(traj_colors), "Not enough colors for trajectories" + assert len(list_points) <= len(point_colors), "Not enough colors for points" + assert ( + dataset_name in data_config + ), f"Dataset {dataset_name} not found in data/data_config.yaml" + + ax.imshow(img) + if ( + "camera_metrics" in data_config[dataset_name] + and "camera_height" in data_config[dataset_name]["camera_metrics"] + and "camera_matrix" in data_config[dataset_name]["camera_metrics"] + and "dist_coeffs" in data_config[dataset_name]["camera_metrics"] + ): + camera_height = data_config[dataset_name]["camera_metrics"]["camera_height"] + camera_x_offset = data_config[dataset_name]["camera_metrics"]["camera_x_offset"] + + fx = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["fx"] + fy = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["fy"] + cx = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["cx"] + cy = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["cy"] + camera_matrix = gen_camera_matrix(fx, fy, cx, cy) + + k1 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k1"] + k2 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k2"] + p1 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["p1"] + p2 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["p2"] + k3 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k3"] + dist_coeffs = np.array([k1, k2, p1, p2, k3, 0.0, 0.0, 0.0]) + + for i, traj in enumerate(list_trajs): + xy_coords = traj[:, :2] # (horizon, 2) + traj_pixels = get_pos_pixels( + xy_coords, camera_height, camera_x_offset, camera_matrix, dist_coeffs, clip=False + ) + if len(traj_pixels.shape) == 2: + ax.plot( + traj_pixels[:250, 0], + traj_pixels[:250, 1], + color=traj_colors[i], + lw=2.5, + ) + + for i, point in enumerate(list_points): + if len(point.shape) == 1: + # add a dimension to the front of point + point = point[None, :2] + else: + point = point[:, :2] + pt_pixels = get_pos_pixels( + point, camera_height, camera_x_offset, camera_matrix, dist_coeffs, clip=True + ) + ax.plot( + pt_pixels[:250, 0], + pt_pixels[:250, 1], + color=point_colors[i], + marker="o", + markersize=10.0, + ) + ax.xaxis.set_visible(False) + ax.yaxis.set_visible(False) + ax.set_xlim((0.5, VIZ_IMAGE_SIZE[0] - 0.5)) + ax.set_ylim((VIZ_IMAGE_SIZE[1] - 0.5, 0.5)) + + +def plot_trajs_and_points( + ax: plt.Axes, + list_trajs: list, + list_points: list, + traj_colors: list = [CYAN, MAGENTA], + point_colors: list = [RED, GREEN], + traj_labels: Optional[list] = ["prediction", "ground truth"], + point_labels: Optional[list] = ["robot", "goal"], + traj_alphas: Optional[list] = None, + point_alphas: Optional[list] = None, + quiver_freq: int = 1, + default_coloring: bool = True, +): + """ + Plot trajectories and points that could potentially have a yaw. + + Args: + ax: matplotlib axis + list_trajs: list of trajectories, each trajectory is a numpy array of shape (horizon, 2) (if there is no yaw) or (horizon, 4) (if there is yaw) + list_points: list of points, each point is a numpy array of shape (2,) + traj_colors: list of colors for trajectories + point_colors: list of colors for points + traj_labels: list of labels for trajectories + point_labels: list of labels for points + traj_alphas: list of alphas for trajectories + point_alphas: list of alphas for points + quiver_freq: frequency of quiver plot (if the trajectory data includes the yaw of the robot) + """ + assert ( + len(list_trajs) <= len(traj_colors) or default_coloring + ), "Not enough colors for trajectories" + assert len(list_points) <= len(point_colors), "Not enough colors for points" + assert ( + traj_labels is None or len(list_trajs) == len(traj_labels) or default_coloring + ), "Not enough labels for trajectories" + assert point_labels is None or len(list_points) == len(point_labels), "Not enough labels for points" + + for i, traj in enumerate(list_trajs): + if traj_labels is None: + ax.plot( + traj[:, 0], + traj[:, 1], + color=traj_colors[i], + alpha=traj_alphas[i] if traj_alphas is not None else 1.0, + marker="o", + ) + else: + ax.plot( + traj[:, 0], + traj[:, 1], + color=traj_colors[i], + label=traj_labels[i], + alpha=traj_alphas[i] if traj_alphas is not None else 1.0, + marker="o", + ) + if traj.shape[1] > 2 and quiver_freq > 0: # traj data also includes yaw of the robot + bearings = gen_bearings_from_waypoints(traj) + ax.quiver( + traj[::quiver_freq, 0], + traj[::quiver_freq, 1], + bearings[::quiver_freq, 0], + bearings[::quiver_freq, 1], + color=traj_colors[i] * 0.5, + scale=1.0, + ) + for i, pt in enumerate(list_points): + if point_labels is None: + ax.plot( + pt[0], + pt[1], + color=point_colors[i], + alpha=point_alphas[i] if point_alphas is not None else 1.0, + marker="o", + markersize=7.0 + ) + else: + ax.plot( + pt[0], + pt[1], + color=point_colors[i], + alpha=point_alphas[i] if point_alphas is not None else 1.0, + marker="o", + markersize=7.0, + label=point_labels[i], + ) + + + # put the legend below the plot + if traj_labels is not None or point_labels is not None: + ax.legend() + ax.legend(bbox_to_anchor=(0.0, -0.5), loc="upper left", ncol=2) + ax.set_aspect("equal", "box") + + +def angle_to_unit_vector(theta): + """Converts an angle to a unit vector.""" + return np.array([np.cos(theta), np.sin(theta)]) + + +def gen_bearings_from_waypoints( + waypoints: np.ndarray, + mag=0.2, +) -> np.ndarray: + """Generate bearings from waypoints, (x, y, sin(theta), cos(theta)).""" + bearing = [] + for i in range(0, len(waypoints)): + if waypoints.shape[1] > 3: # label is sin/cos repr + v = waypoints[i, 2:] + # normalize v + v = v / np.linalg.norm(v) + v = v * mag + else: # label is radians repr + v = mag * angle_to_unit_vector(waypoints[i, 2]) + bearing.append(v) + bearing = np.array(bearing) + return bearing + + +def project_points( + xy: np.ndarray, + camera_height: float, + camera_x_offset: float, + camera_matrix: np.ndarray, + dist_coeffs: np.ndarray, +): + """ + Projects 3D coordinates onto a 2D image plane using the provided camera parameters. + + Args: + xy: array of shape (batch_size, horizon, 2) representing (x, y) coordinates + camera_height: height of the camera above the ground (in meters) + camera_x_offset: offset of the camera from the center of the car (in meters) + camera_matrix: 3x3 matrix representing the camera's intrinsic parameters + dist_coeffs: vector of distortion coefficients + + + Returns: + uv: array of shape (batch_size, horizon, 2) representing (u, v) coordinates on the 2D image plane + """ + batch_size, horizon, _ = xy.shape + + # create 3D coordinates with the camera positioned at the given height + xyz = np.concatenate( + [xy, -camera_height * np.ones(list(xy.shape[:-1]) + [1])], axis=-1 + ) + + # create dummy rotation and translation vectors + rvec = tvec = (0, 0, 0) + + xyz[..., 0] += camera_x_offset + xyz_cv = np.stack([xyz[..., 1], -xyz[..., 2], xyz[..., 0]], axis=-1) + uv, _ = cv2.projectPoints( + xyz_cv.reshape(batch_size * horizon, 3), rvec, tvec, camera_matrix, dist_coeffs + ) + uv = uv.reshape(batch_size, horizon, 2) + + return uv + + +def get_pos_pixels( + points: np.ndarray, + camera_height: float, + camera_x_offset: float, + camera_matrix: np.ndarray, + dist_coeffs: np.ndarray, + clip: Optional[bool] = False, +): + """ + Projects 3D coordinates onto a 2D image plane using the provided camera parameters. + Args: + points: array of shape (batch_size, horizon, 2) representing (x, y) coordinates + camera_height: height of the camera above the ground (in meters) + camera_x_offset: offset of the camera from the center of the car (in meters) + camera_matrix: 3x3 matrix representing the camera's intrinsic parameters + dist_coeffs: vector of distortion coefficients + + Returns: + pixels: array of shape (batch_size, horizon, 2) representing (u, v) coordinates on the 2D image plane + """ + pixels = project_points( + points[np.newaxis], camera_height, camera_x_offset, camera_matrix, dist_coeffs + )[0] + pixels[:, 0] = VIZ_IMAGE_SIZE[0] - pixels[:, 0] + if clip: + pixels = np.array( + [ + [ + np.clip(p[0], 0, VIZ_IMAGE_SIZE[0]), + np.clip(p[1], 0, VIZ_IMAGE_SIZE[1]), + ] + for p in pixels + ] + ) + else: + pixels = np.array( + [ + p + for p in pixels + if np.all(p > 0) and np.all(p < [VIZ_IMAGE_SIZE[0], VIZ_IMAGE_SIZE[1]]) + ] + ) + return pixels + + +def gen_camera_matrix(fx: float, fy: float, cx: float, cy: float) -> np.ndarray: + """ + Args: + fx: focal length in x direction + fy: focal length in y direction + cx: principal point x coordinate + cy: principal point y coordinate + Returns: + camera matrix + """ + return np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]]) diff --git a/train/vint_train/visualizing/distance_utils.py b/train/vint_train/visualizing/distance_utils.py new file mode 100644 index 0000000..ff25b2c --- /dev/null +++ b/train/vint_train/visualizing/distance_utils.py @@ -0,0 +1,202 @@ +import os +import wandb +import numpy as np +from typing import List, Optional, Tuple +from vint_train.visualizing.visualize_utils import numpy_to_img +import matplotlib.pyplot as plt + + +def visualize_dist_pred( + batch_obs_images: np.ndarray, + batch_goal_images: np.ndarray, + batch_dist_preds: np.ndarray, + batch_dist_labels: np.ndarray, + eval_type: str, + save_folder: str, + epoch: int, + num_images_preds: int = 8, + use_wandb: bool = True, + display: bool = False, + rounding: int = 4, + dist_error_threshold: float = 3.0, +): + """ + Visualize the distance classification predictions and labels for an observation-goal image pair. + + Args: + batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels] + batch_goal_images (np.ndarray): batch of goal images [batch_size, height, width, channels] + batch_dist_preds (np.ndarray): batch of distance predictions [batch_size] + batch_dist_labels (np.ndarray): batch of distance labels [batch_size] + eval_type (string): {data_type}_{eval_type} (e.g. recon_train, gs_test, etc.) + epoch (int): current epoch number + num_images_preds (int): number of images to visualize + use_wandb (bool): whether to use wandb to log the images + save_folder (str): folder to save the images. If None, will not save the images + display (bool): whether to display the images + rounding (int): number of decimal places to round the distance predictions and labels + dist_error_threshold (float): distance error threshold for classifying the distance prediction as correct or incorrect (only used for visualization purposes) + """ + visualize_path = os.path.join( + save_folder, + "visualize", + eval_type, + f"epoch{epoch}", + "dist_classification", + ) + if not os.path.isdir(visualize_path): + os.makedirs(visualize_path) + assert ( + len(batch_obs_images) + == len(batch_goal_images) + == len(batch_dist_preds) + == len(batch_dist_labels) + ) + batch_size = batch_obs_images.shape[0] + wandb_list = [] + for i in range(min(batch_size, num_images_preds)): + dist_pred = np.round(batch_dist_preds[i], rounding) + dist_label = np.round(batch_dist_labels[i], rounding) + obs_image = numpy_to_img(batch_obs_images[i]) + goal_image = numpy_to_img(batch_goal_images[i]) + + save_path = None + if save_folder is not None: + save_path = os.path.join(visualize_path, f"{i}.png") + text_color = "black" + if abs(dist_pred - dist_label) > dist_error_threshold: + text_color = "red" + + display_distance_pred( + [obs_image, goal_image], + ["Observation", "Goal"], + dist_pred, + dist_label, + text_color, + save_path, + display, + ) + if use_wandb: + wandb_list.append(wandb.Image(save_path)) + if use_wandb: + wandb.log({f"{eval_type}_dist_prediction": wandb_list}, commit=False) + + +def visualize_dist_pairwise_pred( + batch_obs_images: np.ndarray, + batch_close_images: np.ndarray, + batch_far_images: np.ndarray, + batch_close_preds: np.ndarray, + batch_far_preds: np.ndarray, + batch_close_labels: np.ndarray, + batch_far_labels: np.ndarray, + eval_type: str, + save_folder: str, + epoch: int, + num_images_preds: int = 8, + use_wandb: bool = True, + display: bool = False, + rounding: int = 4, +): + """ + Visualize the distance classification predictions and labels for an observation-goal image pair. + + Args: + batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels] + batch_close_images (np.ndarray): batch of close goal images [batch_size, height, width, channels] + batch_far_images (np.ndarray): batch of far goal images [batch_size, height, width, channels] + batch_close_preds (np.ndarray): batch of close predictions [batch_size] + batch_far_preds (np.ndarray): batch of far predictions [batch_size] + batch_close_labels (np.ndarray): batch of close labels [batch_size] + batch_far_labels (np.ndarray): batch of far labels [batch_size] + eval_type (string): {data_type}_{eval_type} (e.g. recon_train, gs_test, etc.) + save_folder (str): folder to save the images. If None, will not save the images + epoch (int): current epoch number + num_images_preds (int): number of images to visualize + use_wandb (bool): whether to use wandb to log the images + display (bool): whether to display the images + rounding (int): number of decimal places to round the distance predictions and labels + """ + visualize_path = os.path.join( + save_folder, + "visualize", + eval_type, + f"epoch{epoch}", + "pairwise_dist_classification", + ) + if not os.path.isdir(visualize_path): + os.makedirs(visualize_path) + assert ( + len(batch_obs_images) + == len(batch_close_images) + == len(batch_far_images) + == len(batch_close_preds) + == len(batch_far_preds) + == len(batch_close_labels) + == len(batch_far_labels) + ) + batch_size = batch_obs_images.shape[0] + wandb_list = [] + for i in range(min(batch_size, num_images_preds)): + close_dist_pred = np.round(batch_close_preds[i], rounding) + far_dist_pred = np.round(batch_far_preds[i], rounding) + close_dist_label = np.round(batch_close_labels[i], rounding) + far_dist_label = np.round(batch_far_labels[i], rounding) + obs_image = numpy_to_img(batch_obs_images[i]) + close_image = numpy_to_img(batch_close_images[i]) + far_image = numpy_to_img(batch_far_images[i]) + + save_path = None + if save_folder is not None: + save_path = os.path.join(visualize_path, f"{i}.png") + + if close_dist_pred < far_dist_pred: + text_color = "black" + else: + text_color = "red" + + display_distance_pred( + [obs_image, close_image, far_image], + ["Observation", "Close Goal", "Far Goal"], + f"close_pred = {close_dist_pred}, far_pred = {far_dist_pred}", + f"close_label = {close_dist_label}, far_label = {far_dist_label}", + text_color, + save_path, + display, + ) + if use_wandb: + wandb_list.append(wandb.Image(save_path)) + if use_wandb: + wandb.log({f"{eval_type}_pairwise_classification": wandb_list}, commit=False) + + +def display_distance_pred( + imgs: list, + titles: list, + dist_pred: float, + dist_label: float, + text_color: str = "black", + save_path: Optional[str] = None, + display: bool = False, +): + plt.figure() + fig, ax = plt.subplots(1, len(imgs)) + + plt.suptitle(f"prediction: {dist_pred}\nlabel: {dist_label}", color=text_color) + + for axis, img, title in zip(ax, imgs, titles): + axis.imshow(img) + axis.set_title(title) + axis.xaxis.set_visible(False) + axis.yaxis.set_visible(False) + + # make the plot large + fig.set_size_inches((18.5 / 3) * len(imgs), 10.5) + + if save_path is not None: + fig.savefig( + save_path, + bbox_inches="tight", + ) + if not display: + plt.close(fig) diff --git a/train/vint_train/visualizing/visualize_utils.py b/train/vint_train/visualizing/visualize_utils.py new file mode 100644 index 0000000..d4d7def --- /dev/null +++ b/train/vint_train/visualizing/visualize_utils.py @@ -0,0 +1,25 @@ +import numpy as np +from PIL import Image +import torch + +VIZ_IMAGE_SIZE = (640, 480) +RED = np.array([1, 0, 0]) +GREEN = np.array([0, 1, 0]) +BLUE = np.array([0, 0, 1]) +CYAN = np.array([0, 1, 1]) +YELLOW = np.array([1, 1, 0]) +MAGENTA = np.array([1, 0, 1]) + + +def numpy_to_img(arr: np.ndarray) -> Image: + img = Image.fromarray(np.transpose(np.uint8(255 * arr), (1, 2, 0))) + img = img.resize(VIZ_IMAGE_SIZE) + return img + + +def to_numpy(tensor: torch.Tensor) -> np.ndarray: + return tensor.detach().cpu().numpy() + + +def from_numpy(array: np.ndarray) -> torch.Tensor: + return torch.from_numpy(array).float()