diff --git a/groma/eval/run_groma.py b/groma/eval/run_groma.py index c0022dd..d8e360a 100644 --- a/groma/eval/run_groma.py +++ b/groma/eval/run_groma.py @@ -55,7 +55,10 @@ def eval_model(model_name, quant_type, image_file, query,): ) kwargs = {'quantization_config': int4_quant_cfg} - model = GromaModel.from_pretrained(model_name, **kwargs).cuda() + if quant_type == '8bit' or quant_type == '4bit': + model = GromaModel.from_pretrained(model_name, **kwargs) + else: + model = GromaModel.from_pretrained(model_name, **kwargs).cuda() model.init_special_token_id(tokenizer) conversations = []