Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for Gemma2 #16

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion scripts/config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export MODEL_PRECISION=bf16 # precision at which the base model is stored; optio
# export PANZA_GENERATIVE_MODEL="mistralai/Mistral-7B-Instruct-v0.2"
export PANZA_GENERATIVE_MODEL="ISTA-DASLab/Meta-Llama-3-8B-Instruct"
# export PANZA_GENERATIVE_MODEL="microsoft/Phi-3-mini-4k-instruct"
# export PANZA_GENERATIVE_MODEL="google/gemma-2-2b-it"

lowercased=$(echo "$PANZA_GENERATIVE_MODEL" | tr '[:upper:]' '[:lower:]')
if [[ ${lowercased} == *llama* ]]; then
Expand All @@ -29,8 +30,10 @@ elif [[ ${lowercased} == *mistral* ]]; then
export MODEL_TYPE=mistralv2
elif [[ ${lowercased} == *phi* ]]; then
export MODEL_TYPE=phi3
elif [[ ${lowercased} == *gemma* ]]; then
export MODEL_TYPE=gemma2
else
echo "Model type ${PANZA_GENERATIVE_MODEL} not recognized! Panza only works with Mistral and Llama3 models. Exiting."
echo "Model type ${PANZA_GENERATIVE_MODEL} not recognized! Panza only works with Mistral, Phi, Gemma2, and Llama3 models. Exiting."
exit
fi

Expand Down
3 changes: 3 additions & 0 deletions scripts/train_rosa.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ elif [[ ${MODEL_TYPE} == mistralv2 ]]; then
elif [[ ${MODEL_TYPE} == phi3 ]]; then
export LR=1e-5 # learning rate
export LORA_LR=1e-5 # a separate learning rate for the low-rank adapters
elif [[ ${MODEL_TYPE} == gemma2 ]]; then
export LR=1e-5 # learning rate
export LORA_LR=1e-5 # a separate learning rate for the low-rank adapters
else
echo "Model type ${MODEL_TYPE} not recognized! Panza only works with mistralv2, llama3 and phi3 models. Exiting."
exit
Expand Down
6 changes: 5 additions & 1 deletion src/panza/data_preparation/summarize_emails.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
sys.path.pop(0)

MDL = os.environ.get("PANZA_GENERATIVE_MODEL")
MDL_TYPE = os.environ.get("MODEL_TYPE")
TEMP = 0.7
TOP_P = 0.7
TOP_K = 50
Expand All @@ -43,7 +44,10 @@ def __init__(self, model, dtype, temperature, top_k, top_p, summarization_prompt
model, model_max_length=self.model.config.max_position_embeddings, trust_remote_code=True
)
self.tokenizer.padding_side = "left"
self.tokenizer.pad_token = self.tokenizer.eos_token
if MDL_TYPE == "gemma2":
self.tokenizer.pad_token = self.tokenizer.bos_token
else:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.summarization_prompt = summarization_prompt

_, self.prompt_end_wrapper, _, self.response_end_wrapper = (
Expand Down
6 changes: 4 additions & 2 deletions src/panza/finetuning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,17 @@ def build_composer_peft_model(
with init_empty_weights(include_buffers=False):
model = AutoModelForCausalLM.from_pretrained(
model_config.pretrained_model_name_or_path,
device_map='cpu' if quant_config is None else 'auto',
# device_map='cpu' if quant_config is None else 'auto',
torch_dtype=compute_dtype,
# load_in_4bit=weight_bias_dtype == '4bit',
quantization_config=quant_config,
trust_remote_code=True,
use_auth_token=True,
use_cache=False,
attn_implementation='eager'
attn_implementation='eager',
low_cpu_mem_usage=True,
)
model.tie_weights()

print('Model built!')
if rosa_config is not None:
Expand Down
10 changes: 10 additions & 0 deletions src/panza/utils/prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
PHI3_RESPONSE_START_WRAPPER = ""
PHI3_RESPONSE_END_WRAPPER = "<|end|>"

GEMMA2_PROMPT_START_WRAPPER = "<bos><start_of_turn>user\n"
GEMMA2_PROMPT_END_WRAPPER = "<end_of_turn>\n<start_of_turn>model\n"
GEMMA2_RESPONSE_START_WRAPPER = ""
GEMMA2_RESPONSE_END_WRAPPER = "<end_of_turn>"

def create_prompt(
user_input: Text,
system_preamble: Text,
Expand Down Expand Up @@ -126,6 +131,11 @@ def get_model_special_tokens(model_name):
prompt_end_wrapper = PHI3_PROMPT_END_WRAPPER
response_start_wrapper = PHI3_RESPONSE_START_WRAPPER
response_end_wrapper = PHI3_RESPONSE_END_WRAPPER
elif "gemma" in model_name.lower():
prompt_start_wrapper = GEMMA2_PROMPT_START_WRAPPER
prompt_end_wrapper = GEMMA2_PROMPT_END_WRAPPER
response_start_wrapper = GEMMA2_RESPONSE_START_WRAPPER
response_end_wrapper = GEMMA2_RESPONSE_END_WRAPPER
else:
raise ValueError(f"Presets missing for prompting model {model_name}")

Expand Down