Skip to content

Commit

Permalink
Update block.py added num_chunks for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
einrone authored Nov 7, 2024
1 parent e566510 commit 438899b
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#

import logging
import os
from abc import ABC
from abc import abstractmethod
from typing import Optional
Expand All @@ -32,6 +33,7 @@

LOGGER = logging.getLogger(__name__)

NUM_CHUNKS_INFERENCE = int(os.environ.get("ANEMOI_INFERENCE_NUM_CHUNKS", "1"))

class BaseBlock(nn.Module, ABC):
"""Base class for network blocks."""
Expand Down Expand Up @@ -492,7 +494,7 @@ def forward(
query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group)

# TODO: remove magic number
num_chunks = self.num_chunks if self.training else 4 # reduce memory for inference
num_chunks = self.num_chunks if self.training else NUM_CHUNKS_INFERENCE # reduce memory for inference

if num_chunks > 1:
edge_index_list = torch.tensor_split(edge_index, num_chunks, dim=1)
Expand Down

0 comments on commit 438899b

Please sign in to comment.