Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High GPU memory usage due to large intermediate tensor in calculate_radial_contributions in AimNet2 #315

Closed
Tracked by #316
wiederm opened this issue Nov 9, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@wiederm
Copy link
Member

wiederm commented Nov 9, 2024

Description:

We have identified a significant GPU memory consumption issue within the AIMNet2InteractionModule, specifically in the calculate_radial_contributions function. The problem arises from creating a large intermediate tensor with shape (number_of_pairs, G, F_atom), which can consume substantial GPU memory when dealing with large datasets or complex models.

Steps to Reproduce:

Use the AimNet2Core model with a dataset.
Monitor GPU memory usage during the forward pass.
Observe the spike in memory usage when calculate_radial_contributions is called.
Expected Behavior:

The model should efficiently compute radial contributions without excessive GPU memory consumption, allowing for larger batch sizes and more complex models.

Actual Behavior:

The model consumes a large amount of GPU memory due to the creation of the intermediate tensor avf_s with shape (number_of_pairs, G, F_atom), where:

  • number_of_pairs is the total number of atomic pairs.
  • G is the number of radial basis functions.
  • F_atom is the number of per-atom features.
    This high memory usage limits the scalability of the model and may lead to CUDA out of memory errors.
def calculate_radial_contributions(
    self,
    gs: Tensor,
    a_j: Tensor,
    number_of_atoms: int,
    idx_j: Tensor,
) -> Tensor:
    # Compute radial contributions
    avf_s = gs.unsqueeze(-1) * a_j.unsqueeze(1)  # Shape: (number_of_pairs, G, F_atom)
    avf_s = avf_s.sum(dim=1)  # Sum over G

    # Aggregate per atom
    radial_contributions = torch.zeros(
        (number_of_atoms, F_atom),
        device=avf_s.device,
        dtype=avf_s.dtype,
    )
    radial_contributions.index_add_(0, idx_j, avf_s)

    return radial_contributions

Analysis
The operation gs.unsqueeze(-1) * a_j.unsqueeze(1) creates an intermediate tensor of size (number_of_pairs, G, F_atom).
When number_of_pairs, G, and F_atom are large, this tensor consumes a significant amount of GPU memory.

Proposed Solution:
Use more memory-efficient operations, such as element-wise multiplication and mapping gs to match the dimension of a_j.

@wiederm wiederm self-assigned this Nov 9, 2024
@wiederm wiederm added the bug Something isn't working label Nov 9, 2024
@wiederm
Copy link
Member Author

wiederm commented Nov 10, 2024

Has been resolved in PR #316

@wiederm wiederm closed this as completed Nov 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants
@wiederm and others