Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Report] Q cannot be reshaped correctly when model is loaded in 4bit #737

Open
po13on opened this issue Sep 28, 2024 · 4 comments
Open
Labels
bug Something isn't working needs-investigation Issues that need to be recreated, or investigated before work can be done

Comments

@po13on
Copy link

po13on commented Sep 28, 2024

Describe the bug
Query_input's shape is [batch, pos, n_heads, d_model], and the purpose of the code where the error occurred is to reshape query_input to [batch, pos, n_heads, d_head].
I found that the shape of output of bnb.matmul_4bit is still [batch, pos, n_heads, d_model] so it cannot be reshaped to [batch, pos, n_heads, d_head].

The reason for this error may be the following code in abstract_attention.py:

if self.cfg.load_in_4bit:
            nq = int((self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2)
            self.W_Q = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
            self.W_O = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
        else:
            self.W_Q = nn.Parameter(
                torch.empty(
                    self.cfg.n_heads,
                    self.cfg.d_model,
                    self.cfg.d_head,
                    dtype=self.cfg.dtype,
                )
            )
            self.W_O = nn.Parameter(
                torch.empty(
                    self.cfg.n_heads,
                    self.cfg.d_head,
                    self.cfg.d_model,
                    dtype=self.cfg.dtype,
                )
            )

When model is loaded in 4bit, the shape of matrix W_Q is [(self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2, 1] which leads to the unexpected shape of the output from the bnb.matmul_4bit function.
When model is not loaded in 4bit, the shape of matrix W_Q is [n_heads, d_model, d_head] which does nor trigger the bug mentioned above.

Code example
Eorro message:

File ~/python_library/TransformerLens/transformer_lens/components/abstract_attention.py:364, in AbstractAttention.calculate_qkv_matrices(self, query_input, key_input, value_input)
    339        if self.cfg.load_in_4bit:
    340            q = self.hook_q(
    341                # call bitsandbytes method to dequantize and multiply
    342                bnb.matmul_4bit(
    343                    query_input,
    344                    self.W_Q.t(),
    345                    bias=None,
    346                    quant_state=self.W_Q.quant_state,
-->347                ).reshape(
    348                    query_input.shape[0],
    349                    query_input.shape[1],
    350                    self.cfg.n_heads,
    351                    self.cfg.d_head,
    352                )
    353                + self.b_Q
RuntimeError: shape '[20, 22, 32, 128]' is invalid for input of size 57671680"

code:

with torch.inference_mode():
            with model.hooks(fwd_hooks=fwd_hooks_corrupted):
                _ = model(corrupted)

System Info
Describe the characteristic of your environment:

  • git clone
  • Linux
  • python 3.11.4

Checklist

  • [√] I have checked that there is no similar issue in the repo (required)
@po13on
Copy link
Author

po13on commented Sep 28, 2024

When I load the model in 4-bit and set model.cfg.use_split_qkv_input = True, this bug will be triggered.
Code Example:

model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, proxies=proxies,local_files_only=False, low_cpu_mem_usage=True, use_safetensors=False, load_in_4bit=True, torch_dtype=torch.float32, )
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = HookedTransformer.from_pretrained("llama-7b-hf", center_unembed=False, fold_ln=False, fold_value_biases=False, device='cuda', hf_model=model, tokenizer=tokenizer, hf_model_4bit=True, center_writing_weights=False,
)
model.cfg.use_split_qkv_input = True
model.generate("The capital of Germany is", max_new_tokens=2, temperature=0)

Error:

File ~/python_library/TransformerLens/transformer_lens/components/abstract_attention.py:195, in AbstractAttention.forward(self, query_input, key_input, value_input, past_kv_cache_entry, additive_attention_mask, attention_mask, position_bias)
    167 def forward(
    168     self,
    169     query_input: Union[
   (...)
    186     position_bias: Optional[Float[torch.Tensor, \"1 head_index pos kv_pos\"]] = None,
    187 ) -> Float[torch.Tensor, \"batch pos d_model\"]:
    188     \"\"\"
    189     shortformer_pos_embed is only used if self.cfg.positional_embedding_type == \"shortformer\", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
    190     past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
    191     additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
    192     attention_mask is the attention mask for padded tokens. Defaults to None.
    193     \"\"\"
--> 195     q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)
    197     if past_kv_cache_entry is not None:
    198         # Appends the new keys and values to the cached values, and automatically updates the cache
    199         kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)

File ~/python_library/TransformerLens/transformer_lens/components/abstract_attention.py:348, in AbstractAttention.calculate_qkv_matrices(self, query_input, key_input, value_input)
    339 if self.cfg.load_in_4bit:
    340     print('In calculate_qkv_matrices: query_input.shape =', query_input.shape)  # XD debug
    341     q = self.hook_q(
    342         # call bitsandbytes method to dequantize and multiply
    343         bnb.matmul_4bit(
    344             query_input,
    345             self.W_Q.t(),
    346             bias=None,
    347             quant_state=self.W_Q.quant_state,
--> 348         ).reshape(
    349             query_input.shape[0],
    350             query_input.shape[1],
    351             self.cfg.n_heads,
    352             self.cfg.d_head,
    353         )
    354         + self.b_Q
    355     )
    356 else:
    357     q = self.hook_q(attn_fn(query_input, self.W_Q, self.b_Q))

RuntimeError: shape '[1, 6, 32, 128]' is invalid for input of size 786432"

@bryce13950
Copy link
Collaborator

@po13on in order to investigate this further, I am going to need to see exactly the code you used to initialize TransformerLens. This bug could be a wide ranging bug, but more likely, it is a specific model causing the issue. I need to see the full block of code you ran to boot TransformerLens in an invalid state.

@bryce13950 bryce13950 added bug Something isn't working needs-information More information is needed from the issue creator before moving forward. labels Sep 30, 2024
@po13on
Copy link
Author

po13on commented Oct 8, 2024

@bryce13950 I'm sorry for providing incomplete code. The model I loaded is vicuna-7b. Below is the complete code

model_name = 'lmsys/vicuna-7b-v1.3'
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, proxies=proxies,local_files_only=False, low_cpu_mem_usage=True, use_safetensors=False, load_in_4bit=True, torch_dtype=torch.float32, )
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = HookedTransformer.from_pretrained("llama-7b-hf", center_unembed=False, fold_ln=False, fold_value_biases=False, device='cuda', hf_model=model, tokenizer=tokenizer, hf_model_4bit=True, center_writing_weights=False,
)
model.cfg.use_split_qkv_input = True
model.generate("The capital of Germany is", max_new_tokens=2, temperature=0)

@bryce13950 bryce13950 added needs-investigation Issues that need to be recreated, or investigated before work can be done and removed needs-information More information is needed from the issue creator before moving forward. labels Oct 8, 2024
@bryce13950
Copy link
Collaborator

No problem! Thanks for providing this. This should be enough for us to recreate it now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs-investigation Issues that need to be recreated, or investigated before work can be done
Projects
None yet
Development

No branches or pull requests

2 participants