Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Heterogeneous Normalized Attention #153

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

JakobEliasWagner
Copy link
Collaborator

@JakobEliasWagner JakobEliasWagner commented Jul 24, 2024

Feature: Heterogeneous Normalized Attention

Description

This pull request introduces the implementation of the Heterogeneous Normalized Attention mechanism as described in the paper Hao et al., 2023.

The heterogeneous normalized attention block calculates the attention scores in these steps:

  1. normalize the query and key sequence first

$$\tilde{q}_i = Softmax(q_i)$$

$$\tilde{k}_i = Softmax(k_i)$$

  1. calculate the attention score without softmax

$$z_t = \sum_i \frac{\tilde{q}_t \tilde{k}_i}{\sum_j \tilde{q}_t \tilde{k}_j}v_i$$

This implementation is linear with respect to the sequence length.

We added a masking mechanism to the vanilla implementation suggested by Hao et al.

Which issue does this PR tackle?

  • Heterogeneous normalized attention is not implemented.

How does it solve the problem?

  • Implements HeterogeneousNormalizedAttention, a linear attention implementation.
  • Implements masking for HeterogeneousNormalizedAttention.

How are the changes tested?

  • Added 6 unit tests covering: initialization, shape projection, gradient flow, zero inputs, masked forwards, and correctness by masking a known tensor.

Checklist for Contributors

  • Scope: This PR tackles exactly one problem.
  • Conventions: The branch follows the feature/title-slug convention.
  • Conventions: The PR title follows the Bugfix: Title convention.
  • Coding style: The code passes all pre-commit hooks.
  • Documentation: All changes are well-documented.
  • Tests: New features are tested and all tests pass successfully.
  • Changelog: Updated CHANGELOG.md for new features or breaking changes.
  • Review: A suitable reviewer has been assigned.

Checklist for Reviewers:

  • The PR solves the issue it claims to solve and only this one.
  • Changes are tested sufficiently and all tests pass.
  • Documentation is complete and well-written.
  • Changelog has been updated, if necessary.

@JakobEliasWagner JakobEliasWagner marked this pull request as draft July 24, 2024 13:54
@JakobEliasWagner JakobEliasWagner self-assigned this Jul 24, 2024
@JakobEliasWagner JakobEliasWagner added the enhancement New feature or request label Jul 24, 2024
Copy link
Collaborator

@samuelburbulla samuelburbulla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, you're the expert

Comment on lines +21 to +23
$$\tilde{q}_i = Softmax(\frac{\exp(q_{i,j})}{\sum_j\exp(q_{i,j})}$$,
$$\tilde{k}_i = Softmax(\frac{\exp(k_{i,j})}{\sum_j\exp(k_{i,j})}$$, and then calculating the attention without
softmax using $$z_t=\sum_i \frac{\tilde{q}_t \cdot \tilde{k}_i}{\sum_j \tilde{q}_t \cdot \tilde{k}_j}\cdot v_i$$.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Math does not render well in docs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants