From 2fba0fd54c612895d9a4e4d92decc1ec0044fb45 Mon Sep 17 00:00:00 2001 From: Cathal OBrien Date: Fri, 25 Oct 2024 08:52:42 +0000 Subject: [PATCH] option to check seq len at runtime --- src/anemoi/models/layers/attention.py | 63 +++++++++++++++++++-------- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index b5f78df..b7c15b9 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -72,25 +72,43 @@ def __init__( self.attention = attn_func self.resolution = resolution + + self.use_document_masking = False + self.is_attn_compiled = False + self.compile_at_runtime=False + self.use_flex_Attn=False if _FLEX_ATTENTION_AVAILABLE and (os.environ.get("FLEX_ATTN", "") != "" ): - LOGGER.info("Using Flex attn") - #LOGGER.info(f"self.num_heads {self.num_heads} self.embed_dim {self.embed_dim} self.head_dim {self.head_dim} self.dropout {self.dropout_p}") + self.use_flex_Attn = True - if window_size != None: - def sliding_window(b, h, q_idx, kv_idx): - return abs(q_idx - kv_idx) <= window_size - - seq_len=calculate_seq_len(resolution=self.resolution) - LOGGER.debug(f"grid points = {seq_len} for {self.resolution} resolution") - - # B and H can be None here because they are uniform, so the block mask can just be broadcast to these dims - #TODO check if B != 1, does it have to be set? - self.block_mask = create_block_mask(sliding_window, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,_compile=True) - self.attention = functools.partial(flex_attention, block_mask=self.block_mask) #Cache the block mask (attn blog post) - else: - self.attention = flex_attention - self.attention = compile(self.attention) #Must be compiled, otherwise entire seq_len^2 aray is materilised in memory -> OOM + if self.use_flex_Attn: + if not self.compile_at_runtime: + LOGGER.info("Using Flex attn") + #LOGGER.info(f"self.num_heads {self.num_heads} self.embed_dim {self.embed_dim} self.head_dim {self.head_dim} self.dropout {self.dropout_p}") + + if self.use_document_masking: + + def document_causal_mask(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = document_id[q_idx] == document_id[kv_idx] + return causal_mask & document_mask + + + elif window_size != None: + def sliding_window(b, h, q_idx, kv_idx): + return abs(q_idx - kv_idx) <= window_size + + seq_len=calculate_seq_len(resolution=self.resolution) + LOGGER.debug(f"grid points = {seq_len} for {self.resolution} resolution") + + # B and H can be None here because they are uniform, so the block mask can just be broadcast to these dims + #TODO check if B != 1, does it have to be set? + self.block_mask = create_block_mask(sliding_window, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,_compile=True) + self.attention = functools.partial(flex_attention, block_mask=self.block_mask) #Cache the block mask (attn blog post) + else: + self.attention = flex_attention + self.attention = compile(self.attention) #Must be compiled, otherwise entire seq_len^2 aray is materilised in memory -> OOM + self.is_attn_compiled=True if (self.is_causal): LOGGER.error("Causal not yet supported when using flex_attn (but this would be an easy add). Please rerun with 'is_causal = False'") @@ -121,7 +139,18 @@ def forward( ) for t in (query, key, value) ) - + + + if (self.use_flex_Attn and not self.is_attn_compiled): + LOGGER.info("Compiling Flex Attn at runtime") + seq_len=x.shape[0] + def sliding_window(b, h, q_idx, kv_idx): + return abs(q_idx - kv_idx) <= self.window_size[0] + self.block_mask = create_block_mask(sliding_window, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,_compile=True) + self.attention = functools.partial(flex_attention, block_mask=self.block_mask) #Cache the block mask (attn blog post) + self.attention = compile(self.attention) + self.is_attn_compiled = True + query = shard_heads(query, shapes=shapes, mgroup=model_comm_group) key = shard_heads(key, shapes=shapes, mgroup=model_comm_group) value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)