From ef74613da57e041f56c3b1ba545d997f96c2f09d Mon Sep 17 00:00:00 2001 From: Kye Date: Fri, 9 Feb 2024 12:41:24 -0800 Subject: [PATCH] [V] --- pyproject.toml | 2 +- zeta/nn/attention/multiquery_attention.py | 9 ++++++--- zeta/nn/modules/sig_lip.py | 8 ++++---- zeta/nn/modules/xmoe/moe_layer.py | 12 +++++++++--- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3de58e21..c84c246e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zetascale" -version = "2.0.7" +version = "2.0.8" description = "Transformers at zeta scales" authors = ["Zeta Team "] license = "MIT" diff --git a/zeta/nn/attention/multiquery_attention.py b/zeta/nn/attention/multiquery_attention.py index 37808373..c9be52f9 100644 --- a/zeta/nn/attention/multiquery_attention.py +++ b/zeta/nn/attention/multiquery_attention.py @@ -300,9 +300,12 @@ def flash_attn_fn( key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) query_padding_mask = key_padding_mask[:, -query.size(1) :] - query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = ( - bert_padding.unpad_input(query, query_padding_mask) - ) + ( + query_unpad, + indices_q, + cu_seqlens_q, + max_seqlen_q, + ) = bert_padding.unpad_input(query, query_padding_mask) query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=heads) key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input( diff --git a/zeta/nn/modules/sig_lip.py b/zeta/nn/modules/sig_lip.py index 17050242..609bf037 100644 --- a/zeta/nn/modules/sig_lip.py +++ b/zeta/nn/modules/sig_lip.py @@ -4,7 +4,6 @@ try: import torch.distributed.nn - from torch import distributed as dist has_distributed = True except ImportError: @@ -257,9 +256,10 @@ def forward( logit_bias, negative_only=True, ) - text_features_to_left, text_features_to_right = ( - text_features_recv - ) + ( + text_features_to_left, + text_features_to_right, + ) = text_features_recv if remainder: text_features_recv = neighbour_exchange_with_grad( diff --git a/zeta/nn/modules/xmoe/moe_layer.py b/zeta/nn/modules/xmoe/moe_layer.py index deed5f57..67f70cfb 100644 --- a/zeta/nn/modules/xmoe/moe_layer.py +++ b/zeta/nn/modules/xmoe/moe_layer.py @@ -219,9 +219,15 @@ def forward( reshaped_input_padding_mask = padded_input_padding_mask if has_tutel: - l_aux, self.metadata, C, E, indices_, locations_, gates_ = ( - self.gate(reshaped_input, reshaped_input_padding_mask) - ) + ( + l_aux, + self.metadata, + C, + E, + indices_, + locations_, + gates_, + ) = self.gate(reshaped_input, reshaped_input_padding_mask) S, M = reshaped_input.size(0), reshaped_input.size(1) if not hasattr(self, "_tutel_dispatcher"):