From d3a95b6c2657dcb5c05f326da8895f521cf60009 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A9=AC=E9=80=B4=E5=87=A1?= Date: Fri, 7 Jun 2024 14:51:07 +0800 Subject: [PATCH] bug fix --- groma/eval/run_groma.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/groma/eval/run_groma.py b/groma/eval/run_groma.py index c0022dd..d8e360a 100644 --- a/groma/eval/run_groma.py +++ b/groma/eval/run_groma.py @@ -55,7 +55,10 @@ def eval_model(model_name, quant_type, image_file, query,): ) kwargs = {'quantization_config': int4_quant_cfg} - model = GromaModel.from_pretrained(model_name, **kwargs).cuda() + if quant_type == '8bit' or quant_type == '4bit': + model = GromaModel.from_pretrained(model_name, **kwargs) + else: + model = GromaModel.from_pretrained(model_name, **kwargs).cuda() model.init_special_token_id(tokenizer) conversations = []