Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the code of Differential Transformers training #1663

Open
mucunxie opened this issue Dec 2, 2024 · 6 comments
Open

the code of Differential Transformers training #1663

mucunxie opened this issue Dec 2, 2024 · 6 comments

Comments

@mucunxie
Copy link

mucunxie commented Dec 2, 2024

Model I am using (UniLM, MiniLM, LayoutLM ...):Differential Transformers
Thank you for your works! Can you provide you code of Differential Transformers training

@YTianZHU
Copy link
Contributor

YTianZHU commented Dec 3, 2024

Hi @mucunxie , the basic training code for DIFF is similar to the code provided at https://aka.ms/yoco , you can make a few changes and merge DIFF code into it.

You can also use other open-source code frameworks and plug DIFF into it by changing a few lines.

@Adamyangs
Copy link

Adamyangs commented Jan 7, 2025

Hello, @YTianZHU , I tried implementing a diff transformer using huggingface. But I can't reproduce your experimental results. In the experiment, the ordinary transformer performed better than the diff transformer. I trained about 10 B tokens of data. And adding RMSNorm in my experiment doesn't seem to have much benefit. I see that Kaiming initialization was used in the warehouse history, and modeling on Huggingface is mostly initialized with normal_. I'm not sure if it's caused by initialization. May I ask if there is anything I have overlooked that has resulted in poor performance? Here is the code I have modified.

class DiffAttention(nn.Module):
    def __init__(self, config: DiffV2Config, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.attention_dropout = config.attention_dropout
        self.rope_dim = self.head_dim // 2
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        self._init_rope()
        self.lambda_init = lambda_init_fn(self.layer_idx)
        self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.v_norm = RMSNorm(self.head_dim)
class DiffFlashAttention2(DiffAttention):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ):
        if "padding_mask" in kwargs:
            # overwrite attention_mask with padding_mask
            attention_mask = kwargs.pop("padding_mask")
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.rope_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.rope_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads * 2, self.rope_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

        # Because the input can be padded, the absolute sequence length depends on the max position id.
        rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
        cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        use_sliding_windows = (
            _flash_supports_window_size
            and getattr(self.config, "sliding_window", None) is not None
            and kv_seq_len > self.config.sliding_window
            and self.config.use_sliding_window
        )

        if not _flash_supports_window_size:
            logger.warning_once(
                "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
                " make sure to upgrade flash-attn library."
            )

        if past_key_value is not None:
            # Activate slicing cache only if the config has a value `sliding_windows` attribute
            cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
            if (
                getattr(self.config, "sliding_window", None) is not None
                and kv_seq_len > self.config.sliding_window
                and cache_has_contents
            ):
                slicing_tokens = 1 - self.config.sliding_window

                past_key = past_key_value[self.layer_idx][0]
                past_value = past_key_value[self.layer_idx][1]

                past_key = past_key[:, :, slicing_tokens:, :].contiguous()
                past_value = past_value[:, :, slicing_tokens:, :].contiguous()

                if past_key.shape[-2] != self.config.sliding_window - 1:
                    raise ValueError(
                        f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
                        f" {past_key.shape}"
                    )

                if attention_mask is not None:
                    attention_mask = attention_mask[:, slicing_tokens:]
                    attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)

            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        dropout_rate = 0.0 if not self.training else self.attention_dropout

        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in float16 just to be sure everything works as expected.
        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        # Reashape to the expected shape for Flash Attention
        query_states = query_states.transpose(1, 2).view(bsz, q_len, self.num_heads, 2, self.rope_dim)
        key_states = key_states.transpose(1, 2).view(bsz, q_len, self.num_heads, 2, self.rope_dim)
        value_states = value_states.transpose(1, 2).view(bsz, q_len, self.num_heads, 2, self.rope_dim)

        q1, q2 = query_states[:, :, :, 0], query_states[:, :, :, 1]
        k1, k2 = key_states[:, :, :, 0], key_states[:, :, :, 1]
        v1, v2 = value_states[:, :, :, 0], value_states[:, :, :, 1]

        attn11 = self._flash_attention_forward(
            q1,
            k1,
            v1,
            attention_mask,
            q_len,
            dropout=dropout_rate,
            use_sliding_windows=use_sliding_windows,
        )
        attn12 = self._flash_attention_forward(
            q1,
            k1,
            v2,
            attention_mask,
            q_len,
            dropout=dropout_rate,
            use_sliding_windows=use_sliding_windows,
        )
        attn1 = torch.cat([attn11, attn12], dim=-1)
        
        attn21 = self._flash_attention_forward(
            q2,
            k2,
            v1,
            attention_mask,
            q_len,
            dropout=dropout_rate,
            use_sliding_windows=use_sliding_windows,
        )
        attn22 = self._flash_attention_forward(
            q2,
            k2,
            v2,
            attention_mask,
            q_len,
            dropout=dropout_rate,
            use_sliding_windows=use_sliding_windows,
        )
        attn2 = torch.cat([attn21, attn22], dim=-1)

        lambda_1 = torch.exp(torch.sum(5 * self.lambda_q1 * 5 * self.lambda_k1)).type_as(q1)
        lambda_2 = torch.exp(torch.sum(5 * self.lambda_q2 * 5 *self.lambda_k2)).type_as(q1)
        lambda_full = lambda_1 - lambda_2 + self.lambda_init
        attn_output = attn1 - lambda_full * attn2
        # print(f"Layer {self.layer_idx} \t self.lambda_init \t {self.lambda_init} lambda_full\t{lambda_full} \t attn_output \t {attn_output.shape} {attn_output[0,0,0,0]}")
        attn_output = self.v_norm(attn_output)
        attn_output = attn_output * (1 - self.lambda_init)

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

@YTianZHU
Copy link
Contributor

YTianZHU commented Jan 7, 2025

@Adamyangs Hi, seems you use half head dimension for Diff. We suggest using same head dimension but half number of heads. For e.g., a Transformer has 16 heads and head dimension of 128, than corresponding Diff Transformer has 8 heads and 128 headdim for qk and 256 headdim for v.

@Adamyangs
Copy link

Adamyangs commented Jan 8, 2025

Thank you for your response. My setting seems like a fair method of comparison. If the Diff model only surpasses a regular transformer under the condition of having the same number of heads, could it be that this hyperparameter setting is not particularly well-suited for the regular transformer, which may be more favorable for the Diff Transformer? I would be interested to hear your thoughts on this issue.

Additionally, I have a question regarding the learning of lambda. You’ve designed a complex lambda update rule, but it appears that updating lambda doesn’t significantly affect the loss. Have you conducted any ablation experiments to investigate this?

And, I found an issue when implementing the settings you mentioned. When using gqa, if the number of kv heads remains the same, the diff transformer will have a larger cache. Therefore, should the number of kv heads be halved in the experiment?

@YTianZHU
Copy link
Contributor

YTianZHU commented Jan 8, 2025

@Adamyangs, Hi,

  1. We still take a Transformer with 16 heads and head dimension of 128 as an example. Seems your setting uses 64 head dimension for q1, k1, q2, k2, and 128 head dimension for v. Denote number of heads as number of v, then in your setting, there are 16 heads, same as baseline. For our recommended setting, we use 128 head dimension for q1, k1, q2, k2, and 256 head dimension for v. Then the number of heads is 8 for this setting. We recommend this setting because in most settings, there is more redundancy in number of heads then in head dimension.
    Like the example above, reducing dimension of qk to 64 would have a strong negative impact on performance, but reducing number of heads from 16 to 8 would not. Of course, if the corresponding Transformer itself has a large head dimension like 256, then your setting (splitting head dimension) probably works better. It depends on the specific setting.
  2. For the re-parameterization of lambda, we find the exponential initialization works better than directly multiplying lambda in our earlier experiments. The re-parameterization enhances the learning speed of lambda and align with other parameters.
  3. For gqa, as we comment here https://github.com/microsoft/unilm/blob/master/Diff-Transformer/multihead_flashdiff_1.py#L50, the number of kv heads should be set to half of baseline Transformer (just like what we do in MHA), and no kv cache will be increased. The head dimension of v is doubled.

@Adamyangs
Copy link

Hi, @YTianZHU
Thank you so much for your thoughtful and detailed responses to my questions. I truly appreciate the time and effort you took to assist me. Your insights have been incredibly helpful! By adjusting the settings as you suggested, I was able to successfully replicate the results of your experiment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants