Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize calculate_radial_contributions to reduce GPU memory usage #316

Merged
merged 4 commits into from
Nov 9, 2024

Conversation

wiederm
Copy link
Member

@wiederm wiederm commented Nov 9, 2024

Pull Request Summary

This PR addresses the high GPU memory usage issue caused by the creation of a large intermediate tensor in the calculate_radial_contributions function of the AIMNet2InteractionModule. The proposed fix optimizes the computation to reduce memory consumption without affecting the model's performance.

The original implementation:

def calculate_radial_contributions(
    self,
    gs: Tensor,
    a_j: Tensor,
    number_of_atoms: int,
    idx_j: Tensor,
) -> Tensor:
    # Compute radial contributions
    avf_s = gs.unsqueeze(-1) * a_j.unsqueeze(1)  # Shape: (number_of_pairs, G, F_atom)
    avf_s = avf_s.sum(dim=1)  # Sum over G

    # Aggregate per atom
    radial_contributions = torch.zeros(
        (number_of_atoms, F_atom),
        device=avf_s.device,
        dtype=avf_s.dtype,
    )
    radial_contributions.index_add_(0, idx_j, avf_s)

    return radial_contributions

is changed to

def calculate_radial_contributions(
    self,
    gs: Tensor,
    a_j: Tensor,
    number_of_atoms: int,
    idx_j: Tensor,
) -> Tensor:
    # Map gs to match the dimension of a_j
    mapped_gs = self.gs_to_fatom(gs)  # Linear layer mapping: (number_of_pairs, G) -> (number_of_pairs, F_atom)

    # Element-wise multiplication without expanding dimensions
    avf_s = a_j * mapped_gs  # Shape: (number_of_pairs, F_atom)

    # Aggregate per atom
    radial_contributions = torch.zeros(
        (number_of_atoms, F_atom),
        device=avf_s.device,
        dtype=avf_s.dtype,
    )
    radial_contributions.index_add_(0, idx_j, avf_s)

    return radial_contributions

Key changes

  • modified calculate_radial_contributions to compute radial contributions without creating a large intermediate tensor.
  • replaced the original tensor operations with a more memory-efficient approach using a linear layer.
  • updated the calculation of self.number_of_input_features to reflect the correct dimensions.

Associated Issue(s)

Pull Request Checklist

  • Issue(s) raised/addressed and linked
  • Includes appropriate unit test(s)
  • Appropriate docstring(s) added/updated
  • Appropriate .rst doc file(s) added/updated
  • PR is ready for review

…ze (nr_of_pairs, F, G) with F number of atom features and G number of radial features. The generation of this internal representation can be avoided, which is addressed in this PR
@wiederm wiederm self-assigned this Nov 9, 2024
@wiederm wiederm merged commit 426171a into main Nov 9, 2024
5 of 6 checks passed
@wiederm wiederm deleted the dev-memory-aimnet2 branch November 9, 2024 22:37
@codecov-commenter
Copy link

codecov-commenter commented Nov 9, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 85.54%. Comparing base (cf5b7c3) to head (bac77c8).
Report is 5 commits behind head on main.

Additional details and impacted files

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants