Skip to content

Commit

Permalink
vLLM multi gpu tests adjustments (#65)
Browse files Browse the repository at this point in the history
Co-authored-by: Jacky <[email protected]>
  • Loading branch information
oandreeva-nv and kthui authored Sep 24, 2024
1 parent 0df1013 commit b71088a
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 6 deletions.
103 changes: 98 additions & 5 deletions ci/L0_multi_gpu/multi_lora/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,16 @@ cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_llama_multi_lora

export SERVER_ENABLE_LORA=true

# Check boolean flag value for `enable_lora`
model_json=$(cat <<EOF
{
"model":"./weights/backbone/gemma-2b",
"disable_log_requests": "true",
"disable_log_requests": true,
"gpu_memory_utilization": 0.7,
"tensor_parallel_size": 2,
"block_size": 16,
"enforce_eager": "true",
"enable_lora": "true",
"enforce_eager": true,
"enable_lora": true,
"max_lora_rank": 32,
"lora_extra_vocab_size": 256,
"distributed_executor_backend":"ray"
Expand Down Expand Up @@ -110,16 +111,108 @@ set -e
kill $SERVER_PID
wait $SERVER_PID

# Check string flag value for `enable_lora`
model_json=$(cat <<EOF
{
"model":"./weights/backbone/gemma-2b",
"disable_log_requests": true,
"gpu_memory_utilization": 0.7,
"tensor_parallel_size": 2,
"block_size": 16,
"enforce_eager": true,
"enable_lora": "true",
"max_lora_rank": 32,
"lora_extra_vocab_size": 256,
"distributed_executor_backend":"ray"
}
EOF
)
echo "$model_json" > models/vllm_llama_multi_lora/1/model.json

run_server
if [ "$SERVER_PID" == "0" ]; then
cat $SERVER_LOG
echo -e "\n***\n*** Failed to start $SERVER\n***"
exit 1
fi

set +e
python3 $CLIENT_PY -v > $CLIENT_LOG 2>&1

if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***"
RET=1
else
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification FAILED.\n***"
RET=1
fi
fi
set -e

kill $SERVER_PID
wait $SERVER_PID

# disable lora
export SERVER_ENABLE_LORA=false
# check bool flag value for `enable_lora`
model_json=$(cat <<EOF
{
"model":"./weights/backbone/gemma-2b",
"disable_log_requests": true,
"gpu_memory_utilization": 0.8,
"tensor_parallel_size": 2,
"block_size": 16,
"enforce_eager": true,
"enable_lora": false,
"lora_extra_vocab_size": 256,
"distributed_executor_backend":"ray"
}
EOF
)
echo "$model_json" > models/vllm_llama_multi_lora/1/model.json

run_server
if [ "$SERVER_PID" == "0" ]; then
cat $SERVER_LOG
echo -e "\n***\n*** Failed to start $SERVER\n***"
exit 1
fi

set +e
python3 $CLIENT_PY -v >> $CLIENT_LOG 2>&1

if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***"
RET=1
else
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification FAILED.\n***"
RET=1
fi
fi
set -e

kill $SERVER_PID
wait $SERVER_PID

# disable lora
export SERVER_ENABLE_LORA=false
# check string flag value for `enable_lora`
model_json=$(cat <<EOF
{
"model":"./weights/backbone/gemma-2b",
"disable_log_requests": "true",
"disable_log_requests": true,
"gpu_memory_utilization": 0.8,
"tensor_parallel_size": 2,
"block_size": 16,
"enforce_eager": "true",
"enforce_eager": true,
"enable_lora": "false",
"lora_extra_vocab_size": 256,
"distributed_executor_backend":"ray"
Expand Down
4 changes: 3 additions & 1 deletion src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,11 @@ def init_engine(self):
def setup_lora(self):
self.enable_lora = False

# Check if `enable_lora` field is in the `model.json`,
# and if it is, read its contents, which can be string or bool.
if (
"enable_lora" in self.vllm_engine_config.keys()
and self.vllm_engine_config["enable_lora"].lower() == "true"
and str(self.vllm_engine_config["enable_lora"]).lower() == "true"
):
# create Triton LoRA weights repository
multi_lora_args_filepath = os.path.join(
Expand Down

0 comments on commit b71088a

Please sign in to comment.