-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathQA_ort.py
150 lines (120 loc) · 5.45 KB
/
QA_ort.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import onnxruntime as ort
import argparse
import numpy as np
from transformers import AutoTokenizer
from huggingface_hub import login
def padding_input(input, max_sequence_length):
if input.shape[1] >= max_sequence_length:
return input
input = np.concatenate([input, np.zeros(shape=(1,max_sequence_length-input.shape[1]), dtype=input.dtype)], axis=1)
return input
def padding_input_reverse(input, max_sequence_length):
if input.shape[1] >= max_sequence_length:
return input
input = np.concatenate([np.zeros(shape=(1,max_sequence_length-input.shape[1]), dtype=input.dtype), input], axis=1)
return input
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', type=str, help="The model to convert", default='microsoft/Phi-3-mini-4k-instruct')
parser.add_argument('-q', '--quantize', type=int, help='whether use quantized INT4 or INT8 model', default=0)
args = parser.parse_args()
try:
model = args.model.split('/')[1].replace('-', '_')
except:
model = args.model.replace('-', '_')
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(args.model)
if args.quantize == 0:
q_str = ''
elif args.quantize == 4:
q_str = '_INT4'
elif args.quantize == 8:
q_str = '_INT8'
model_path = 'logs/models/pad/'+ model + '/done/' + model +'_decoder_static_one_lm'+q_str+'_ex_v21_INT4_QDQ.onnx'
session = ort.InferenceSession(model_path)
input_names = []
output_names = []
for input in session.get_inputs():
input_names.append(input.name)
for output in session.get_outputs():
output_names.append(output.name)
if args.model == "microsoft/Phi-3-mini-4k-instruct":
num_layers = 32
second_dim = 32
forth_dim = 96
elif args.model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
num_layers = 22
second_dim = 4
forth_dim = 64
elif args.model == "Qwen/Qwen2-0.5B-Instruct":
num_layers = 24
second_dim = 2
forth_dim = 64
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "What is the result of one plus one?"},
]
prompt_list = []
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
prompt_list.append(prompt)
prompt_list.append('What is the result of multiplying the last result with two?')
prompt_list.append('What is the latest result minus three?')
prompt_list.append('Is the final result equal to one?')
max_sequence_length = 128
start_len = 0
print('start_len', start_len)
print("Prompt is: ", prompt)
tokens = []
sentence = []
input = {}
kv_cache = {f'past_key_values.{i}.decoder.key': np.zeros((1, second_dim, max_sequence_length, forth_dim), dtype=np.float32) for i in range(num_layers)}
kv_cache.update({f'past_key_values.{i}.decoder.value': np.zeros((1, second_dim, max_sequence_length, forth_dim), dtype=np.float32) for i in range(num_layers)})
input.update(kv_cache)
for question in prompt_list:
print('This question is: ', question)
input_data = tokenizer(question, return_tensors="pt")
input_ids = input_data.input_ids.numpy().astype(np.int32)
input_len = input_ids.shape[1]
position_ids = np.array([range(start_len, start_len+max_sequence_length)], dtype=np.int32)
attention_mask = None
if start_len == 0:
attention_mask = np.zeros((1,max_sequence_length), dtype=np.int32)
else:
attention_mask = np.ones((1,start_len), dtype=np.int32)
input["attention_mask"] = padding_input(padding_input_reverse(attention_mask, max_sequence_length), 2*max_sequence_length)
input["attention_mask"][0][max_sequence_length:max_sequence_length+input_len] = 1
input["input_ids"] = padding_input(input_ids, max_sequence_length)
input['position_ids'] = position_ids
outputs = session.run(None, input)
for layer in range(num_layers):
kv_cache[f'past_key_values.{layer}.decoder.key'] = outputs[2*layer+1]
kv_cache[f'past_key_values.{layer}.decoder.value'] = outputs[2*layer+2]
start_len += input_len
start_len = min(start_len,max_sequence_length)
input.update(kv_cache)
logits = outputs[0]
new_token = np.argmax(logits[0, input_len-1, :])
tokens.append(new_token)
print(tokenizer.decode(tokens, skip_special_tokens=False))
for i in range(max_sequence_length):
input_ids = np.zeros((1,1), dtype=input['input_ids'].dtype)
input_ids[0][0] = new_token
input['input_ids'] = padding_input(input_ids, max_sequence_length)
attention_mask = np.ones((1,start_len), dtype=input['attention_mask'].dtype)
input["attention_mask"] = padding_input(padding_input_reverse(attention_mask, max_sequence_length), 2*max_sequence_length)
input['attention_mask'][0][max_sequence_length] = 1
position_ids = np.array([range(start_len, start_len+max_sequence_length)], dtype=np.int32)
input["position_ids"] = position_ids
outputs = session.run(None, input)
for layer in range(num_layers):
kv_cache[f'past_key_values.{layer}.decoder.key'] = outputs[2*layer+1]
kv_cache[f'past_key_values.{layer}.decoder.value'] = outputs[2*layer+2]
input.update(kv_cache)
start_len += 1
start_len = min(start_len,max_sequence_length)
logits = outputs[0]
new_token = np.argmax(logits[0, 0, :])
tokens.append(new_token)
print(tokenizer.decode(tokens, skip_special_tokens=False))
if new_token == tokenizer.eos_token_id:
break
print('--------------------Done.-------------------------')