Skip to content

Commit

Permalink
加一个简单的token_healing
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 25, 2024
1 parent 7490802 commit 637e50d
Showing 1 changed file with 85 additions and 2 deletions.
87 changes: 85 additions & 2 deletions tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import ctypes;
import ctypes
import math
import os;
import os
import threading
import asyncio
import copy
import json
import math
from typing import Optional, Tuple, Union, List, Callable, Dict, Any;

import platform
Expand Down Expand Up @@ -90,6 +91,17 @@
fastllm_lib.get_max_input_len_llm_model.argtypes = [ctypes.c_int]
fastllm_lib.get_max_input_len_llm_model.restype = ctypes.c_int

def softmax(a):
max_value = a[0]
for i in a:
max_value = max(max_value, i)
sum = 0.0
for i in range(len(a)):
a[i] = math.exp(a[i] - max_value)
sum += a[i]
for i in range(len(a)):
a[i] /= sum

def set_cpu_threads(threads: int):
fastllm_lib.set_cpu_threads(threads);

Expand Down Expand Up @@ -618,6 +630,77 @@ def get_input_token_len(self, conversation: List[Dict[str, str]], add_generation
prompt = self.apply_chat_template(conversation)
return len(self.encode(prompt))

def token_healing(self,
prompt: str):
if (self.hf_tokenizer == None):
raise("Error: generate_text_token_healing need use transformers tokenizer")
tokenizer = self.hf_tokenizer
vocab_size = len(tokenizer.get_vocab())
if (not hasattr(self, "vocab_cache")):
self.vocab_cache = []
self.max_vocab_len = 0
for i in range(vocab_size):
self.vocab_cache.append(tokenizer.decode([i]))
self.max_vocab_len = max(self.max_vocab_len, len(self.vocab_cache[-1]))
cur_set = set(range(vocab_size))
logits = list(range(vocab_size))
array = (ctypes.c_float * (vocab_size * 4))(*logits)

common_token = tokenizer.encode(prompt)
for i in range(1, len(prompt) - 1):
cur_prompt = prompt[:-i]
right_prompt = prompt[len(cur_prompt):]
if right_prompt != "":
for id in list(cur_set):
if self.vocab_cache[id].find(right_prompt) == -1:
cur_set.remove(id)
if (len(cur_set) == 0):
break
cur_token = tokenizer.encode(cur_prompt)
for l in range(len(common_token)):
if (l >= len(cur_token) or cur_token[l] != common_token[l]):
common_token = common_token[:l]

last_ret = ""
max_prob = -1e5
cur_set = set(range(vocab_size))
for i in range(len(prompt) - 1):
cur_prompt = prompt[:-i] if i > 0 else prompt
right_prompt = prompt[len(cur_prompt):]
if right_prompt != "":
for id in list(cur_set):
if self.vocab_cache[id].find(right_prompt) == -1:
cur_set.remove(id)
if (len(cur_set) == 0):
break

cur_prob = 0.0
real_input = tokenizer.encode(cur_prompt)
for idx in range(len(common_token), len(real_input) + 1):
input = real_input[ : idx]
stop_token_len, stop_token_list = self.stop_token_ctypes(None)
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
1, False, 1, 1, 1, 1, True, stop_token_len, stop_token_list)
ret = fastllm_lib.fetch_response_logits_llm_model(self.model, handle, array)
out = list(array)[:vocab_size]
softmax(out)
if (idx < len(real_input)):
cur_prob += math.log(out[real_input[idx]])
max_id = -1
for i in cur_set:
if max_id == -1 or out[i] > out[max_id]:
if (self.vocab_cache[i].startswith(right_prompt)):
max_id = i
if max_id != -1:
input.append(max_id)
cur_prob += math.log(out[max_id])
if cur_prob > max_prob:
max_prob = max_prob
last_ret = tokenizer.decode(input)
#print(math.exp(cur_prob))
#print(tokenizer.decode(input))
return last_ret

def response_logits(self,
query: str,
history: List[Tuple[str, str]] = None,
Expand Down

0 comments on commit 637e50d

Please sign in to comment.