forked from turboderp/exllama
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
973 additions
and
198 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# ignore __pycache__ folder | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,329 @@ | ||
import cuda_ext | ||
from model import ExLlama, ExLlamaCache | ||
from tokenizer import ExLlamaTokenizer | ||
from lora import ExLlamaLora | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
MAX_CACHED_STRINGS = 100 | ||
|
||
class ExLlamaAltGenerator: | ||
|
||
class Settings: | ||
|
||
temperature = 0.95 | ||
top_k = 40 # consider the most probable top_k samples, 0 to disable top_k sampling | ||
top_p = 0.65 # consider tokens up to a cumulative probabiltiy of top_p, 0.0 to disable top_p sampling | ||
min_p = 0.0 # Do not consider tokens with probability less than this | ||
typical = 0.0 # Locally typical sampling threshold, 0.0 to disable typical sampling | ||
|
||
token_repetition_penalty_max = 1.15 # Repetition penalty for most recent tokens | ||
token_repetition_penalty_sustain = -1 # No. most recent tokens to repeat penalty for, -1 to apply to whole context | ||
token_repetition_penalty_decay = 0 # Gradually decrease penalty over this many tokens | ||
|
||
disallowed_tokens: list[int] = None # List of tokens to inhibit, e.g. tokenizer.eos_token_id | ||
lora: ExLlamaLora = None # LoRA to apply when generating | ||
|
||
# Internal state | ||
|
||
model: ExLlama | ||
cache: ExLlamaCache | ||
tokenizer: ExLlamaTokenizer | ||
tokenizer_cache = {} | ||
|
||
settings: Settings | ||
stop_strings: list = [] | ||
stop_tokens: list = [] | ||
held_text: str = "" | ||
max_stop_tokens: int = 2 | ||
sequence_ids: torch.Tensor = None | ||
sequence_str: str = None | ||
remaining_tokens: int = 0 | ||
|
||
|
||
def __init__(self, model: ExLlama, tokenizer: ExLlamaTokenizer, cache: ExLlamaCache): | ||
|
||
self.model = model | ||
self.tokenizer = tokenizer | ||
self.cache = cache | ||
self.settings = ExLlamaAltGenerator.Settings() | ||
|
||
|
||
def cached_tokenize(self, text: str, encode_special_characters = False): | ||
|
||
if text in self.tokenizer_cache: | ||
return self.tokenizer_cache[text] | ||
|
||
while len(self.tokenizer_cache) >= MAX_CACHED_STRINGS: | ||
del self.tokenizer_cache[next(iter(self.tokenizer_cache))] # Always removes oldest entry, as of Python 3.7 | ||
|
||
new_enc = self.tokenizer.encode(text, encode_special_characters = encode_special_characters) | ||
self.tokenizer_cache[text] = new_enc | ||
return new_enc | ||
|
||
|
||
def get_num_tokens(self, text: str, encode_special_characters = False): | ||
|
||
return self.cached_tokenize(text, encode_special_characters = encode_special_characters).shape[-1] | ||
|
||
|
||
# Begin generating | ||
# | ||
# prompt: The input prompt. Will be tokenized and then truncated to the models max sequence length minus max_new_tokens | ||
# stop_conditions: List of strings or integer token IDs that will end the sequence | ||
# settings: ExLlamaAltGeneratorSettings | ||
# encode_special_characters: Set to true to tokenize "</s>" etc. | ||
|
||
def begin_stream(self, prompt: str, stop_conditions: list, max_new_tokens: int, gen_settings: Settings, encode_special_characters = False): | ||
|
||
assert isinstance(prompt, str), "ExLlamaAltGenerator does not support batched generation" | ||
|
||
# Tokenize prompt and limit length to allow prompt and (max) new tokens within max sequence length | ||
|
||
max_input_tokens = self.model.config.max_seq_len - max_new_tokens | ||
self.remaining_tokens = max_new_tokens | ||
|
||
input_ids = self.cached_tokenize(prompt, encode_special_characters) | ||
applied_input_ids = input_ids[:, -max_input_tokens:] | ||
self.sequence_str = self.tokenizer.decode(applied_input_ids)[0] if applied_input_ids.shape[0] < input_ids.shape[0] else prompt | ||
|
||
# Settings | ||
|
||
self.stop_strings = [] | ||
self.stop_tokens = [] | ||
for t in stop_conditions: | ||
if isinstance(t, int): self.stop_tokens += [t] | ||
elif isinstance(t, str): self.stop_strings += [t] | ||
else: raise ValueError("Unsupported type in stop_conditions") | ||
|
||
self.held_text = "" | ||
|
||
self.max_stop_tokens = 2 | ||
for ss in self.stop_strings: | ||
self.max_stop_tokens = max(self.max_stop_tokens, self.get_num_tokens(ss) + 2) | ||
|
||
self.settings = gen_settings | ||
|
||
# Start generation | ||
|
||
self.gen_begin_reuse(applied_input_ids, gen_settings) | ||
|
||
|
||
# Get the next chunk of text in the stream | ||
# | ||
# Returns stream_chunk: str, EOS: bool | ||
|
||
def stream(self): | ||
|
||
# Check total response length | ||
|
||
if self.remaining_tokens == 0: | ||
self.sequence_str += self.held_text | ||
return self.held_text, True | ||
|
||
self.remaining_tokens -= 1 | ||
|
||
# Decode the current tail end of the sequence | ||
|
||
old_tail = self.tokenizer.decode(self.sequence_ids[:, -self.max_stop_tokens:])[0] | ||
|
||
# Generate a single token and append to the sequence | ||
|
||
next_token = self.gen_single_token(self.settings) | ||
|
||
# End immediately if it was a stop token | ||
|
||
if next_token in self.stop_tokens: | ||
self.sequence_str += self.held_text | ||
return self.held_text, True | ||
|
||
# Decode the tail end of the sequence with the added token to get (actual) characters added | ||
|
||
new_tail = self.tokenizer.decode(self.sequence_ids[:, -(self.max_stop_tokens + 1):])[0] | ||
self.held_text += new_tail[len(old_tail):] | ||
|
||
# Hold text as long as it contains part of a stop string | ||
|
||
partial_ss = False | ||
for ss in self.stop_strings: | ||
|
||
# Check if held_text fully contains stop string | ||
|
||
position = self.held_text.find(ss) | ||
if position != -1: | ||
self.sequence_str += self.held_text[:position] | ||
return self.held_text[:position], True | ||
|
||
# Check for overlap between end of held_text and start of stop string | ||
|
||
overlap = 0 | ||
for j in range(1, min(len(self.held_text), len(ss)) + 1): | ||
if self.held_text[-j:] == ss[:j]: overlap = j | ||
if overlap > 0: partial_ss = True | ||
|
||
# If holding text because of a partial stop condition, return nothing but also EOS = False | ||
|
||
if partial_ss: | ||
return "", False | ||
|
||
# No stop condition, so return whatever has been generated | ||
|
||
stream_text = self.held_text | ||
self.held_text = "" | ||
self.sequence_str += stream_text | ||
return stream_text, False | ||
|
||
|
||
# Generate and return a full prompt completion. See begin_stream() above | ||
|
||
def generate(self, prompt: str, stop_conditions: list, max_new_tokens: int, gen_settings: Settings, encode_special_characters = False): | ||
|
||
self.begin_stream(prompt, stop_conditions, max_new_tokens, gen_settings, encode_special_characters) | ||
response = "" | ||
while True: | ||
chunk, eos = self.stream() | ||
response += chunk | ||
if eos: break | ||
|
||
return response | ||
|
||
|
||
# Begin generation | ||
|
||
def gen_begin(self, in_tokens, gen_settings): | ||
|
||
self.sequence_ids = in_tokens.clone() | ||
self.cache.current_seq_len = 0 | ||
self.model.forward(self.sequence_ids[:, :-1], self.cache, preprocess_only = True, lora = gen_settings.lora) | ||
|
||
|
||
def gen_begin_reuse(self, in_tokens, gen_settings): | ||
|
||
if self.sequence_ids is None or self.cache.current_seq_len == 0: | ||
self.gen_begin(in_tokens, gen_settings) | ||
return | ||
|
||
reuse = 0 | ||
while reuse < self.sequence_ids.shape[-1] and reuse < in_tokens.shape[-1] and self.sequence_ids[0, reuse] == in_tokens[0, reuse]: | ||
reuse += 1 | ||
|
||
if reuse < 2: | ||
self.gen_begin(in_tokens, gen_settings) | ||
return | ||
|
||
self.cache.current_seq_len = reuse - 1 | ||
self.sequence_ids = in_tokens[:, :reuse] | ||
|
||
if reuse < in_tokens.shape[-1]: self.gen_feed_tokens(in_tokens[:, reuse:], gen_settings) | ||
|
||
|
||
def gen_feed_tokens(self, in_tokens, gen_settings): | ||
|
||
if self.sequence_ids is None: | ||
self.gen_begin(in_tokens, gen_settings) | ||
return | ||
|
||
start = self.cache.current_seq_len | ||
self.sequence_ids = torch.cat((self.sequence_ids, in_tokens), dim = 1) | ||
|
||
self.model.forward(self.sequence_ids[:, start : -1], self.cache, preprocess_only = True, lora = gen_settings.lora) | ||
|
||
|
||
# Generate one token in current sequence | ||
|
||
def gen_single_token(self, gen_settings): | ||
|
||
# Simple sampling case: | ||
|
||
logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, lora = gen_settings.lora) | ||
token, _ = self.sample(logits, gen_settings) | ||
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1) | ||
return token | ||
|
||
|
||
def sample(self, logits, gen_settings): | ||
|
||
cuda_ext.ext_apply_rep_penalty_mask_cpu(self.sequence_ids, | ||
self.settings.token_repetition_penalty_max, | ||
self.settings.token_repetition_penalty_sustain, | ||
self.settings.token_repetition_penalty_decay, | ||
logits) | ||
|
||
logits[:, :, self.tokenizer.bos_token_id] = -10000.0 | ||
|
||
if logits.dim() == 3: logits = logits[0, -1, :] | ||
elif logits.dim() == 2: logits = logits[-1, :] | ||
else: raise ValueError("Bad logits dimension") | ||
|
||
# Disallow tokens | ||
|
||
if gen_settings.disallowed_tokens is not None: | ||
logits[gen_settings.disallowed_tokens] = float("-inf") | ||
|
||
# Base probabilities | ||
|
||
logits /= gen_settings.temperature | ||
logits += 1e-8 | ||
probs = torch.softmax(logits, dim = -1) | ||
|
||
# Top K | ||
|
||
if gen_settings.top_k == 0: | ||
top_probs, top_indices = torch.sort(probs, descending = True) | ||
else: | ||
top_probs, top_indices = torch.topk(probs, gen_settings.top_k) | ||
top_probs = F.normalize(top_probs, p = 1, dim = -1) | ||
|
||
# Top P | ||
|
||
if gen_settings.top_p > 0.0: | ||
|
||
num_top_p_probs = 0 | ||
cum_prob = top_probs[0].item() | ||
while True: | ||
num_top_p_probs += 1 | ||
if num_top_p_probs == top_probs.shape[-1]: break | ||
if top_probs[num_top_p_probs].item() < gen_settings.min_p: break | ||
cum_prob += top_probs[num_top_p_probs].item() | ||
if cum_prob > gen_settings.top_p: break | ||
|
||
top_probs = top_probs[:num_top_p_probs] | ||
top_probs = F.normalize(top_probs, p = 1, dim = -1) | ||
top_indices = top_indices[:num_top_p_probs] | ||
|
||
# Locally typical sampling | ||
|
||
if gen_settings.typical > 0.0: | ||
|
||
epsilon = 1e-10 | ||
log_probs = (top_probs + epsilon).log() | ||
neg_entropy = (top_probs * log_probs).sum() | ||
entropy_dev = (neg_entropy - log_probs).abs() | ||
_, entropy_dev_order = torch.sort(entropy_dev) | ||
|
||
top_probs = top_probs.gather(-1, entropy_dev_order) | ||
top_indices = top_indices.gather(-1, entropy_dev_order) | ||
|
||
num_typical_probs = 0 | ||
cum_prob = top_probs[0].item() | ||
while True: | ||
num_typical_probs += 1 | ||
if num_typical_probs == top_probs.shape[-1]: break | ||
cum_prob += top_probs[num_typical_probs].item() | ||
if cum_prob > gen_settings.typical: break | ||
|
||
top_probs = top_probs[:num_typical_probs] | ||
top_probs = F.normalize(top_probs, p = 1, dim = -1) | ||
top_indices = top_indices[:num_typical_probs] | ||
|
||
# Multinomial sampling from top_probs, kept in same order as top_indices | ||
|
||
sampled_ind = torch.multinomial(top_probs, 1) | ||
sampled_tokens = top_indices[sampled_ind] | ||
sampled_probs = top_probs[sampled_ind] # Return probs before second norm | ||
|
||
# if sampled_tokens.shape[0] > 1: | ||
# sampled_tokens, ind = sampled_tokens.sort() | ||
# sampled_probs = sampled_probs[ind] | ||
|
||
return sampled_tokens.unsqueeze(0), sampled_probs.unsqueeze(0) |
Oops, something went wrong.