-
Notifications
You must be signed in to change notification settings - Fork 144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merge Search & State into new "Generator" type #39
Changes from all commits
4b00707
0080d99
c1e5947
151781d
d5e3477
aead6c9
97b54e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,48 +10,45 @@ | |
# model=og.Model("../../test_models/llama2-7b-fp32-cpu", device_type) | ||
#model=og.Llama_Model("../../test_models/llama2-7b-fp16-gpu/rank_0_Llama-2-7b-hf_decoder_merged_model_fp16.onnx", device_type) | ||
#model=og.Llama_Model("../../test_models/llama2-7b-int4-gpu/rank_0_Llama-2-7b-hf_decoder_merged_model_int4.onnx", device_type) | ||
model=og.Model("../../test_models/llama2-7b-chat-int4-gpu", device_type) | ||
model=og.Model("../test_models/llama2-7b-chat-int4-gpu", device_type) | ||
print("Model loaded") | ||
# tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf') | ||
tokenizer=model.CreateTokenizer() | ||
tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf') | ||
# tokenizer=model.CreateTokenizer() | ||
print("Tokenizer created") | ||
|
||
# Keep asking for input prompts in an loop | ||
while True: | ||
text = input("Input:") | ||
# input_tokens = tokenizer.encode(text, return_tensors='np') | ||
input_tokens = tokenizer.encode(text) | ||
input_tokens = tokenizer.encode(text, return_tensors='np') | ||
# input_tokens = tokenizer.encode(text) | ||
|
||
params=og.SearchParams(model) | ||
params.max_length = 128 | ||
params.input_ids = input_tokens | ||
|
||
search=params.CreateSearch() | ||
state=model.CreateState(search.GetSequenceLengths(), params) | ||
generator=og.Generator(model, params) | ||
|
||
print("Output:") | ||
|
||
print(text, end='', flush=True) | ||
while not search.IsDone(): | ||
search.SetLogits(state.Run(search.GetSequenceLength(), search.GetNextTokens())) | ||
while not generator.IsDone(): | ||
generator.ComputeLogits() | ||
|
||
# search.Apply_MinLength(1) | ||
# search.Apply_RepetitionPenalty(1.0) | ||
Comment on lines
37
to
38
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please update this also |
||
|
||
search.SampleTopP(0.7, 0.6) | ||
generator.AppendNextToken_TopP(0.7, 0.6) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Append doesn't look like a good name to me. generator actually generate next token. How about GenerateNextToken There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And can we make TopP a parameter instead of part of the name? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I named it that since you first compute the logits, then append the next token based on the logits. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a process of generate next token. Add or Append doesn't reflect the action, i think. |
||
|
||
print(tokenizer.decode([search.GetNextTokens().GetArray()[0]]), ' ', end='', flush=True) | ||
''' | ||
# print(tokenizer.decode([generator.GetNextTokens().GetArray()[0]]), ' ', end='', flush=True) | ||
# Print each token as we compute it, we have to do some work to get newlines & spaces to appear properly: | ||
word=tokenizer.convert_ids_to_tokens([search.GetNextTokens().GetArray()[0]])[0] | ||
word=tokenizer.convert_ids_to_tokens([generator.GetNextTokens().GetArray()[0]])[0] | ||
yufenglee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if word=='<0x0A>': | ||
word = '\n' | ||
if word.startswith('▁'): | ||
word = ' ' + word[1:] | ||
print(word, end='', flush=True) | ||
''' | ||
|
||
# Print sequence all at once vs as it's decoded: | ||
print(tokenizer.decode(search.GetSequence(0).GetArray())) | ||
print(tokenizer.decode(generator.GetSequence(0).GetArray())) | ||
print() | ||
print() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please also change SearchParams to GeneratorParams?