Skip to content

Commit

Permalink
init everything
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Aug 22, 2023
1 parent 2466f75 commit b035183
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 21 deletions.
12 changes: 6 additions & 6 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ nav:
- Checklist: "checklist.md"
- Hiring: "hiring.md"
- Zeta:
- Overview: "zeta/index.md"
- nn:
- modules:
xpos: "zeta/workers/index.md"
- attention:
- FlashAttention: "zeta/nn/attention/flash_attention.md"
- Overview: "zeta/index.md"
- nn:
- modules:
xpos: "zeta/workers/index.md"
- attention:
- FlashAttention: "zeta/nn/attention/flash_attention.md"
- Examples:
- Overview: "examples/index.md"
- Agents:
Expand Down
16 changes: 4 additions & 12 deletions zeta/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
# Copyright (c) 2022 Agora
# Licensed under The MIT License [see LICENSE for details]









# ENCODER/ DECODER


######### Attention
from zeta.nn.attention.multiquery_attention import MultiQueryAttention
from zeta.nn.attention.multiquery_attention import MultiQueryAttention
from zeta.nn import *
from zeta.models import *
from zeta.training import *
6 changes: 3 additions & 3 deletions zeta/nn/architecture/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,10 +773,10 @@ def forward(
max_attend_past_mask = dist > self.max_attend_past
masks.append(max_attend_past_mask)

if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
top, _ = dots.topk(self.sparse_topk, dim = -1)
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: # noqa: F821
top, _ = dots.topk(self.sparse_topk, dim = -1) # noqa: F821
vk = rearrange(top[..., -1], '... -> ... 1')
sparse_topk_mask = dots < vk
sparse_topk_mask = dots < vk # noqa: F821
masks.append(sparse_topk_mask)

if len(masks) > 0:
Expand Down

0 comments on commit b035183

Please sign in to comment.