Skip to content

Commit

Permalink
code quality
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Oct 13, 2023
1 parent b8f0346 commit 253e7a6
Show file tree
Hide file tree
Showing 11 changed files with 19 additions and 21 deletions.
2 changes: 1 addition & 1 deletion zeta/nn/architecture/attn_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def groupby_prefix_and_trim(prefix, d):
partial(string_begins_with, prefix), d
)
kwargs_without_prefix = dict(
map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
)
return kwargs_without_prefix, kwargs

Expand Down
2 changes: 1 addition & 1 deletion zeta/nn/architecture/auto_regressive_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def generate(
return out
else:
for _ in range(seq_len):
x = out[:, -self.max_seq_len :]
x = out[:, -self.max_seq_len:]

logits = self.net(x, **kwargs)[:, -1]

Expand Down
6 changes: 3 additions & 3 deletions zeta/nn/architecture/hierarchical_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def prophet(self, h, ids):
prophet_logits = rearrange(prophet_logits, "b n (c d) -> (b c) d n", c=c)

prophet_ids = F.pad(ids, (-1, c), value=self.ignore_index)
prophet_ids = tuple(prophet_ids[:, i : (seq_len + i)] for i in range(c))
prophet_ids = tuple(prophet_ids[:, i: (seq_len + i)] for i in range(c))
prophet_ids = torch.stack(prophet_ids, dim=1)
prophet_ids = rearrange(prophet_ids, "b c n -> (b c) n")

Expand All @@ -255,7 +255,7 @@ def recon(self, h, ids):
recon_logits = rearrange(recon_logits, "b n (c d) -> (b c) d n", c=c)

recon_ids = F.pad(ids, (c - 1, 0), value=self.ignore_index)
recon_ids = tuple(recon_ids[:, i : (seq_len + i)] for i in range(c))
recon_ids = tuple(recon_ids[:, i: (seq_len + i)] for i in range(c))
recon_ids = torch.stack(recon_ids, dim=1)
recon_ids = rearrange(recon_ids, "b c n -> (b c) n")

Expand Down Expand Up @@ -688,7 +688,7 @@ def generate(self, prompt, seq_len, temperature=1.0, filter_thres=0.9, **kwargs)
out = prompt

for _ in range(seq_len):
logits = self.forward(out[:, -self.seq_len :], **kwargs)[:, -1]
logits = self.forward(out[:, -self.seq_len:], **kwargs)[:, -1]
filtered_logits = top_k(logits, thres=filter_thres)
sample = gumbel_sample(filtered_logits, temperature=temperature)
sample = rearrange(sample, "b -> b 1")
Expand Down
2 changes: 1 addition & 1 deletion zeta/nn/architecture/local_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def generate(self, prime, seq_len, temperature=1.0, filter_thres=0.9, **kwargs):
out = prime

for _ in range(seq_len):
logits = self.forward(out[:, -self.max_seq_len :], **kwargs)
logits = self.forward(out[:, -self.max_seq_len:], **kwargs)
filtered_logits = top_k(logits[:, -1], thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sampled = torch.multinomial(probs, 1)
Expand Down
6 changes: 3 additions & 3 deletions zeta/nn/architecture/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def groupby_prefix_and_trim(prefix, d):
partial(string_begins_with, prefix), d
)
kwargs_without_prefix = dict(
map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
)
return kwargs_without_prefix, kwargs

Expand Down Expand Up @@ -1670,7 +1670,7 @@ def forward(
mask = pad_at_dim(mask, (num_mem, 0), dim=-1, value=True)

if self.shift_mem_down and exists(mems):
mems_l, mems_r = mems[: self.shift_mem_down], mems[self.shift_mem_down :]
mems_l, mems_r = mems[: self.shift_mem_down], mems[self.shift_mem_down:]
mems = [*mems_r, *mems_l]

if return_hiddens:
Expand Down Expand Up @@ -1709,7 +1709,7 @@ def forward(
else hiddens
)
new_mems = list(
map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems)
map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)
)
return out, new_mems

Expand Down
2 changes: 1 addition & 1 deletion zeta/nn/attention/multiquery_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def flash_attn_fn(

if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1) :]
query_padding_mask = key_padding_mask[:, -query.size(1):]

query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
query, query_padding_mask
Expand Down
2 changes: 1 addition & 1 deletion zeta/ops/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def multi_dim_cat(split_tensors: List[Tensor], num_splits: List[int]) -> Tensor:
for dim, split in reversed(list(enumerate(num_splits))):
if split > 0:
merged_tensor = [
torch.cat(merged_tensor[i : i + split], dim=dim)
torch.cat(merged_tensor[i: i + split], dim=dim)
for i in range(0, len(merged_tensor), split)
]
assert len(merged_tensor) == 1
Expand Down
4 changes: 1 addition & 3 deletions zeta/rl/ppo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
Expand Down Expand Up @@ -56,9 +57,6 @@ def ppo_step(
optimizer_policy.step()


import torch
import numpy as np

# Define the environment parameters
num_inputs = 4
num_outputs = 2
Expand Down
2 changes: 1 addition & 1 deletion zeta/training/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def group_texts(examples):
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
return result
Expand Down
4 changes: 2 additions & 2 deletions zeta/utils/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def groupby_prefix_and_trim(prefix, d):
partial(string_begins_with, prefix), d
)
kwargs_without_prefix = dict(
map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
)
return kwargs_without_prefix, kwargs

Expand Down Expand Up @@ -748,7 +748,7 @@ def look_around(x, backward=1, forward=0, pad_value=-1, dim=2):
padded_x = F.pad(x, (*dims, backward, forward), value=pad_value)

tensors = [
padded_x[:, ind : (ind + t), ...] for ind in range(forward + backward + 1)
padded_x[:, ind: (ind + t), ...] for ind in range(forward + backward + 1)
]
return torch.cat(tensors, dim=dim)

Expand Down
8 changes: 4 additions & 4 deletions zeta/zeta.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
zeta = """
_____________________________________
\____ /\_ _____/\__ ___/ _ \
/ / | __)_ | | / /_\ \
_____________________________________
\____ /\_ _____/\__ ___/ _ \
/ / | __)_ | | / /_\ \
/ /_ | \ | |/ | \
/_______ \/_______ / |____|\____|__ /
\/ \/ \/
\/ \/ \/
"""

0 comments on commit 253e7a6

Please sign in to comment.