-
Notifications
You must be signed in to change notification settings - Fork 1
/
path_config.py
78 lines (57 loc) · 2.07 KB
/
path_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
####################
CACHEDIR = "/storage/shared/janghyun" # os.environ['TRANSFORMERS_CACHE'] # huggingface model cache_dir (e.g., LLaMA-2)
LLAMADIR = "/storage/shared/janghyun" # LLaMA model directory (llama-7b-hf)
DATAPATH = "./dataset" # tokenized data directory (containing folders e.g. metaicl, soda)
SAVEPATH = "./result" # result directory (containing folders of dataset names)
####################
# DATAPATH example
## DATAPATH
## |- metaicl
## |- soda
# SAVEPATH example
## SAVEPATH
## |- all
## |- llama-7b-no
## |- finetune
## |- metaicl
## |- dialog
def model_path(model_name):
if model_name == "llama-7b":
path = os.path.join(LLAMADIR, "llama-7b-hf")
elif model_name == "llama-13b":
path = os.path.join(LLAMADIR, "llama-13b-hf")
elif model_name == "llama-2-7b-chat":
path = "meta-llama/Llama-2-7b-chat-hf"
elif model_name == "llama-2-13b-chat":
path = "meta-llama/Llama-2-13b-chat-hf"
elif model_name == "llama-2-7b":
path = "meta-llama/Llama-2-7b-hf"
elif model_name == "llama-2-13b":
path = "meta-llama/Llama-2-13b-hf"
elif model_name == "mistral-7b":
path = "mistralai/Mistral-7B-v0.1"
elif model_name == "mistral-7b-inst":
path = "mistralai/Mistral-7B-Instruct-v0.2"
elif "flan-t5" in model_name:
path = f"google/{model_name}"
elif model_name == "llama-debug":
path = "meta-llama/Llama-2-7b-chat-hf"
elif model_name == "mistral-debug":
path = "mistralai/Mistral-7B-Instruct-v0.2"
else:
raise ValueError(f"Unknown model name: {model_name}")
return path
def map_config(model_name):
is_llama = "llama" in model_name.lower() or "mistral" in model_name.lower()
if "debug" in model_name:
config = "llama-debug"
elif is_llama and "7b" in model_name:
config = "llama-7b"
elif is_llama and "13b" in model_name:
config = "llama-13b"
elif "flan-t5" in model_name:
config = model_name
else:
raise ValueError(f"Unknown model name: {model_name}")
return config