From 637e50d804187cf2d74e43e190693927212d30b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Thu, 25 Jul 2024 14:09:37 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=A0=E4=B8=80=E4=B8=AA=E7=AE=80=E5=8D=95?= =?UTF-8?q?=E7=9A=84token=5Fhealing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/fastllm_pytools/llm.py | 87 +++++++++++++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 2 deletions(-) diff --git a/tools/fastllm_pytools/llm.py b/tools/fastllm_pytools/llm.py index 851dd323..740dff9f 100644 --- a/tools/fastllm_pytools/llm.py +++ b/tools/fastllm_pytools/llm.py @@ -1,10 +1,11 @@ -import ctypes; +import ctypes import math -import os; +import os import threading import asyncio import copy import json +import math from typing import Optional, Tuple, Union, List, Callable, Dict, Any; import platform @@ -90,6 +91,17 @@ fastllm_lib.get_max_input_len_llm_model.argtypes = [ctypes.c_int] fastllm_lib.get_max_input_len_llm_model.restype = ctypes.c_int +def softmax(a): + max_value = a[0] + for i in a: + max_value = max(max_value, i) + sum = 0.0 + for i in range(len(a)): + a[i] = math.exp(a[i] - max_value) + sum += a[i] + for i in range(len(a)): + a[i] /= sum + def set_cpu_threads(threads: int): fastllm_lib.set_cpu_threads(threads); @@ -618,6 +630,77 @@ def get_input_token_len(self, conversation: List[Dict[str, str]], add_generation prompt = self.apply_chat_template(conversation) return len(self.encode(prompt)) + def token_healing(self, + prompt: str): + if (self.hf_tokenizer == None): + raise("Error: generate_text_token_healing need use transformers tokenizer") + tokenizer = self.hf_tokenizer + vocab_size = len(tokenizer.get_vocab()) + if (not hasattr(self, "vocab_cache")): + self.vocab_cache = [] + self.max_vocab_len = 0 + for i in range(vocab_size): + self.vocab_cache.append(tokenizer.decode([i])) + self.max_vocab_len = max(self.max_vocab_len, len(self.vocab_cache[-1])) + cur_set = set(range(vocab_size)) + logits = list(range(vocab_size)) + array = (ctypes.c_float * (vocab_size * 4))(*logits) + + common_token = tokenizer.encode(prompt) + for i in range(1, len(prompt) - 1): + cur_prompt = prompt[:-i] + right_prompt = prompt[len(cur_prompt):] + if right_prompt != "": + for id in list(cur_set): + if self.vocab_cache[id].find(right_prompt) == -1: + cur_set.remove(id) + if (len(cur_set) == 0): + break + cur_token = tokenizer.encode(cur_prompt) + for l in range(len(common_token)): + if (l >= len(cur_token) or cur_token[l] != common_token[l]): + common_token = common_token[:l] + + last_ret = "" + max_prob = -1e5 + cur_set = set(range(vocab_size)) + for i in range(len(prompt) - 1): + cur_prompt = prompt[:-i] if i > 0 else prompt + right_prompt = prompt[len(cur_prompt):] + if right_prompt != "": + for id in list(cur_set): + if self.vocab_cache[id].find(right_prompt) == -1: + cur_set.remove(id) + if (len(cur_set) == 0): + break + + cur_prob = 0.0 + real_input = tokenizer.encode(cur_prompt) + for idx in range(len(common_token), len(real_input) + 1): + input = real_input[ : idx] + stop_token_len, stop_token_list = self.stop_token_ctypes(None) + handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), + 1, False, 1, 1, 1, 1, True, stop_token_len, stop_token_list) + ret = fastllm_lib.fetch_response_logits_llm_model(self.model, handle, array) + out = list(array)[:vocab_size] + softmax(out) + if (idx < len(real_input)): + cur_prob += math.log(out[real_input[idx]]) + max_id = -1 + for i in cur_set: + if max_id == -1 or out[i] > out[max_id]: + if (self.vocab_cache[i].startswith(right_prompt)): + max_id = i + if max_id != -1: + input.append(max_id) + cur_prob += math.log(out[max_id]) + if cur_prob > max_prob: + max_prob = max_prob + last_ret = tokenizer.decode(input) + #print(math.exp(cur_prob)) + #print(tokenizer.decode(input)) + return last_ret + def response_logits(self, query: str, history: List[Tuple[str, str]] = None,