You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
Description
You can run the below code to reproduce prefill key value caching problem of minitron models.
I used "nvidia/Minitron-8B-Base" model.
Code
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load Minitron model and tokenizer from Hugging Face
model_name = "your-minitron-model-name" # Replace with the actual Minitron model name
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Set the model to evaluation mode
model.eval()
# Sample input text
input_text = "Hello, how are you?"
# Tokenize the input
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
# First forward pass (prefill phase)
with torch.no_grad():
outputs = model(input_ids, use_cache=True) # Set use_cache=True
logits = outputs.logits
past_key_values = outputs.past_key_values
# Check the output
print("Logits shape:", logits.shape)
print("Number of layers in past_key_values:", len(past_key_values))
print("Shape of keys and values in the first layer:")
print("Key shape:", past_key_values[0][0].shape)
print("Value shape:", past_key_values[0][1].shape)
# Add new input to test cache utilization
new_input_text = " What about you?"
new_input_ids = tokenizer(new_input_text, return_tensors="pt").input_ids
# Pass the new input along with the previous key-value cache
with torch.no_grad():
outputs_with_cache = model(new_input_ids, past_key_values=past_key_values, use_cache=True)
# Check results after caching
new_logits = outputs_with_cache.logits
new_past_key_values = outputs_with_cache.past_key_values
print("New logits shape:", new_logits.shape)
print("Number of layers in new past_key_values:", len(new_past_key_values))
Expected behavior
As-Is
Past key value is 'Nonetype', which means the key value caches are not cached.
The text was updated successfully, but these errors were encountered:
System Info
huggingface-hub-0.26.2
tokenizers-0.20.3
transformers-4.47.0.dev0
Python 3.10.12
Driver Version: 535.129.03
CUDA Version: 12.3
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Description
You can run the below code to reproduce prefill key value caching problem of minitron models.
I used "nvidia/Minitron-8B-Base" model.
Code
Expected behavior
As-Is
Past key value is 'Nonetype', which means the key value caches are not cached.
The text was updated successfully, but these errors were encountered: