Skip to content

Latest commit

 

History

History
executable file
·
207 lines (155 loc) · 11.1 KB

README.md

File metadata and controls

executable file
·
207 lines (155 loc) · 11.1 KB

OpenVLA: An Open-Source Vision-Language-Action Model

arXiv HF Models PyTorch Python License

Getting Started | Pretrained VLAs | Installation | Training VLAs from Scratch | Project Website

A simple and scalable codebase for training and fine-tuning vision-language-action models (VLAs) for generalist robotic manipulation:

  • Different Dataset Mixtures: We natively support arbitrary datasets in RLDS format, including arbitrary mixtures of data from the Open X-Embodiment Dataset.
  • Easy Scaling: Powered by PyTorch FSDP and Flash-Attention, we can quickly and efficiently train models from 1B - 34B parameters, with easily adaptable model architectures.
  • Native Fine-Tuning Support: Built-in support (with examples) for various forms of fine-tuning (full, partial, LoRA).

Built on top of Prismatic VLMs.

Getting Started

To get started with loading and running OpenVLA models for inference, we provide a lightweight interface that leverages HuggingFace transformers AutoClasses, with minimal dependencies.

For example, to load openvla-7b for zero-shot instruction following in the BridgeData V2 environments with a Widow-X robot:

# Install minimal dependencies (`torch`, `transformers`, `timm`, `tokenizers`, ...)
# > pip install -r https://raw.githubusercontent.com/openvla/openvla/main/requirements-min.txt
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image

import torch

# Load Processor & VLA
processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained(
    "openvla/openvla-7b", 
    attn_implementation="flash_attention_2",  # [Optional] Requires `flash_attn`
    torch_dtype=torch.bfloat16, 
    low_cpu_mem_usage=True, 
    trust_remote_code=True
).to("cuda:0")

# Grab image input & format prompt
image: Image.Image = get_from_camera(...)
prompt = "In: What action should the robot take to {<INSTRUCTION>}?\nOut:"

# Predict Action (7-DoF; un-normalize for BridgeData V2)
inputs = processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)

# Execute...
robot.act(action, ...)

We also provide an example script for fine-tuning OpenVLA models for new tasks and embodiments; this script supports different fine-tuning modes -- including (quantized) low-rank adaptation (LoRA) supported by HuggingFace's PEFT library.

For deployment, we provide a lightweight script for serving OpenVLA models over a REST API, providing an easy way to integrate OpenVLA models into existing robot control stacks, removing any requirement for powerful on-device compute.

Pretrained VLAs

We release two OpenVLA models trained as part of our work, with checkpoints, configs, and model cards available on our HuggingFace page:

  • openvla-7b: The flagship model from our paper, trained from the Prismatic prism-dinosiglip-224px VLM (based on a fused DINOv2 and SigLIP vision backbone, and Llama-2 LLM). Trained on a large mixture of datasets from Open X-Embodiment spanning 970K trajectories (mixture details - see "Open-X Magic Soup++").
  • openvla-v01-7b: An early model used during development, trained from the Prismatic siglip-224px VLM (singular SigLIP vision backbone, and a Vicuña v1.5 LLM). Trained on the same mixture of datasets as Octo, but for significantly fewer GPU hours than our final model (mixture details - see "Open-X Magic Soup").

Explicit Notes on Model Licensing & Commercial Use: While all code in this repository is released under an MIT License, our pretrained models may inherit restrictions from the underlying base models we use. Specifically, both the above models are derived from Llama-2, and as such are subject to the Llama Community License.


Installation

Note: These installation instructions are for full-scale pretraining (and distributed fine-tuning); if looking to just run inference with OpenVLA models (or perform lightweight fine-tuning), see instructions above!

This repository was built using Python 3.10, but should be backwards compatible with any Python >= 3.8. We require PyTorch 2.2.* -- installation instructions can be found here. The latest version of this repository was developed and thoroughly tested with:

  • PyTorch 2.2.0, torchvision 0.17.0, transformers 4.40.1, tokenizers 0.19.1, timm 0.9.10, and flash-attn 2.5.5

[5/21/24] Note: Following reported regressions and breaking changes in later versions of transformers, timm, and tokenizers we explicitly pin the above versions of the dependencies. We are working on implementing thorough tests, and plan on relaxing these constraints as soon as we can.

Once PyTorch has been properly installed, you can install this package locally via an editable installation (or via pip install git+https://github.com/openvla/openvla):

cd openvla
pip install -e .

# Training additionally requires Flash-Attention 2 (https://github.com/Dao-AILab/flash-attention)
pip install packaging ninja

# Verify Ninja --> should return exit code "0"
ninja --version; echo $?

# Install Flash Attention 2
#   =>> If you run into difficulty, try `pip cache remove flash_attn` first
pip install "flash-attn==2.5.5" --no-build-isolation

If you run into any problems during the installation process, please file a GitHub Issue.

Note: See vla-scripts/ for full training and verification scripts for OpenVLA models. Note that scripts/ is mostly a holdover from the original (base) prismatic-vlms repository, with support for training and evaluating visually-conditioned language models; while you can use this repo to train VLMs AND VLAs, note that trying to generate language (via scripts/generate.py) with existing OpenVLA models will not work (as we only train current OpenVLA models to generate actions, and actions alone).

Training VLAs from Scratch

We provide full instructions and configurations for training OpenVLA models on (arbitrary subsets of) the Open X-Embodiment (OXE) Dataset. If you run in to any issues with the following, see VLA Troubleshooting below (or file a GitHub Issue).

VLA Pretraining Datasets

We download and preprocess individual datasets from Open X-Embodiment in RLDS format following this custom script. See mixtures.py for the full list of component datasets (and mixture weights) we use to train openvla-7b.

  • Important: For the BridgeData V2 component, the version in OXE is out of date (as of 12/20/2023). Instead, you should download the dataset from the official website and place it under the subdirectory bridge_orig/. Replace any reference to bridge in the OXE code with bridge_orig.

VLA Configuration & Training Script

The entry point for VLA training is vla-scripts/train.py. We use draccus to provide a modular, dataclass-based interface for specifying VLA training configurations; existing VLA configurations are in prismatic/conf/vla.py. You can add your own training configuration and refer to it using the --vla.type command line argument.

We use PyTorch Fully Sharded Data Parallel (FSDP) to distribute training across GPUs. Launch training via torchrun:

# Train VLA on BridgeData V2 with the Prismatic DINO-SigLIP 224px Backbone on a Single Node (w/ 8 GPUs)
torchrun --standalone --nnodes 1 --nproc-per-node 8 vla-scripts/train.py \
  --vla.type "prism-dinosiglip-224px+mx-bridge" \
  --data_root_dir <PATH TO OXE DATA ROOT> \
  --run_root_dir <PATH TO LOG/CHECKPOINT ROOT> \
  --wandb_project "<PROJECT>" \
  --wandb_entity "<ENTITY>"

VLA Troubleshooting

The following are a list of known problems and corresponding fixes:

FileNotFoundError: Failed to construct dataset "fractal20220817_data", builder_kwargs "{'data_dir': '/path/to/processed/datasets/'}": Could not load dataset info from fractal20220817_data/0.1.0/dataset_info.json
  • Fix: Downgrade tensorflow-datasets via pip install tensorflow-datasets==4.9.3.
AttributeError: 'DLataset' object has no attribute 'traj_map'. Did you mean: 'flat_map'?
  • Fix: Upgrade dlimp to the newest version. You may have to --force-reinstall like so: pip install --no-deps --force-reinstall git+https://github.com/kvablack/dlimp@5edaa4691567873d495633f2708982b42edf1972

Repository Structure

High-level overview of repository/project file-tree:

  • prismatic - Package source; provides core utilities for model loading, training, data preprocessing, etc.
  • vla-scripts/ - Core scripts for training, fine-tuning, and deploying VLAs.
  • LICENSE - All code is made available under the MIT License; happy hacking!
  • Makefile - Top-level Makefile (by default, supports linting - checking & auto-fix); extend as needed.
  • pyproject.toml - Full project configuration details (including dependencies), as well as tool configurations.
  • README.md - You are here!

Citation

If you find our code or models useful in your work, please cite our paper:

@article{kim24openvla,
    title={OpenVLA: An Open-Source Vision-Language-Action Model},
    author={{Moo Jin} Kim and Karl Pertsch and Siddharth Karamcheti and Ted Xiao and Ashwin Balakrishna and Suraj Nair and Rafael Rafailov and Ethan Foster and Grace Lam and Pannag Sanketi and Quan Vuong and Thomas Kollar and Benjamin Burchfiel and Russ Tedrake and Dorsa Sadigh and Sergey Levine and Percy Liang and Chelsea Finn},
    journal = {arXiv preprint arXiv:2406.09246},
    year={2024}
}