You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
/code/tensorrt_llm/examples/llama# python ../run.py --max_output_len=20 --tokenizer_dir /code/tensorrt_llm/Llama-2-7b-hf/ --engine_dir /tmp/ws/models/compiled/engine/fp16/1-gpu --input_text "In Bash, how do I list all text files?" --use_py_session --output_logits_npy logits.npy --temperature 1.2 --repetition_penalty 1.4
Issue:
temperature not reflecting in output logits
We assume when we export output_logits, the logits are post all sampling, penalty layers. But we notice, with temperature alone in above command, there is no change (with vs without temperature) in the output_logits.
But when I add logs in penaltyKernel.cu, I clearly see input difference (i.e. * 1/temp) between inLogitsPtr[index] and outLogitsPtr[index].
inconsistency with repetition_penalty
When we pass repetition_penalty, I see final output_logits are differing. But the difference seems not correct. But when I log the intermediate values in the penaltyKernel.cu it is as expected and can see both repetition_penalty and temperature are properly applied.
For the given runtime command, i attached the full changes and log in this ticket.
If you refer at step = 0,
--------------------numOccurences > 0 ------ at 1 input logit: -7.039062
--------------------numOccurences > 0 ------ at 1 before rp logit: -5.865880
--------------------numOccurences > 0 ------ at 1 after rp logit: -8.212233
The math is right here: -7.039062 * 1/1.2 * 1.4 = -8.212233
But when i see output logits, I still see -7.0390625
Note: I am using py_session, as logged step info in decode_regular to see logs belongs to which step.
The difference in output_logits is as below:
(temp_rp_change_gen is with --temperature 1.2 --repetition_penalty 1.4)
for i in range(20):
different_indices = np.where(temp_rp_change_gen[0,0,i,:] != original_gen[0,0,i,:])[0]
print(f" at step {i} indices of count {different_indices.shape[0]} where the arrays differ: {different_indices}")
op: (vocab_size is 32000
at step 0 indices of count 0 where the arrays differ: []
at step 1 indices of count 0 where the arrays differ: []
at step 2 indices of count 31998 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 3 indices of count 31996 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 4 indices of count 31997 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 5 indices of count 31994 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 6 indices of count 31992 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 7 indices of count 32000 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 8 indices of count 31998 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 9 indices of count 31991 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 10 indices of count 31999 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 11 indices of count 31991 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 12 indices of count 31999 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 13 indices of count 31995 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 14 indices of count 31986 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 15 indices of count 32000 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 16 indices of count 31996 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 17 indices of count 31997 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 18 indices of count 31995 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
at step 19 indices of count 31991 where the arrays differ: [ 0 1 2 ... 31997 31998 31999]
I am not sure how there is no difference in logits at step 0 and 1, and after that almost all vocab_indicies are seeing difference.
Expected behavior
I would like to see final logits same as what is logged.
Hi @buddhapuneeth, thank you for taking time and reporting the issue. You are right, output logits are not modified with any of the penalty/temperature. This is an intentional behavior -- in context/generation logits we output only raw logits before sampling. There is a possibility to get logprob value after temp/penalties applied, but only for the selected token. Check out returnLogProbs in executor OutputConfig and the way to expose it from run.py.
@nekorobov but if you see the initial issue, i am seeing different logits when i modify repetition penalty, only with temperature i am seeing raw logits. The changes to logits is happening in same kernel, so what is the reason for different behavior?
yes, gen_logits with and without RP changed as output (generated) tokens different in both the cases from the second token, so from step 2 we noticed logits are completely different.
with temp = 0.1 or 1.2, the final output is remaining same and hence the logits remain same as no temperature. This is not clear for me, why the output is not changing with such extreme values (like 0.1) of temperature.
Actually, got my answer for Temp as well, as it is greedy decoding, the pick don't differ based on the temp. Thats why we see same output every time irrespective of the temp value.
Conclusions:
We see only raw logits when we expose output_logits.
In greedy decoding we see different results with and without RP, so logits will differ as generated sequence is different.
In greedy decoding we see same results with and without Temp, so logits will remain same as generated sequence is same.
System Info
HW
Data
Who can help?
@byshiue
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Setup:
compilation commands:
runtime command:
Issue:
temperature
not reflecting in output logitsWe assume when we export output_logits, the logits are post all sampling, penalty layers. But we notice, with
temperature
alone in above command, there is no change (with vs withouttemperature
) in the output_logits.But when I add logs in penaltyKernel.cu, I clearly see input difference (i.e. * 1/temp) between
inLogitsPtr[index]
andoutLogitsPtr[index]
.repetition_penalty
When we pass
repetition_penalty
, I see final output_logits are differing. But the difference seems not correct. But when I log the intermediate values in the penaltyKernel.cu it is as expected and can see bothrepetition_penalty
andtemperature
are properly applied.For the given runtime command, i attached the full changes and log in this ticket.
If you refer at step = 0,
The math is right here: -7.039062 * 1/1.2 * 1.4 = -8.212233
But when i see output logits, I still see
-7.0390625
Note: I am using py_session, as logged
step
info indecode_regular
to see logs belongs to which step.The difference in output_logits is as below:
(
temp_rp_change_gen
is with--temperature 1.2 --repetition_penalty 1.4
)op: (vocab_size is 32000
I am not sure how there is no difference in logits at step 0 and 1, and after that almost all vocab_indicies are seeing difference.
Expected behavior
I would like to see final logits same as what is logged.
actual behavior
Unable to relate the final logits computation
additional notes
Git diffs:
full logs
The text was updated successfully, but these errors were encountered: