Skip to content

Commit

Permalink
automatically use SyncBatchNorm if doing distributed training
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 1, 2023
1 parent dce5709 commit 243151b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
13 changes: 11 additions & 2 deletions enformer_pytorch/modeling_enformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.checkpoint import checkpoint_sequential

from einops import rearrange, reduce
Expand Down Expand Up @@ -53,6 +54,12 @@ def _round(x):
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))

# maybe sync batchnorm, for distributed training

def MaybeSyncBatchnorm(is_distributed = None):
is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d

# losses and metrics

def poisson_loss(pred, target):
Expand Down Expand Up @@ -204,9 +211,11 @@ def forward(self, x):

return x[:, -trim:trim]

def ConvBlock(dim, dim_out = None, kernel_size = 1):
def ConvBlock(dim, dim_out = None, kernel_size = 1, is_distributed = None):
batchnorm_klass = MaybeSyncBatchnorm(is_distributed = is_distributed)

return nn.Sequential(
nn.BatchNorm1d(dim),
batchnorm_klass(dim),
GELU(),
nn.Conv1d(dim, default(dim_out, dim), kernel_size, padding = kernel_size // 2)
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'enformer-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.8.5',
version = '0.8.6',
license='MIT',
description = 'Enformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 243151b

Please sign in to comment.