-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgeneration.py
77 lines (70 loc) · 2.45 KB
/
generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from transformers import AutoModelForCausalLM, AutoTokenizer
import sys
sys.path.append('../')
import argparse
from utils.cafie_model import ScoringAlgo
parser = argparse.ArgumentParser(description="")
#Specify model name and hyperparameters
parser.add_argument(
"--alp",
default=0.99
)
parser.add_argument(
"--lmd",
default=100
)
parser.add_argument(
"--model_name",
default="gpt2"
)
args = parser.parse_args()
model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(args.model_name, add_prefix_space=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
list_1_words = []
list_2_words = []
list_3_words = []
word_path = "data\word_lists\list_1.txt" #Path to the word list 1
with open(word_path, "r") as f:
for line in f:
list_1_words.append(line[:-1])
word_path = "data\word_lists\list_2.txt" #Path to the word list 2
with open(word_path, "r") as f:
for line in f:
list_2_words.append(line[:-1])
word_path = "data\word_lists\list_3.txt" #Path to the word list 3
with open(word_path, "r") as f:
for line in f:
list_3_words.append(line[:-1])
def generator(prompt, max_new_tokens=10):
try:
runner = ScoringAlgo(
mdl=model,
model_name_path='g',
tokenizer=tokenizer,
_do_sdb=False, #add sdb prefix to the sentence
ratio=0.5, #0.5 for avg
scoring_function="tanh", #Other options- avg, jpdf, arctan, weight
threshold=0, #bias threshold for scoring
lmbda=int(args.lmd), #when using scoring, hyperparamenter for scoring function (increasing it increses debiasing but reduces LM score)
alpha_ratio=float(args.alp), #new probs = alpha*debiased_probs + (1-alpha)*vanilla_probs
softmax_temperature=1, #temperature for vanilla_probs in alpha
prompt=prompt,
context="",
l1=list_1_words,
l2=list_2_words,
l3=list_3_words,
sent_len=max_new_tokens,
batch_size=1,
max_seq_length=128,
bias_type=None,
words_to_ignore = []
)
_, _, gen_sent, _ = runner()
except:
gen_sent = prompt
print("Error in generation")
return gen_sent
prompt = "Two boys start a" #prompt for generation
print(generator(prompt))