forked from ronghanghu/moco_v3_tpu
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathxla_sync_bn.py
36 lines (27 loc) · 1.07 KB
/
xla_sync_bn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torch import nn
from distributed import get_world_size, xla_all_reduce_sum_with_backward
try:
import torch_xla.core.xla_model as xm
except ImportError:
xm = None
class XLASyncBNTrainModeOnly(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim))
def forward(self, batch):
assert isinstance(batch, torch.Tensor) and batch.ndim == 2
local_mean = torch.mean(batch, dim=0)
local_sqr_mean = torch.mean(batch * batch, dim=0)
scale = 1.0 / get_world_size()
mean = xla_all_reduce_sum_with_backward(local_mean) * scale
sqr_mean = xla_all_reduce_sum_with_backward(local_sqr_mean) * scale
var = sqr_mean - mean.pow(2)
batch = (batch - mean) / torch.sqrt(var + self.eps)
batch = batch * self.weight + self.bias
return batch
def extra_repr(self) -> str:
return "dim={}, eps={}".format(self.dim, self.eps)