Skip to content

Commit

Permalink
option to check seq len at runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
cathalobrien committed Oct 25, 2024
1 parent 845bfc7 commit 2fba0fd
Showing 1 changed file with 46 additions and 17 deletions.
63 changes: 46 additions & 17 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,25 +72,43 @@ def __init__(
self.attention = attn_func

self.resolution = resolution

self.use_document_masking = False
self.is_attn_compiled = False
self.compile_at_runtime=False
self.use_flex_Attn=False

if _FLEX_ATTENTION_AVAILABLE and (os.environ.get("FLEX_ATTN", "") != "" ):
LOGGER.info("Using Flex attn")
#LOGGER.info(f"self.num_heads {self.num_heads} self.embed_dim {self.embed_dim} self.head_dim {self.head_dim} self.dropout {self.dropout_p}")
self.use_flex_Attn = True

if window_size != None:
def sliding_window(b, h, q_idx, kv_idx):
return abs(q_idx - kv_idx) <= window_size

seq_len=calculate_seq_len(resolution=self.resolution)
LOGGER.debug(f"grid points = {seq_len} for {self.resolution} resolution")

# B and H can be None here because they are uniform, so the block mask can just be broadcast to these dims
#TODO check if B != 1, does it have to be set?
self.block_mask = create_block_mask(sliding_window, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,_compile=True)
self.attention = functools.partial(flex_attention, block_mask=self.block_mask) #Cache the block mask (attn blog post)
else:
self.attention = flex_attention
self.attention = compile(self.attention) #Must be compiled, otherwise entire seq_len^2 aray is materilised in memory -> OOM
if self.use_flex_Attn:
if not self.compile_at_runtime:
LOGGER.info("Using Flex attn")
#LOGGER.info(f"self.num_heads {self.num_heads} self.embed_dim {self.embed_dim} self.head_dim {self.head_dim} self.dropout {self.dropout_p}")

if self.use_document_masking:

def document_causal_mask(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = document_id[q_idx] == document_id[kv_idx]
return causal_mask & document_mask


elif window_size != None:
def sliding_window(b, h, q_idx, kv_idx):
return abs(q_idx - kv_idx) <= window_size

seq_len=calculate_seq_len(resolution=self.resolution)
LOGGER.debug(f"grid points = {seq_len} for {self.resolution} resolution")

# B and H can be None here because they are uniform, so the block mask can just be broadcast to these dims
#TODO check if B != 1, does it have to be set?
self.block_mask = create_block_mask(sliding_window, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,_compile=True)
self.attention = functools.partial(flex_attention, block_mask=self.block_mask) #Cache the block mask (attn blog post)
else:
self.attention = flex_attention
self.attention = compile(self.attention) #Must be compiled, otherwise entire seq_len^2 aray is materilised in memory -> OOM
self.is_attn_compiled=True

if (self.is_causal):
LOGGER.error("Causal not yet supported when using flex_attn (but this would be an easy add). Please rerun with 'is_causal = False'")
Expand Down Expand Up @@ -121,7 +139,18 @@ def forward(
)
for t in (query, key, value)
)



if (self.use_flex_Attn and not self.is_attn_compiled):
LOGGER.info("Compiling Flex Attn at runtime")
seq_len=x.shape[0]
def sliding_window(b, h, q_idx, kv_idx):
return abs(q_idx - kv_idx) <= self.window_size[0]
self.block_mask = create_block_mask(sliding_window, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,_compile=True)
self.attention = functools.partial(flex_attention, block_mask=self.block_mask) #Cache the block mask (attn blog post)
self.attention = compile(self.attention)
self.is_attn_compiled = True

query = shard_heads(query, shapes=shapes, mgroup=model_comm_group)
key = shard_heads(key, shapes=shapes, mgroup=model_comm_group)
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)
Expand Down

0 comments on commit 2fba0fd

Please sign in to comment.