Skip to content

Commit

Permalink
longnet class
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jul 12, 2023
1 parent 4d3ada3 commit c72e361
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions LongNet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,28 @@ def forward(self, text_tokens, **kwargs):
return self.decoder(model_input, passed_x=model_input)[0]


# class LongNet(Module):
# def __init__(self):
# super().__init__()

# self.model = LongNet(
# num_tokens = 16000, # number of tokens
# dim = (512, 256), # transformer model dimension (512 for coarsest, 256 for fine in this example)
# max_seq_len = (1024, 4), # sequence length for global and then local. this can be more than 2
# depth = (6, 4), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
# dim_head = 64, # dimension per head
# heads = 8, # number of attention heads
# flash_attn = True # use flash attention
# )

# def forward(self, text_tokens, temperature: int = None, filter_thres: int = None, **kwargs):
# sampled = self.model.generate(temperature=temperature, filter_thres=filter_thres)
# return sampled
class DilatedLongNet(Module):
def __init__(self):
super().__init__()

self.model = LongNet(
num_tokens = 16000, # number of tokens
dim = (512, 256), # transformer model dimension (512 for coarsest, 256 for fine in this example)
max_seq_len = (1024, 4), # sequence length for global and then local. this can be more than 2
depth = (6, 4), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
dim_head = 64, # dimension per head
heads = 8, # number of attention heads
flash_attn = True, # use flash attention
dilation_rate = 1, # dilation rate for DilatedAttention
segment_size = 0, # segment size for DilatedAttention
casual = False, # whether to use causal attention for DilatedAttention
use_xpos = False, # whether to use absolute positional embeddings for DilatedAttention
use_rel_pos_bias = False, # whether to use relative positional bias for DilatedAttention
distributed = False # whether to distribute attention for DilatedAttention
)

def forward(self, text_tokens, temperature: int = None, filter_thres: int = None, **kwargs):
sampled = self.model.generate(temperature=temperature, filter_thres=filter_thres)
return sampled


0 comments on commit c72e361

Please sign in to comment.