From b0351834b1754292430a12b79a5847b588588c0c Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 22 Aug 2023 19:10:19 -0400 Subject: [PATCH] init everything --- mkdocs.yml | 12 ++++++------ zeta/nn/__init__.py | 16 ++++------------ zeta/nn/architecture/transformer.py | 6 +++--- 3 files changed, 13 insertions(+), 21 deletions(-) diff --git a/mkdocs.yml b/mkdocs.yml index 338ba994..233afa32 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -75,12 +75,12 @@ nav: - Checklist: "checklist.md" - Hiring: "hiring.md" - Zeta: - - Overview: "zeta/index.md" - - nn: - - modules: - xpos: "zeta/workers/index.md" - - attention: - - FlashAttention: "zeta/nn/attention/flash_attention.md" + - Overview: "zeta/index.md" + - nn: + - modules: + xpos: "zeta/workers/index.md" + - attention: + - FlashAttention: "zeta/nn/attention/flash_attention.md" - Examples: - Overview: "examples/index.md" - Agents: diff --git a/zeta/nn/__init__.py b/zeta/nn/__init__.py index 7510df1f..0ca5d9d6 100644 --- a/zeta/nn/__init__.py +++ b/zeta/nn/__init__.py @@ -1,16 +1,8 @@ # Copyright (c) 2022 Agora # Licensed under The MIT License [see LICENSE for details] - - - - - - - - -# ENCODER/ DECODER - - ######### Attention -from zeta.nn.attention.multiquery_attention import MultiQueryAttention \ No newline at end of file +from zeta.nn.attention.multiquery_attention import MultiQueryAttention +from zeta.nn import * +from zeta.models import * +from zeta.training import * diff --git a/zeta/nn/architecture/transformer.py b/zeta/nn/architecture/transformer.py index 566b34a4..7b907991 100644 --- a/zeta/nn/architecture/transformer.py +++ b/zeta/nn/architecture/transformer.py @@ -773,10 +773,10 @@ def forward( max_attend_past_mask = dist > self.max_attend_past masks.append(max_attend_past_mask) - if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: - top, _ = dots.topk(self.sparse_topk, dim = -1) + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: # noqa: F821 + top, _ = dots.topk(self.sparse_topk, dim = -1) # noqa: F821 vk = rearrange(top[..., -1], '... -> ... 1') - sparse_topk_mask = dots < vk + sparse_topk_mask = dots < vk # noqa: F821 masks.append(sparse_topk_mask) if len(masks) > 0: