Skip to content

Commit

Permalink
Merge pull request #2 from kyegomez/master
Browse files Browse the repository at this point in the history
Catching up 20240209 1710
  • Loading branch information
evelynmitchell authored Feb 10, 2024
2 parents 40a1a19 + 02f8569 commit 7701213
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "2.0.7"
version = "2.0.8"
description = "Transformers at zeta scales"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
Expand Down
9 changes: 6 additions & 3 deletions zeta/nn/attention/multiquery_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,12 @@ def flash_attn_fn(
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1) :]

query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = (
bert_padding.unpad_input(query, query_padding_mask)
)
(
query_unpad,
indices_q,
cu_seqlens_q,
max_seqlen_q,
) = bert_padding.unpad_input(query, query_padding_mask)
query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
Expand Down
8 changes: 4 additions & 4 deletions zeta/nn/modules/sig_lip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

try:
import torch.distributed.nn
from torch import distributed as dist

has_distributed = True
except ImportError:
Expand Down Expand Up @@ -257,9 +256,10 @@ def forward(
logit_bias,
negative_only=True,
)
text_features_to_left, text_features_to_right = (
text_features_recv
)
(
text_features_to_left,
text_features_to_right,
) = text_features_recv

if remainder:
text_features_recv = neighbour_exchange_with_grad(
Expand Down
24 changes: 15 additions & 9 deletions zeta/nn/modules/xmoe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ def forward(
device=input.device,
)
if input_padding_mask is not None:
padded_input_padding_mask[: input_shape[0], :] = (
input_padding_mask
)
padded_input_padding_mask[
: input_shape[0], :
] = input_padding_mask
else:
padded_input_padding_mask[: input_shape[0], :] = False
input_padding_mask = padded_input_padding_mask
Expand Down Expand Up @@ -211,17 +211,23 @@ def forward(
(expected_dim,), dtype=torch.bool, device=padded_input.device
)
if reshaped_input_padding_mask is not None:
padded_input_padding_mask[: reshaped_input_shape[0]] = (
reshaped_input_padding_mask
)
padded_input_padding_mask[
: reshaped_input_shape[0]
] = reshaped_input_padding_mask
else:
padded_input_padding_mask[: reshaped_input_shape[0]] = False
reshaped_input_padding_mask = padded_input_padding_mask

if has_tutel:
l_aux, self.metadata, C, E, indices_, locations_, gates_ = (
self.gate(reshaped_input, reshaped_input_padding_mask)
)
(
l_aux,
self.metadata,
C,
E,
indices_,
locations_,
gates_,
) = self.gate(reshaped_input, reshaped_input_padding_mask)
S, M = reshaped_input.size(0), reshaped_input.size(1)

if not hasattr(self, "_tutel_dispatcher"):
Expand Down

0 comments on commit 7701213

Please sign in to comment.