Skip to content

Benchmark, Toolbox, and Reflection-based Method for Clinical Agent

License

Notifications You must be signed in to change notification settings

BlueZeros/ReflecTool

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

12 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

REFLECTOOL: Towards Reflection-Aware Tool-Augmented Clinical Agents

This is the official repository for "REFLECTOOL: Towards Reflection-Aware Tool-Augmented Clinical Agents"

[Webpage] [Paper] [Huggingface Dataset] [Leaderboard]

Outlines

πŸ”¦ About

ClinicalAgent Bench

Despite clinical agents succeeding in diverse signal interaction, they are oriented to a single clinical scenario and hence fail for broader applications. To evaluate clinical agents holistically, we propose ClinicalAgent Bench, a comprehensive medical agent benchmark consisting of 18 tasks across five key realistic clinical dimensions.

ReflecTool

Building on this, we introduce ReflectTool, a novel framework that excels at utilizing domain-specific tools within two stages. ReflectTool can search for supportive successful demonstrations from already built long-term memory to guide the tool selection strategy, and a verifier improves the tool usage according to the tool-wise experience with two verification methods--Iterative Refinement and Candidate Selection.

πŸ† LeaderBoard

Methods Type Total Knowledge& Reasoning Multimodal Numerical Analysis Data Understanding Trustworthiness
MedQA MMLU BioASQ PubMedQA Avg. VQARAD SLAKE OmniMedQA Avg. MedCalc EHRSQL MIMIC-III eICU Avg. MedMentions emrQA LongHealthQA Avg. MedHalt-Rht MedVQA-Halt EHR-Halt LongHalt Avg.
MedLlama3-8b LLM 25.14 58.13 72.82 66.18 42.40 59.88 - - - - 22.45 7.92 8.65 7.40 11.61 22.93 23.10 - 23.02 9.90 - 2.23 - 6.07
Qwen2-7B LLM 38.01 54.04 69.51 71.68 48.20 60.86 - - - - 14.42 18.09 19.70 21.73 18.49 18.38 39.86 75.25 44.50 14.11 - 3.97 66.50 28.19
Llama3-8b LLM 35.62 56.32 70.34 72.82 53.80 63.32 - - - - 28.08 17.02 12.53 21.83 19.87 28.54 41.92 - 35.23 29.22 - 18.90 - 24.06
Llama3.1-8b LLM 42.46 65.20 76.58 74.92 53.00 67.43 - - - - 37.44 11.35 17.42 22.07 32.08 42.41 74.25 49.58 27.78 - 3.97 60.50 30.75
Qwen2-72B* LLM 48.76 71.25 84.48 82.85 53.00 72.90 - - - - 32.19 23.98 33.95 34.15 31.07 29.20 42.89 79.75 50.61 31.56 - 31.30 58.50 40.45
Llama3.1-70b* LLM 47.59 79.58 88.15 82.52 57.40 76.91 - - - - 48.52 16.49 25.44 26.47 29.23 25.71 31.69 80.00 45.80 28.22 - 22.48 64.50 38.40
GPT-3.5-turbo LLM 31.31 58.68 69.88 75.40 50.60 63.64 - - - - 20.53 17.57 24.31 14.30 19.18 26.88 21.64 - 24.26 9.78 - 26.55 - 18.17
MiniCPM-V-2.6 MLLM 29.28 46.58 61.16 70.23 47.20 56.29 48.78 47.12 73.70 56.53 13.28 1.61 1.63 1.88 4.60 18.92 17.42 5.25 13.86 12.44 36.89 8.91 2.25 15.12
InternVL-Chat-V1.5 MLLM 37.02 50.82 65.56 64.89 30.40 52.92 49.67 41.47 68.50 53.21 18.91 17.99 18.05 20.83 18.95 26.47 42.53 * 34.50 25.78 50.78 0.00 * 25.52
HuatuoGPT-Vision-7B MLLM 41.11 50.43 66.12 73.30 54.00 60.96 53.65 52.97 92.10 66.24 13.56 4.39 9.27 9.76 9.25 16.74 38.44 73.00 42.73 14.33 23.44 2.42 65.25 26.36
HuatuoGPT-Vision-34B MLLM 41.31 54.83 72.36 73.79 48.00 62.25 56.76 53.72 91.50 67.33 25.79 8.14 9.15 9.79 13.22 16.34 41.68 * 29.01 28.44 43.77 32.07 * 34.76
GPT-4o-mini MLLM 52.65 76.90 85.67 82.84 49.20 73.65 50.47 46.47 59.20 52.05 50.43 21.73 28.07 18.19 29.61 31.66 40.00 78.50 50.05 62.11 53.78 45.74 70.00 57.91
COT Qwen2-7b Agent 39.94 52.47 69.97 72.21 41.00 58.91 - - - - 19.10 16.06 23.81 20.95 19.98 22.83 19.92 65.75 36.17 45.56 - 31.49 57.00 44.68
ReAct Qwen2-7b Agent 42.73 51.61 67.68 80.24 48.60 62.03 35.92 39.59 72.90 49.47 18.62 18.52 25.06 34.00 24.05 22.19 24.04 41.50 29.24 49.89 35.33 61.49 48.75 48.87
CRITIC Qwen2-7b Agent 43.97 52.87 58.68 71.68 43.20 56.61 48.12 42.70 70.80 53.87 13.09 23.55 28.20 33.12 24.49 25.47 36.33 50.25 37.35 30.44 28.33 57.64 73.75 47.54
Reflexion Qwen2-7b Agent 45.25 51.78 66.48 74.60 50.80 60.92 45.68 47.97 77.20 56.95 13.37 17.56 22.16 30.23 20.83 30.30 28.92 53.00 37.41 50.55 36.33 62.91 50.75 50.14
ReflecTool (Iterative Refinement, k=2) Qwen2-7b Agent 49.37 50.12 65.47 76.37 63.20 63.79 53.88 45.71 82.90 60.83 24.68 24.20 16.92 22.08 21.97 43.69 50.27 61.00 51.65 54.44 37.99 56.73 45.25 48.60
ReflecTool (Candidates Selection, k=2) Qwen2-7b Agent 49.08 50.81 64.84 72.98 62.60 62.81 56.76 49.76 79.20 61.91 24.64 29.87 25.13 27.48 26.78 59.21 43.40 54.00 52.20 53.44 34.21 55.97 23.25 41.72
COT Qwen2-72b Agent 50.58 72.51 85.45 81.07 37.40 69.11 - - - - 29.89 18.73 23.55 25.72 24.47 24.43 52.09 81.00 52.51 57.11 - 40.79 70.75 56.22
ReAct Qwen2-72b Agent 53.31 72.43 85.67 85.38 62.40 76.47 50.09 46.01 73.00 56.37 26.46 35.97 31.45 31.87 31.44 42.11 55.02 62.75 53.29 56.89 24.75 55.78 58.50 48.98
CRITIC Qwen2-72b Agent 52.35 71.85 85.31 87.86 51.00 74.01 40.58 50.80 73.50 54.96 22.44 35.48 32.47 33.28 30.92 33.60 56.36 75.50 55.15 51.77 24.22 52.03 58.75 46.69
Reflexion Qwen2-72b Agent 56.37 70.78 84.30 87.06 65.00 76.79 54.99 50.05 77.80 60.95 22.45 40.14 31.94 33.42 31.99 52.37 58.73 64.00 58.37 59.00 33.00 59.00 64.00 53.75
ReflecTool (Iterative Refinement, k=2) Qwen2-72b Agent 59.43 73.30 84.11 84.63 65.20 76.81 57.21 48.82 85.20 63.74 36.01 47.43 33.96 36.39 38.45 54.57 67.96 68.00 63.51 59.55 38.21 57.82 63.00 54.65
ReflecTool (Candidates Selection, k=2) Qwen2-72b Agent 59.66 71.37 84.00 83.50 66.20 76.27 57.66 48.54 81.90 62.70 36.77 49.89 31.20 34.38 38.06 60.51 66.61 66.50 64.54 60.77 39.63 59.78 66.75 56.73

πŸ’« News

  • πŸ”₯ [2024/10/24] We release the ClinicalAgent Bench, comprised 18 tasks across five capacity dimensions!
  • πŸ”₯ [2024/10/24] We release the code implementation of the ReflecTool!

πŸ’‘ Installations

Environment

conda create -n reflectool python=3.10.8
conda activate reflectool
# install torch
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118
# install java
conda install -c conda-forge openjdk=21.0.4

Dependency

git clone https://github.com/BlueZeros/ReflecTool.git
cd ReflecTool
pip install -e .

ToolBox

You can find details of the tools installation in here.

πŸ“” Dataset

Download the dataset ClinicalAgentBench from here and put it into the temp folder. The ClinicalAgentBench has the structure below:

ReflecTool/
β”œβ”€β”€ ClinicalAgentBench/
β”‚   β”œβ”€β”€ ablations/ # subset used for analysis
β”‚   β”œβ”€β”€ train/ # dataset used for optimization stage
β”‚   β”œβ”€β”€ test/ # dataset used for evaluation
β”‚   β”œβ”€β”€ memory/ # few-shot samples, long-term memory and tool experience

πŸ”¬ Support Models

The support model list can be found in model_config. You should add corresponding folder path before using the model. The type of the supported models are shown below:

  • Local Models: Huggingface models load with AutoModelForCausalLM. We also support the vllm accelerate for these models (automatically used if vllm is installed).
    • Llama3
    • Llama3.1
    • Qwen1.5
    • Qwen2
    • Qwen2.5
  • OpenAI Models: OpenAI Model list.
    • GPT-3.5-turbo
    • GPT-4o-mini
    • GPT-4o
    • Note: need to set environment variables OPENAI_API_KEY={your_openai_api_key}.
  • MultiModal LLM: HuatuoGPT-Vision and MiniCPM.
    • HuatuoGPT-Vision
    • MiniCPM-V-2.6
    • InternVL-Chat-V1.5

πŸ“ Evaluation

Baselines

Models
DATA_PATH=./ClinicalAgentBench
TASK_PATH=test

OUTPUT_PATH=./results/${TASK_PATH}
MEMORY_PATH=${DATA_PATH}/memory

DATASET=medqa
MODEL=qewn2-7b

EXP_NAME=${MODEL}
mkdir -p ${OUTPUT_PATH}/${DATASET}/${EXP_NAME}

python run.py \
    --data-path ${DATA_PATH} \
    --output-path ${OUTPUT_PATH} \
    --task-name ${TASK_PATH} \
    --test-split ${DATASET} \
    --exp-name ${EXP_NAME} \
    --model ${MODEL} \
    --log-print \
    --prompt-debug \
    --resume
Reflexion Agent
DATA_PATH=./ClinicalAgentBench
TASK_PATH=test

OUTPUT_PATH=./results/${TASK_PATH}
MEMORY_PATH=${DATA_PATH}/memory

DATASET="medqa"
MODEL=qwen2-7b
ACTIONS="all_wo_mm"
FEWSHOT=1

EXP_NAME=reflexion_${MODEL}-few_shot_${FEWSHOT}-${ACTIONS}
mkdir -p ${OUTPUT_PATH}/${DATASET}/${EXP_NAME}

python run.py \
    --data-path ${DATA_PATH} \
    --output-path ${OUTPUT_PATH} \
    --task-name ${TASK_PATH} \
    --test-split ${DATASET} \
    --exp-name ${EXP_NAME} \
    --agent "reflexion" \
    --model ${MODEL} \
    --actions ${ACTIONS} \
    --memory-path ${MEMORY_PATH} \
    --memory-type "reflexion_standard" \
    --few-shot ${FEWSHOT} \
    --resume
CRITIC Agent
DATA_PATH=./ClinicalAgentBench
TASK_PATH=test

OUTPUT_PATH=./results/${TASK_PATH}
MEMORY_PATH=${DATA_PATH}/memory

DATASET="medqa"
MODEL=qwen2-72b-int4

ACTIONS="all_wo_mm"
FEWSHOT=1

EXP_NAME=critic_${MODEL}-few_shot_${FEWSHOT}-${ACTIONS}
mkdir -p ${OUTPUT_PATH}/${DATASET}/${EXP_NAME}

python run.py \
    --data-path ${DATA_PATH} \
    --output-path ${OUTPUT_PATH} \
    --task-name ${TASK_PATH} \
    --test-split ${DATASET} \
    --exp-name ${EXP_NAME} \
    --agent "critic" \
    --model ${MODEL} \
    --actions ${ACTIONS} \
    --memory-path ${MEMORY_PATH} \
    --memory-type "critic_standard" \
    --few-shot ${FEWSHOT} \
    --resume \
    --log-print \
    --prompt-debug

ReflecTool

Optimization

optimization on medqa dataset (Knowledge task example)
DATA_PATH=./ClinicalAgentBench
TASK_PATH=train

OUTPUT_PATH=./results/${TASK_PATH}
MEMORY_PATH=${DATA_PATH}/my_memory

domain=medqa
MODEL=qwen2-7b
ACTIONS="all_wo_mm" # tool type. note that `all` and `mm` actions should use 2 gpu.
FEWSHOT=0 # num of few-shot demonstration

EXP_NAME=train-${MODEL}-few_shot_${FEWSHOT}-${ACTIONS}
mkdir -p ${OUTPUT_PATH}/${domain}/${EXP_NAME}

python train.py \
    --data-path ${DATA_PATH} \
    --output-path ${OUTPUT_PATH} \
    --task-name ${TASK_PATH} \
    --test-split ${domain} \
    --exp-name ${EXP_NAME} \
    --model ${MODEL} \
    --actions ${ACTIONS} \
    --memory-path ${MEMORY_PATH} \
    --max-exec-steps 15 \
    --few-shot ${FEWSHOT} \
    --write-memory \
    --memory-type task \
    --log-print \
    --resume
optimization on slake dataset (MultiModal task example)
DATA_PATH=./ClinicalAgentBench
TASK_PATH=train

OUTPUT_PATH=./results/${TASK_PATH}
MEMORY_PATH=${DATA_PATH}/my_memory

domain=slake
MODEL=qwen2-7b
ACTIONS="all" # tool type. note that `all` and `mm` actions should use 2 gpu.
FEWSHOT=0 # num of few-shot demonstration

EXP_NAME=train-${MODEL}-few_shot_${FEWSHOT}-${ACTIONS}
mkdir -p ${OUTPUT_PATH}/${domain}/${EXP_NAME}

python train.py \
    --data-path ${DATA_PATH} \
    --output-path ${OUTPUT_PATH} \
    --task-name ${TASK_PATH} \
    --test-split ${domain} \
    --exp-name ${EXP_NAME} \
    --model ${MODEL} \
    --actions ${ACTIONS} \
    --memory-path ${MEMORY_PATH} \
    --max-exec-steps 15 \
    --few-shot ${FEWSHOT} \
    --write-memory \
    --memory-type task \
    --log-print \
    --resume

Inference

Evaluation on medqa dataset with ReflecTool (Iterative Refinement)
DATA_PATH=./ClinicalAgentBench
TASK=test

OUTPUT_PATH=./results/${TASK}

DATASET=medqa
MODEL=qwen2-7b
ACTIONS="all_wo_mm" # tool type. note that `all` and `mm` actions should use 2 gpu.
FEWSHOT=1 # num of few-shot demonstration
MEMORY_PATH=${DATA_PATH}/memory/task/long_term_memory/${DATASET}

EXP_NAME=reflectool_refine-${MODEL}-few_shot_${FEWSHOT}-${ACTIONS}-task_memory
mkdir -p ${OUTPUT_PATH}/${TASK}/${DATASET}/${EXP_NAME}

python run.py \
    --data-path ${DATA_PATH} \
    --output-path ${OUTPUT_PATH} \
    --task-name ${TASK} \
    --test-split ${DATASET} \
    --exp-name ${EXP_NAME} \
    --agent "reflectool" \
    --model ${MODEL} \
    --actions ${ACTIONS} \
    --action-guide-path ${DATA_PATH}/memory/task/tool_experience/${DATASET}.json \
    --memory-path ${MEMORY_PATH} \
    --force-action \
    --action-search "refine" \
    --few-shot ${FEWSHOT} \
    --log-print \
    --prompt-debug \
    --resume
Evaluation on medqa dataset with ReflecTool (Candidate Selection)
DATA_PATH=./ClinicalAgentBench
TASK=test

OUTPUT_PATH=./results/${TASK}

DATASET=medqa
MODEL=qwen2-7b
ACTIONS="all_wo_mm" # tool type. note that `all` and `mm` actions should use 2 gpu.
FEWSHOT=1 # num of few-shot demonstration
MEMORY_PATH=${DATA_PATH}/memory/task/long_term_memory/${DATASET}

EXP_NAME=reflectool_select-${MODEL}-few_shot_${FEWSHOT}-${ACTIONS}-task_memory
mkdir -p ${OUTPUT_PATH}/${TASK}/${DATASET}/${EXP_NAME}

python run.py \
    --data-path ${DATA_PATH} \
    --output-path ${OUTPUT_PATH} \
    --task-name ${TASK} \
    --test-split ${DATASET} \
    --exp-name ${EXP_NAME} \
    --agent "reflectool" \
    --model ${MODEL} \
    --actions ${ACTIONS} \
    --action-guide-path ${DATA_PATH}/memory/task/tool_experience/${DATASET}.json \
    --memory-path ${MEMORY_PATH} \
    --force-action \
    --action-search "select" \
    --few-shot ${FEWSHOT} \
    --log-print \
    --prompt-debug \
    --resume

πŸ“š Quickly Use

ReAct Agent

  • Question Answering
from reflectool.agents.TaskAgent import TaskAgent
from reflectool.commons.TaskPackage import TaskPackage
from reflectool.utilities import parse_args


args = parse_args()
args.model = "gpt-4o-mini"
Agent = TaskAgent(args)

task = TaskPackage(
    inputs="Can you explain me about the fever?",
    instruction="",
)
print(Agent(task))
  • Visual Question Answering
from reflectool.agents.TaskAgent import TaskAgent
from reflectool.commons.TaskPackage import TaskPackage
from reflectool.utilities import parse_args


args = parse_args()
args.model = "gpt-4o-mini"
Agent = TaskAgent(args)

task = TaskPackage(
    inputs="Can you descript the image?",
    instruction="",
    multimodal_inputs={
      "image": "{path to the image}",
    }
)
print(Agent(task))

ReflecTool Agent

  • Question Answering
from reflectool.agents.ReflecToolAgent import ReflecToolAgent
from reflectool.commons.TaskPackage import TaskPackage
from reflectool.utilities import parse_args


args = parse_args()
args.model = "gpt-4o-mini"
args.action_search = "refine" # verification method: "refine" or "select"
Agent = ReflecToolAgent(args)

task = TaskPackage(
    inputs="Can you explain me about the fever?",
    instruction="",
)
print(Agent(task))

πŸͺΆ Acknowledgements

Thanks to the codebase we built upon:

Citation

@misc{liao2024reflectoolreflectionawaretoolaugmentedclinical,
      title={ReflecTool: Towards Reflection-Aware Tool-Augmented Clinical Agents}, 
      author={Yusheng Liao and Shuyang Jiang and Yanfeng Wang and Yu Wang},
      year={2024},
      eprint={2410.17657},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2410.17657}, 
}

About

Benchmark, Toolbox, and Reflection-based Method for Clinical Agent

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published