Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency with penaltyKernels.cu #2486

Closed
2 of 4 tasks
buddhapuneeth opened this issue Nov 22, 2024 · 4 comments
Closed
2 of 4 tasks

Inconsistency with penaltyKernels.cu #2486

buddhapuneeth opened this issue Nov 22, 2024 · 4 comments
Assignees
Labels
bug Something isn't working triaged Issue has been triaged by maintainers

Comments

@buddhapuneeth
Copy link

System Info

HW

  • EC2: g6e.12x
  • GPU: L40S
  • CUDA: 550.127.05
  • Branch: commit 201135e (HEAD, tag: v0.13.0)

Data

  • Model: Llama-2-7b-hf

Who can help?

@byshiue

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Setup:

compilation commands:

python convert_checkpoint.py --model_dir /code/tensorrt_llm/Llama-2-7b-hf/ --output_dir /tmp/ws/models/compiled/chpt --dtype float16

trtllm-build --checkpoint_dir   /tmp/ws/models/compiled/chpt --output_dir /tmp/ws/models/compiled/engine/fp16/1-gpu --gemm_plugin float16 --gather_all_token_logits

runtime command:

/code/tensorrt_llm/examples/llama# python ../run.py --max_output_len=20 --tokenizer_dir /code/tensorrt_llm/Llama-2-7b-hf/ --engine_dir /tmp/ws/models/compiled/engine/fp16/1-gpu --input_text "In Bash, how do I list all text files?" --use_py_session --output_logits_npy logits.npy --temperature 1.2 --repetition_penalty 1.4

Issue:

  1. temperature not reflecting in output logits
    We assume when we export output_logits, the logits are post all sampling, penalty layers. But we notice, with temperature alone in above command, there is no change (with vs without temperature) in the output_logits.

But when I add logs in penaltyKernel.cu, I clearly see input difference (i.e. * 1/temp) between inLogitsPtr[index] and outLogitsPtr[index].

  1. inconsistency with repetition_penalty
    When we pass repetition_penalty, I see final output_logits are differing. But the difference seems not correct. But when I log the intermediate values in the penaltyKernel.cu it is as expected and can see both repetition_penalty and temperature are properly applied.

For the given runtime command, i attached the full changes and log in this ticket.
If you refer at step = 0,

--------------------numOccurences > 0 ------ at 1 input logit: -7.039062
--------------------numOccurences > 0 ------ at 1 before rp logit: -5.865880
--------------------numOccurences > 0 ------ at 1 after rp logit: -8.212233

The math is right here: -7.039062 * 1/1.2 * 1.4 = -8.212233

But when i see output logits, I still see -7.0390625

Note: I am using py_session, as logged step info in decode_regular to see logs belongs to which step.

The difference in output_logits is as below:
(temp_rp_change_gen is with --temperature 1.2 --repetition_penalty 1.4)

for i in range(20):
    different_indices = np.where(temp_rp_change_gen[0,0,i,:] != original_gen[0,0,i,:])[0]

    print(f" at step {i} indices of count {different_indices.shape[0]} where the arrays differ: {different_indices}")

op: (vocab_size is 32000

 at step 0 indices of count 0 where the arrays differ: []
 at step 1 indices of count 0 where the arrays differ: []
 at step 2 indices of count 31998 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 3 indices of count 31996 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 4 indices of count 31997 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 5 indices of count 31994 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 6 indices of count 31992 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 7 indices of count 32000 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 8 indices of count 31998 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 9 indices of count 31991 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 10 indices of count 31999 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 11 indices of count 31991 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 12 indices of count 31999 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 13 indices of count 31995 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 14 indices of count 31986 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 15 indices of count 32000 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 16 indices of count 31996 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 17 indices of count 31997 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 18 indices of count 31995 where the arrays differ: [    0     1     2 ... 31997 31998 31999]
 at step 19 indices of count 31991 where the arrays differ: [    0     1     2 ... 31997 31998 31999]

I am not sure how there is no difference in logits at step 0 and 1, and after that almost all vocab_indicies are seeing difference.

Expected behavior

I would like to see final logits same as what is logged.

actual behavior

Unable to relate the final logits computation

additional notes

Git diffs:


--- a/cpp/tensorrt_llm/kernels/penaltyKernels.cu
diff --git a/cpp/tensorrt_llm/kernels/penaltyKernels.cu b/cpp/tensorrt_llm/kernels/penaltyKernels.cu
index 84b1a66..579db5b 100644
diff --git a/cpp/tensorrt_llm/kernels/penaltyKernels.cu b/cpp/tensorrt_llm/kernels/penaltyKernels.cu
diff --git a/cpp/tensorrt_llm/kernels/penaltyKernels.cu b/cpp/tensorrt_llm/kernels/penaltyKernels.cu
index 84b1a66..579db5b 100644
--- a/cpp/tensorrt_llm/kernels/penaltyKernels.cu
+++ b/cpp/tensorrt_llm/kernels/penaltyKernels.cu
@@ -171,11 +171,14 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits,
                 SizeType32 numOccurences = penaltyWorkspace[index];
                 if (numOccurences > 0)
                 {
+                    printf("--------------------numOccurences > 0 ------ at %d input logit: %f \n", index, static_cast<float>(inLogitsPtr[index]));
+                    printf("--------------------numOccurences > 0 ------ at %d before rp logit: %f \n", index, logit);
                     // Repetition
                     if (repetitionPenalties != nullptr)
                     {
                         logit = logit < 0.0f ? logit * repetitionPenalty : logit / repetitionPenalty;
                     }
+                    printf("--------------------numOccurences > 0 ------ at %d after rp logit: %f \n", index, logit);
                     // Presence
                     if (presencePenalties != nullptr)
                     {
@@ -197,6 +200,11 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits,
             {
                 logit = static_cast<float>(MASK_VAL);
             }
+            // if (index < 5)
+            // {
+            //     printf("in logit: %f\n",  static_cast<float>(inLogitsPtr[index])); 
+            //     printf("out logit: %f\n",  logit); 
+            // }
             outLogitsPtr[index] = logit;
         }
         else
diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py
index f843772..3f8d41c 100755
--- a/tensorrt_llm/runtime/generation.py
+++ b/tensorrt_llm/runtime/generation.py
@@ -3500,7 +3500,7 @@ class GenerationSession(object):
 
         next_step_tensors = None
         for step in range(0, self.max_new_tokens):
-
+            print(f"at step -- {step} -------------------------------------")
             should_stop, next_step_tensors, tasks, context_lengths, host_context_lengths, attention_mask, context_logits, generation_logits, encoder_input_lengths = self.handle_per_step(
                 cache_indirections, step, batch_size, max_context_length,
                 beam_width, input_ids, hidden_states, scfg,
@@ -3618,7 +3618,7 @@ class GenerationSession(object):
 
         next_step_tensors = None
         for step in range(0, self.max_new_tokens):
-
+            print(f"at step -- {step} -------------------------------------")
             should_stop, next_step_tensors, tasks, context_lengths, host_context_lengths, attention_mask, context_logits, generation_logits, encoder_input_lengths = self.handle_per_step(
                 cache_indirections, step, batch_size, max_context_length,
                 beam_width, input_ids, hidden_states, scfg,

full logs

https://drive.google.com/file/d/1PqZ097uEbnamMgCE7DAG8toFwI-7RFJo/view?usp=sharing
@buddhapuneeth buddhapuneeth added the bug Something isn't working label Nov 22, 2024
@hello-11 hello-11 added the triaged Issue has been triaged by maintainers label Nov 25, 2024
@nekorobov
Copy link
Collaborator

Hi @buddhapuneeth, thank you for taking time and reporting the issue. You are right, output logits are not modified with any of the penalty/temperature. This is an intentional behavior -- in context/generation logits we output only raw logits before sampling. There is a possibility to get logprob value after temp/penalties applied, but only for the selected token. Check out returnLogProbs in executor OutputConfig and the way to expose it from run.py.

@nekorobov nekorobov self-assigned this Nov 25, 2024
@buddhapuneeth
Copy link
Author

@nekorobov but if you see the initial issue, i am seeing different logits when i modify repetition penalty, only with temperature i am seeing raw logits. The changes to logits is happening in same kernel, so what is the reason for different behavior?

@buddhapuneeth
Copy link
Author

Actually, I did further deep-dive and noticed:

  • yes, gen_logits with and without RP changed as output (generated) tokens different in both the cases from the second token, so from step 2 we noticed logits are completely different.
  • with temp = 0.1 or 1.2, the final output is remaining same and hence the logits remain same as no temperature. This is not clear for me, why the output is not changing with such extreme values (like 0.1) of temperature.

@buddhapuneeth
Copy link
Author

Actually, got my answer for Temp as well, as it is greedy decoding, the pick don't differ based on the temp. Thats why we see same output every time irrespective of the temp value.

Conclusions:

  • We see only raw logits when we expose output_logits.
  • In greedy decoding we see different results with and without RP, so logits will differ as generated sequence is different.
  • In greedy decoding we see same results with and without Temp, so logits will remain same as generated sequence is same.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants