This is the official repository for "REFLECTOOL: Towards Reflection-Aware Tool-Augmented Clinical Agents"
[Webpage] [Paper] [Huggingface Dataset] [Leaderboard]
- About
- LeaderBoard
- News
- Installations
- Dataset
- Support Models
- Evaluation
- Quickly Use
- Acknowledgements
- Citation
Despite clinical agents succeeding in diverse signal interaction, they are oriented to a single clinical scenario and hence fail for broader applications. To evaluate clinical agents holistically, we propose ClinicalAgent Bench, a comprehensive medical agent benchmark consisting of 18 tasks across five key realistic clinical dimensions.
Building on this, we introduce ReflectTool, a novel framework that excels at utilizing domain-specific tools within two stages. ReflectTool can search for supportive successful demonstrations from already built long-term memory to guide the tool selection strategy, and a verifier improves the tool usage according to the tool-wise experience with two verification methods--Iterative Refinement and Candidate Selection.
Methods | Type | Total | Knowledge& Reasoning | Multimodal | Numerical Analysis | Data Understanding | Trustworthiness | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
MedQA | MMLU | BioASQ | PubMedQA | Avg. | VQARAD | SLAKE | OmniMedQA | Avg. | MedCalc | EHRSQL | MIMIC-III | eICU | Avg. | MedMentions | emrQA | LongHealthQA | Avg. | MedHalt-Rht | MedVQA-Halt | EHR-Halt | LongHalt | Avg. | |||
MedLlama3-8b | LLM | 25.14 | 58.13 | 72.82 | 66.18 | 42.40 | 59.88 | - | - | - | - | 22.45 | 7.92 | 8.65 | 7.40 | 11.61 | 22.93 | 23.10 | - | 23.02 | 9.90 | - | 2.23 | - | 6.07 |
Qwen2-7B | LLM | 38.01 | 54.04 | 69.51 | 71.68 | 48.20 | 60.86 | - | - | - | - | 14.42 | 18.09 | 19.70 | 21.73 | 18.49 | 18.38 | 39.86 | 75.25 | 44.50 | 14.11 | - | 3.97 | 66.50 | 28.19 |
Llama3-8b | LLM | 35.62 | 56.32 | 70.34 | 72.82 | 53.80 | 63.32 | - | - | - | - | 28.08 | 17.02 | 12.53 | 21.83 | 19.87 | 28.54 | 41.92 | - | 35.23 | 29.22 | - | 18.90 | - | 24.06 |
Llama3.1-8b | LLM | 42.46 | 65.20 | 76.58 | 74.92 | 53.00 | 67.43 | - | - | - | - | 37.44 | 11.35 | 17.42 | 22.07 | 32.08 | 42.41 | 74.25 | 49.58 | 27.78 | - | 3.97 | 60.50 | 30.75 | |
Qwen2-72B* | LLM | 48.76 | 71.25 | 84.48 | 82.85 | 53.00 | 72.90 | - | - | - | - | 32.19 | 23.98 | 33.95 | 34.15 | 31.07 | 29.20 | 42.89 | 79.75 | 50.61 | 31.56 | - | 31.30 | 58.50 | 40.45 |
Llama3.1-70b* | LLM | 47.59 | 79.58 | 88.15 | 82.52 | 57.40 | 76.91 | - | - | - | - | 48.52 | 16.49 | 25.44 | 26.47 | 29.23 | 25.71 | 31.69 | 80.00 | 45.80 | 28.22 | - | 22.48 | 64.50 | 38.40 |
GPT-3.5-turbo | LLM | 31.31 | 58.68 | 69.88 | 75.40 | 50.60 | 63.64 | - | - | - | - | 20.53 | 17.57 | 24.31 | 14.30 | 19.18 | 26.88 | 21.64 | - | 24.26 | 9.78 | - | 26.55 | - | 18.17 |
MiniCPM-V-2.6 | MLLM | 29.28 | 46.58 | 61.16 | 70.23 | 47.20 | 56.29 | 48.78 | 47.12 | 73.70 | 56.53 | 13.28 | 1.61 | 1.63 | 1.88 | 4.60 | 18.92 | 17.42 | 5.25 | 13.86 | 12.44 | 36.89 | 8.91 | 2.25 | 15.12 |
InternVL-Chat-V1.5 | MLLM | 37.02 | 50.82 | 65.56 | 64.89 | 30.40 | 52.92 | 49.67 | 41.47 | 68.50 | 53.21 | 18.91 | 17.99 | 18.05 | 20.83 | 18.95 | 26.47 | 42.53 | * | 34.50 | 25.78 | 50.78 | 0.00 | * | 25.52 |
HuatuoGPT-Vision-7B | MLLM | 41.11 | 50.43 | 66.12 | 73.30 | 54.00 | 60.96 | 53.65 | 52.97 | 92.10 | 66.24 | 13.56 | 4.39 | 9.27 | 9.76 | 9.25 | 16.74 | 38.44 | 73.00 | 42.73 | 14.33 | 23.44 | 2.42 | 65.25 | 26.36 |
HuatuoGPT-Vision-34B | MLLM | 41.31 | 54.83 | 72.36 | 73.79 | 48.00 | 62.25 | 56.76 | 53.72 | 91.50 | 67.33 | 25.79 | 8.14 | 9.15 | 9.79 | 13.22 | 16.34 | 41.68 | * | 29.01 | 28.44 | 43.77 | 32.07 | * | 34.76 |
GPT-4o-mini | MLLM | 52.65 | 76.90 | 85.67 | 82.84 | 49.20 | 73.65 | 50.47 | 46.47 | 59.20 | 52.05 | 50.43 | 21.73 | 28.07 | 18.19 | 29.61 | 31.66 | 40.00 | 78.50 | 50.05 | 62.11 | 53.78 | 45.74 | 70.00 | 57.91 |
COT | Qwen2-7b Agent | 39.94 | 52.47 | 69.97 | 72.21 | 41.00 | 58.91 | - | - | - | - | 19.10 | 16.06 | 23.81 | 20.95 | 19.98 | 22.83 | 19.92 | 65.75 | 36.17 | 45.56 | - | 31.49 | 57.00 | 44.68 |
ReAct | Qwen2-7b Agent | 42.73 | 51.61 | 67.68 | 80.24 | 48.60 | 62.03 | 35.92 | 39.59 | 72.90 | 49.47 | 18.62 | 18.52 | 25.06 | 34.00 | 24.05 | 22.19 | 24.04 | 41.50 | 29.24 | 49.89 | 35.33 | 61.49 | 48.75 | 48.87 |
CRITIC | Qwen2-7b Agent | 43.97 | 52.87 | 58.68 | 71.68 | 43.20 | 56.61 | 48.12 | 42.70 | 70.80 | 53.87 | 13.09 | 23.55 | 28.20 | 33.12 | 24.49 | 25.47 | 36.33 | 50.25 | 37.35 | 30.44 | 28.33 | 57.64 | 73.75 | 47.54 |
Reflexion | Qwen2-7b Agent | 45.25 | 51.78 | 66.48 | 74.60 | 50.80 | 60.92 | 45.68 | 47.97 | 77.20 | 56.95 | 13.37 | 17.56 | 22.16 | 30.23 | 20.83 | 30.30 | 28.92 | 53.00 | 37.41 | 50.55 | 36.33 | 62.91 | 50.75 | 50.14 |
ReflecTool (Iterative Refinement, k=2) | Qwen2-7b Agent | 49.37 | 50.12 | 65.47 | 76.37 | 63.20 | 63.79 | 53.88 | 45.71 | 82.90 | 60.83 | 24.68 | 24.20 | 16.92 | 22.08 | 21.97 | 43.69 | 50.27 | 61.00 | 51.65 | 54.44 | 37.99 | 56.73 | 45.25 | 48.60 |
ReflecTool (Candidates Selection, k=2) | Qwen2-7b Agent | 49.08 | 50.81 | 64.84 | 72.98 | 62.60 | 62.81 | 56.76 | 49.76 | 79.20 | 61.91 | 24.64 | 29.87 | 25.13 | 27.48 | 26.78 | 59.21 | 43.40 | 54.00 | 52.20 | 53.44 | 34.21 | 55.97 | 23.25 | 41.72 |
COT | Qwen2-72b Agent | 50.58 | 72.51 | 85.45 | 81.07 | 37.40 | 69.11 | - | - | - | - | 29.89 | 18.73 | 23.55 | 25.72 | 24.47 | 24.43 | 52.09 | 81.00 | 52.51 | 57.11 | - | 40.79 | 70.75 | 56.22 |
ReAct | Qwen2-72b Agent | 53.31 | 72.43 | 85.67 | 85.38 | 62.40 | 76.47 | 50.09 | 46.01 | 73.00 | 56.37 | 26.46 | 35.97 | 31.45 | 31.87 | 31.44 | 42.11 | 55.02 | 62.75 | 53.29 | 56.89 | 24.75 | 55.78 | 58.50 | 48.98 |
CRITIC | Qwen2-72b Agent | 52.35 | 71.85 | 85.31 | 87.86 | 51.00 | 74.01 | 40.58 | 50.80 | 73.50 | 54.96 | 22.44 | 35.48 | 32.47 | 33.28 | 30.92 | 33.60 | 56.36 | 75.50 | 55.15 | 51.77 | 24.22 | 52.03 | 58.75 | 46.69 |
Reflexion | Qwen2-72b Agent | 56.37 | 70.78 | 84.30 | 87.06 | 65.00 | 76.79 | 54.99 | 50.05 | 77.80 | 60.95 | 22.45 | 40.14 | 31.94 | 33.42 | 31.99 | 52.37 | 58.73 | 64.00 | 58.37 | 59.00 | 33.00 | 59.00 | 64.00 | 53.75 |
ReflecTool (Iterative Refinement, k=2) | Qwen2-72b Agent | 59.43 | 73.30 | 84.11 | 84.63 | 65.20 | 76.81 | 57.21 | 48.82 | 85.20 | 63.74 | 36.01 | 47.43 | 33.96 | 36.39 | 38.45 | 54.57 | 67.96 | 68.00 | 63.51 | 59.55 | 38.21 | 57.82 | 63.00 | 54.65 |
ReflecTool (Candidates Selection, k=2) | Qwen2-72b Agent | 59.66 | 71.37 | 84.00 | 83.50 | 66.20 | 76.27 | 57.66 | 48.54 | 81.90 | 62.70 | 36.77 | 49.89 | 31.20 | 34.38 | 38.06 | 60.51 | 66.61 | 66.50 | 64.54 | 60.77 | 39.63 | 59.78 | 66.75 | 56.73 |
- π₯ [2024/10/24] We release the ClinicalAgent Bench, comprised 18 tasks across five capacity dimensions!
- π₯ [2024/10/24] We release the code implementation of the ReflecTool!
conda create -n reflectool python=3.10.8
conda activate reflectool
# install torch
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118
# install java
conda install -c conda-forge openjdk=21.0.4
git clone https://github.com/BlueZeros/ReflecTool.git
cd ReflecTool
pip install -e .
You can find details of the tools installation in here.
Download the dataset ClinicalAgentBench
from here and put it into the temp folder. The ClinicalAgentBench
has the structure below:
ReflecTool/
βββ ClinicalAgentBench/
β βββ ablations/ # subset used for analysis
β βββ train/ # dataset used for optimization stage
β βββ test/ # dataset used for evaluation
β βββ memory/ # few-shot samples, long-term memory and tool experience
The support model list can be found in model_config. You should add corresponding folder path before using the model. The type of the supported models are shown below:
- Local Models: Huggingface models load with AutoModelForCausalLM. We also support the vllm accelerate for these models (automatically used if vllm is installed).
- Llama3
- Llama3.1
- Qwen1.5
- Qwen2
- Qwen2.5
- OpenAI Models: OpenAI Model list.
- GPT-3.5-turbo
- GPT-4o-mini
- GPT-4o
- Note: need to set environment variables
OPENAI_API_KEY={your_openai_api_key}
.
- MultiModal LLM: HuatuoGPT-Vision and MiniCPM.
- HuatuoGPT-Vision
- MiniCPM-V-2.6
- InternVL-Chat-V1.5
Models
DATA_PATH=./ClinicalAgentBench
TASK_PATH=test
OUTPUT_PATH=./results/${TASK_PATH}
MEMORY_PATH=${DATA_PATH}/memory
DATASET=medqa
MODEL=qewn2-7b
EXP_NAME=${MODEL}
mkdir -p ${OUTPUT_PATH}/${DATASET}/${EXP_NAME}
python run.py \
--data-path ${DATA_PATH} \
--output-path ${OUTPUT_PATH} \
--task-name ${TASK_PATH} \
--test-split ${DATASET} \
--exp-name ${EXP_NAME} \
--model ${MODEL} \
--log-print \
--prompt-debug \
--resume
Reflexion Agent
DATA_PATH=./ClinicalAgentBench
TASK_PATH=test
OUTPUT_PATH=./results/${TASK_PATH}
MEMORY_PATH=${DATA_PATH}/memory
DATASET="medqa"
MODEL=qwen2-7b
ACTIONS="all_wo_mm"
FEWSHOT=1
EXP_NAME=reflexion_${MODEL}-few_shot_${FEWSHOT}-${ACTIONS}
mkdir -p ${OUTPUT_PATH}/${DATASET}/${EXP_NAME}
python run.py \
--data-path ${DATA_PATH} \
--output-path ${OUTPUT_PATH} \
--task-name ${TASK_PATH} \
--test-split ${DATASET} \
--exp-name ${EXP_NAME} \
--agent "reflexion" \
--model ${MODEL} \
--actions ${ACTIONS} \
--memory-path ${MEMORY_PATH} \
--memory-type "reflexion_standard" \
--few-shot ${FEWSHOT} \
--resume
CRITIC Agent
DATA_PATH=./ClinicalAgentBench
TASK_PATH=test
OUTPUT_PATH=./results/${TASK_PATH}
MEMORY_PATH=${DATA_PATH}/memory
DATASET="medqa"
MODEL=qwen2-72b-int4
ACTIONS="all_wo_mm"
FEWSHOT=1
EXP_NAME=critic_${MODEL}-few_shot_${FEWSHOT}-${ACTIONS}
mkdir -p ${OUTPUT_PATH}/${DATASET}/${EXP_NAME}
python run.py \
--data-path ${DATA_PATH} \
--output-path ${OUTPUT_PATH} \
--task-name ${TASK_PATH} \
--test-split ${DATASET} \
--exp-name ${EXP_NAME} \
--agent "critic" \
--model ${MODEL} \
--actions ${ACTIONS} \
--memory-path ${MEMORY_PATH} \
--memory-type "critic_standard" \
--few-shot ${FEWSHOT} \
--resume \
--log-print \
--prompt-debug
optimization on medqa dataset (Knowledge task example)
DATA_PATH=./ClinicalAgentBench
TASK_PATH=train
OUTPUT_PATH=./results/${TASK_PATH}
MEMORY_PATH=${DATA_PATH}/my_memory
domain=medqa
MODEL=qwen2-7b
ACTIONS="all_wo_mm" # tool type. note that `all` and `mm` actions should use 2 gpu.
FEWSHOT=0 # num of few-shot demonstration
EXP_NAME=train-${MODEL}-few_shot_${FEWSHOT}-${ACTIONS}
mkdir -p ${OUTPUT_PATH}/${domain}/${EXP_NAME}
python train.py \
--data-path ${DATA_PATH} \
--output-path ${OUTPUT_PATH} \
--task-name ${TASK_PATH} \
--test-split ${domain} \
--exp-name ${EXP_NAME} \
--model ${MODEL} \
--actions ${ACTIONS} \
--memory-path ${MEMORY_PATH} \
--max-exec-steps 15 \
--few-shot ${FEWSHOT} \
--write-memory \
--memory-type task \
--log-print \
--resume
optimization on slake dataset (MultiModal task example)
DATA_PATH=./ClinicalAgentBench
TASK_PATH=train
OUTPUT_PATH=./results/${TASK_PATH}
MEMORY_PATH=${DATA_PATH}/my_memory
domain=slake
MODEL=qwen2-7b
ACTIONS="all" # tool type. note that `all` and `mm` actions should use 2 gpu.
FEWSHOT=0 # num of few-shot demonstration
EXP_NAME=train-${MODEL}-few_shot_${FEWSHOT}-${ACTIONS}
mkdir -p ${OUTPUT_PATH}/${domain}/${EXP_NAME}
python train.py \
--data-path ${DATA_PATH} \
--output-path ${OUTPUT_PATH} \
--task-name ${TASK_PATH} \
--test-split ${domain} \
--exp-name ${EXP_NAME} \
--model ${MODEL} \
--actions ${ACTIONS} \
--memory-path ${MEMORY_PATH} \
--max-exec-steps 15 \
--few-shot ${FEWSHOT} \
--write-memory \
--memory-type task \
--log-print \
--resume
Evaluation on medqa dataset with ReflecTool (Iterative Refinement)
DATA_PATH=./ClinicalAgentBench
TASK=test
OUTPUT_PATH=./results/${TASK}
DATASET=medqa
MODEL=qwen2-7b
ACTIONS="all_wo_mm" # tool type. note that `all` and `mm` actions should use 2 gpu.
FEWSHOT=1 # num of few-shot demonstration
MEMORY_PATH=${DATA_PATH}/memory/task/long_term_memory/${DATASET}
EXP_NAME=reflectool_refine-${MODEL}-few_shot_${FEWSHOT}-${ACTIONS}-task_memory
mkdir -p ${OUTPUT_PATH}/${TASK}/${DATASET}/${EXP_NAME}
python run.py \
--data-path ${DATA_PATH} \
--output-path ${OUTPUT_PATH} \
--task-name ${TASK} \
--test-split ${DATASET} \
--exp-name ${EXP_NAME} \
--agent "reflectool" \
--model ${MODEL} \
--actions ${ACTIONS} \
--action-guide-path ${DATA_PATH}/memory/task/tool_experience/${DATASET}.json \
--memory-path ${MEMORY_PATH} \
--force-action \
--action-search "refine" \
--few-shot ${FEWSHOT} \
--log-print \
--prompt-debug \
--resume
Evaluation on medqa dataset with ReflecTool (Candidate Selection)
DATA_PATH=./ClinicalAgentBench
TASK=test
OUTPUT_PATH=./results/${TASK}
DATASET=medqa
MODEL=qwen2-7b
ACTIONS="all_wo_mm" # tool type. note that `all` and `mm` actions should use 2 gpu.
FEWSHOT=1 # num of few-shot demonstration
MEMORY_PATH=${DATA_PATH}/memory/task/long_term_memory/${DATASET}
EXP_NAME=reflectool_select-${MODEL}-few_shot_${FEWSHOT}-${ACTIONS}-task_memory
mkdir -p ${OUTPUT_PATH}/${TASK}/${DATASET}/${EXP_NAME}
python run.py \
--data-path ${DATA_PATH} \
--output-path ${OUTPUT_PATH} \
--task-name ${TASK} \
--test-split ${DATASET} \
--exp-name ${EXP_NAME} \
--agent "reflectool" \
--model ${MODEL} \
--actions ${ACTIONS} \
--action-guide-path ${DATA_PATH}/memory/task/tool_experience/${DATASET}.json \
--memory-path ${MEMORY_PATH} \
--force-action \
--action-search "select" \
--few-shot ${FEWSHOT} \
--log-print \
--prompt-debug \
--resume
- Question Answering
from reflectool.agents.TaskAgent import TaskAgent
from reflectool.commons.TaskPackage import TaskPackage
from reflectool.utilities import parse_args
args = parse_args()
args.model = "gpt-4o-mini"
Agent = TaskAgent(args)
task = TaskPackage(
inputs="Can you explain me about the fever?",
instruction="",
)
print(Agent(task))
- Visual Question Answering
from reflectool.agents.TaskAgent import TaskAgent
from reflectool.commons.TaskPackage import TaskPackage
from reflectool.utilities import parse_args
args = parse_args()
args.model = "gpt-4o-mini"
Agent = TaskAgent(args)
task = TaskPackage(
inputs="Can you descript the image?",
instruction="",
multimodal_inputs={
"image": "{path to the image}",
}
)
print(Agent(task))
- Question Answering
from reflectool.agents.ReflecToolAgent import ReflecToolAgent
from reflectool.commons.TaskPackage import TaskPackage
from reflectool.utilities import parse_args
args = parse_args()
args.model = "gpt-4o-mini"
args.action_search = "refine" # verification method: "refine" or "select"
Agent = ReflecToolAgent(args)
task = TaskPackage(
inputs="Can you explain me about the fever?",
instruction="",
)
print(Agent(task))
Thanks to the codebase we built upon:
@misc{liao2024reflectoolreflectionawaretoolaugmentedclinical,
title={ReflecTool: Towards Reflection-Aware Tool-Augmented Clinical Agents},
author={Yusheng Liao and Shuyang Jiang and Yanfeng Wang and Yu Wang},
year={2024},
eprint={2410.17657},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2410.17657},
}