From 36fb882747ae529f4ad5484bc306ab7f9b6bd9a4 Mon Sep 17 00:00:00 2001 From: Preetika764 Date: Sat, 10 Aug 2024 10:37:21 +0000 Subject: [PATCH] Adding support for qwen --- scripts/config.sh | 3 +++ scripts/train_rosa.sh | 3 +++ src/panza/finetuning/train.py | 7 ++++--- src/panza/utils/prompting.py | 10 ++++++++++ 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/scripts/config.sh b/scripts/config.sh index 76522f1..e922c37 100755 --- a/scripts/config.sh +++ b/scripts/config.sh @@ -21,6 +21,7 @@ export MODEL_PRECISION=bf16 # precision at which the base model is stored; optio # export PANZA_GENERATIVE_MODEL="mistralai/Mistral-7B-Instruct-v0.2" export PANZA_GENERATIVE_MODEL="ISTA-DASLab/Meta-Llama-3-8B-Instruct" # export PANZA_GENERATIVE_MODEL="microsoft/Phi-3-mini-4k-instruct" +# export PANZA_GENERATIVE_MODEL="Qwen/Qwen2-1.5B-Instruct" lowercased=$(echo "$PANZA_GENERATIVE_MODEL" | tr '[:upper:]' '[:lower:]') if [[ ${lowercased} == *llama* ]]; then @@ -29,6 +30,8 @@ elif [[ ${lowercased} == *mistral* ]]; then export MODEL_TYPE=mistralv2 elif [[ ${lowercased} == *phi* ]]; then export MODEL_TYPE=phi3 +elif [[ ${lowercased} == *qwen* ]]; then + export MODEL_TYPE=qwen else echo "Model type ${PANZA_GENERATIVE_MODEL} not recognized! Panza only works with Mistral and Llama3 models. Exiting." exit diff --git a/scripts/train_rosa.sh b/scripts/train_rosa.sh index d100cd2..0224728 100755 --- a/scripts/train_rosa.sh +++ b/scripts/train_rosa.sh @@ -28,6 +28,9 @@ elif [[ ${MODEL_TYPE} == mistralv2 ]]; then elif [[ ${MODEL_TYPE} == phi3 ]]; then export LR=1e-5 # learning rate export LORA_LR=1e-5 # a separate learning rate for the low-rank adapters +elif [[ ${MODEL_TYPE} == qwen ]]; then + export LR=1e-5 + export LORA_LR=1e-5 else echo "Model type ${MODEL_TYPE} not recognized! Panza only works with mistralv2, llama3 and phi3 models. Exiting." exit diff --git a/src/panza/finetuning/train.py b/src/panza/finetuning/train.py index 82c7c23..4b05764 100644 --- a/src/panza/finetuning/train.py +++ b/src/panza/finetuning/train.py @@ -166,16 +166,17 @@ def build_composer_peft_model( with init_empty_weights(include_buffers=False): model = AutoModelForCausalLM.from_pretrained( model_config.pretrained_model_name_or_path, - device_map='cpu' if quant_config is None else 'auto', + # device_map='cpu' if quant_config is None else 'auto', torch_dtype=compute_dtype, # load_in_4bit=weight_bias_dtype == '4bit', quantization_config=quant_config, trust_remote_code=True, use_auth_token=True, use_cache=False, - attn_implementation='eager' + attn_implementation='eager', + low_cpu_mem_usage=True, ) - + model.tie_weights() print('Model built!') if rosa_config is not None: print('Building RoSA config...') diff --git a/src/panza/utils/prompting.py b/src/panza/utils/prompting.py index 958b720..d798728 100644 --- a/src/panza/utils/prompting.py +++ b/src/panza/utils/prompting.py @@ -17,6 +17,11 @@ PHI3_RESPONSE_START_WRAPPER = "" PHI3_RESPONSE_END_WRAPPER = "<|end|>" +QWEN_PROMPT_START_WRAPPER = "<|im_start|>user\n" +QWEN_PROMPT_END_WRAPPER = "<|im_end|>\n<|im_start|>assistant\n" +QWEN_RESPONSE_START_WRAPPER = "" +QWEN_RESPONSE_END_WRAPPER = "<|im_end|>" + def create_prompt( user_input: Text, system_preamble: Text, @@ -126,6 +131,11 @@ def get_model_special_tokens(model_name): prompt_end_wrapper = PHI3_PROMPT_END_WRAPPER response_start_wrapper = PHI3_RESPONSE_START_WRAPPER response_end_wrapper = PHI3_RESPONSE_END_WRAPPER + elif "qwen" in model_name.lower(): + prompt_start_wrapper = QWEN_PROMPT_START_WRAPPER + prompt_end_wrapper = QWEN_PROMPT_END_WRAPPER + response_start_wrapper = QWEN_RESPONSE_START_WRAPPER + response_end_wrapper = QWEN_RESPONSE_END_WRAPPER else: raise ValueError(f"Presets missing for prompting model {model_name}")