Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request : Enhance Attention Mechanism for Multi-GPU Support #24

Open
viai957 opened this issue May 27, 2024 · 0 comments
Open

Comments

@viai957
Copy link

viai957 commented May 27, 2024

Is your feature request related to a problem? Please describe.
Yes, the current implementation of the DilatedAttention and FlashAttention modules in the Zeta repository does not support multi-GPU configurations effectively, particularly lacking in model parallelism and data parallelism capabilities. Specifically, FlashAttention is optimized for A100 GPUs, but I am equipped with 8 A10 GPUs and would like to leverage all available resources efficiently. This limitation restricts the scalability and speed of my deep learning tasks, particularly for large-scale sequence processing and attention mechanisms.

Describe the solution you'd like
I propose enhancing the DilatedAttention and FlashAttention classes to include support for both model parallelism and data parallelism. This update should include:

  • Automatic detection and utilization of multiple GPU architectures (beyond A100).
  • Implementation of data parallelism to distribute data across multiple GPUs, improving throughput and efficiency.
  • Integration of model parallelism where the model can be split across multiple GPUs to manage large models or balance load more effectively.
  • Support for distributed computing across multiple nodes, initially starting with a straightforward implementation and gradually scaling to more complex distributed systems.

Describe alternatives you've considered
An alternative could be the manual partitioning of tasks and managing CUDA devices at the application level, but this approach is less efficient and scalable. Utilizing existing frameworks like NVIDIA’s NCCL for communication in parallel processing might be considered if native support in the framework proves too complex to implement in the initial stages.

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant