Skip to content
This repository has been archived by the owner on Aug 10, 2023. It is now read-only.

Commit

Permalink
merge June 2020 updates of transformer.edge, bump support to pytorch 1.5
Browse files Browse the repository at this point in the history
  • Loading branch information
liuqiuhui2015 committed Jun 12, 2020
1 parent d6546ed commit ce580f3
Show file tree
Hide file tree
Showing 21 changed files with 97 additions and 71 deletions.
2 changes: 1 addition & 1 deletion cnfg/ihyp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from utils.fmt.base import parse_none, parse_double_value_tuple

enable_residual_bias_default = not ease_optimization
enable_prev_ln_bias_default = enable_proj_bias_default = not ease_optimization

enable_ln_parameters = True

Expand Down
4 changes: 2 additions & 2 deletions modules/TA.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ class PositionwiseFF(PositionwiseFFBase):
# isize: input dimension
# hsize: hidden dimension

def __init__(self, isize, hsize=None, dropout=0.0, use_GeLU=use_adv_act_default):
def __init__(self, isize, norm_residual=norm_residual_default, **kwargs):

super(PositionwiseFF, self).__init__(isize, hsize, dropout, False, use_GeLU)
super(PositionwiseFF, self).__init__(isize, norm_residual=False, **kwargs)

def forward(self, x):

Expand Down
22 changes: 11 additions & 11 deletions modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class PositionwiseFF(nn.Module):
# isize: input dimension
# hsize: hidden dimension

def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_default, use_GeLU=use_adv_act_default, enable_bias=enable_residual_bias_default):
def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_default, use_GeLU=use_adv_act_default, enable_bias=enable_prev_ln_bias_default):

super(PositionwiseFF, self).__init__()

Expand Down Expand Up @@ -115,18 +115,18 @@ class MultiHeadAttn(nn.Module):
# sparsenorm: using sparse normer or standard softmax
# bind_qk: query and key can share a same linear transformation for the Reformer: The Efficient Transformer(https://arxiv.org/abs/2001.04451) paper.

def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, v_isize=None, enable_bias=enable_residual_bias_default, k_rel_pos=0, sparsenorm=False, bind_qk=False, xseql=cache_len_default):
def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, v_isize=None, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, k_rel_pos=0, sparsenorm=False, bind_qk=False, xseql=cache_len_default):

super(MultiHeadAttn, self).__init__()

self.attn_dim = hsize // num_head
self.hsize = self.attn_dim * num_head
self.num_head = num_head

self.query_adaptor = Linear(isize, self.hsize, bias=enable_bias)
self.query_adaptor = Linear(isize, self.hsize, bias=enable_proj_bias)
_k_isize = isize if k_isize is None else k_isize
self.key_adaptor = self.query_adaptor if bind_qk and isize == _k_isize else Linear(_k_isize, self.hsize, bias=enable_bias)
self.value_adaptor = Linear(_k_isize if v_isize is None else v_isize, self.hsize, bias=enable_bias)
self.key_adaptor = self.query_adaptor if bind_qk and isize == _k_isize else Linear(_k_isize, self.hsize, bias=enable_proj_bias)
self.value_adaptor = Linear(_k_isize if v_isize is None else v_isize, self.hsize, bias=enable_proj_bias)

self.outer = Linear(self.hsize, osize, bias=enable_bias)

Expand Down Expand Up @@ -255,15 +255,15 @@ def get_ext(self, npos):
# Accelerated MultiHeadAttn for self attention, use when Q == K == V
class SelfAttn(nn.Module):

def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=enable_residual_bias_default, k_rel_pos=use_k_relative_position, sparsenorm=False, xseql=cache_len_default):
def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, k_rel_pos=use_k_relative_position, sparsenorm=False, xseql=cache_len_default):

super(SelfAttn, self).__init__()

self.attn_dim = hsize // num_head
self.hsize = self.attn_dim * num_head
self.num_head = num_head

self.adaptor = Linear(isize, self.hsize * 3, bias=enable_bias)
self.adaptor = Linear(isize, self.hsize * 3, bias=enable_proj_bias)

self.outer = Linear(self.hsize, osize, bias=enable_bias)

Expand Down Expand Up @@ -340,17 +340,17 @@ def get_rel_pos(self, length):
# Accelerated MultiHeadAttn for cross attention, use when K == V
class CrossAttn(nn.Module):

def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, enable_bias=enable_residual_bias_default, sparsenorm=False):
def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, sparsenorm=False):

super(CrossAttn, self).__init__()

self.attn_dim = hsize // num_head
self.hsize = self.attn_dim * num_head
self.num_head = num_head

self.query_adaptor = Linear(isize, self.hsize, bias=enable_bias)
self.query_adaptor = Linear(isize, self.hsize, bias=enable_proj_bias)

self.kv_adaptor = Linear(isize if k_isize is None else k_isize, self.hsize * 2, bias=enable_bias)
self.kv_adaptor = Linear(isize if k_isize is None else k_isize, self.hsize * 2, bias=enable_proj_bias)

self.outer = Linear(self.hsize, osize, bias=enable_bias)

Expand Down Expand Up @@ -390,7 +390,7 @@ class ResidueCombiner(nn.Module):

# isize: input size of Feed-forward NN

def __init__(self, isize, ncomb=2, hsize=None, dropout=0.0, use_GeLU=use_adv_act_default, enable_bias=enable_residual_bias_default):
def __init__(self, isize, ncomb=2, hsize=None, dropout=0.0, use_GeLU=use_adv_act_default, enable_bias=enable_prev_ln_bias_default):

super(ResidueCombiner, self).__init__()

Expand Down
4 changes: 2 additions & 2 deletions modules/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def forward(self, inpute, mask=None):

class PositionwiseFF(PositionwiseFFBase):

def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_default, use_GeLU=use_adv_act_default, power=None):
def __init__(self, isize, power=None, **kwargs):

super(PositionwiseFF, self).__init__(isize, hsize, dropout, norm_residual, use_GeLU)
super(PositionwiseFF, self).__init__(isize, **kwargs)

self.noiser = None if power is None else Noiser(power)

Expand Down
4 changes: 2 additions & 2 deletions modules/rnncells.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class LSTMCell4RNMT(nn.Module):
# isize: input size of Feed-forward NN
# dropout: dropout over hidden units, disabling it and applying dropout to outputs (_out) in most cases

def __init__(self, isize, osize=None, dropout=0.0, use_GeLU=use_adv_act_default, enable_bias=enable_residual_bias_default):
def __init__(self, isize, osize=None, dropout=0.0, use_GeLU=use_adv_act_default, enable_bias=enable_prev_ln_bias_default):

super(LSTMCell4RNMT, self).__init__()

Expand Down Expand Up @@ -57,7 +57,7 @@ class GRUCell4RNMT(nn.Module):

# isize: input size of Feed-forward NN

def __init__(self, isize, osize=None, dropout=0.0, use_GeLU=use_adv_act_default, enable_bias=enable_residual_bias_default):
def __init__(self, isize, osize=None, dropout=0.0, use_GeLU=use_adv_act_default, enable_bias=enable_prev_ln_bias_default):

super(GRUCell4RNMT, self).__init__()

Expand Down
74 changes: 51 additions & 23 deletions parallel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.cuda.comm as comm
from utils.comm import secure_broadcast_coalesced

from torch.jit import ScriptModule
from torch._C import ScriptMethod
Expand Down Expand Up @@ -119,12 +120,24 @@ def zero_grad(self):
net.zero_grad()
self.ngradev = 0

def zero_replicas_grad(self):
def collect_gradients_func(self, func):

if self.ngradev > 1:
grads = comm.reduce_add_coalesced([[p.grad for p in filter_para_grad(func(net).parameters())] for net in self.nets[:self.ngradev]], self.output_device)
for mp, grad in zip(filter_para_grad(func(self.module).parameters()), grads):
mp.grad = grad

def zero_replicas_grad(self, func=None):

if self.nets is not None and self.ngradev > 1:
for net in self.nets[1:self.ngradev]:
for para in filter_para_grad(net.parameters()):
para.grad = None
if func is None:
for net in self.nets[1:self.ngradev]:
for para in filter_para_grad(net.parameters()):
para.grad = None
else:
for net in self.nets[1:self.ngradev]:
for para in filter_para_grad(func(net).parameters()):
para.grad = None

def reset_grad(self):

Expand Down Expand Up @@ -152,12 +165,13 @@ def forward(self, inputs, *targets, **kwargs):
# input should be already scatterd
# scattering the targets instead
if not self.device_ids:
return self.module(inputs, *targets, **kwargs)
return self.module(inputs[0], *targets, **kwargs)
targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
targets = clean_list(targets)
ngpu = len(targets)
if (len(self.device_ids) == 1) or (ngpu == 1):
return self.module(inputs[0], *targets[0], **kwargs[0])
if ngpu == 1:
_fwd_m = self.module if self.nets is None else self.nets[0]
return _fwd_m(inputs[0], *targets[0], **kwargs[0])
devices = self.device_ids[:ngpu]
replicas = self.replicate(self.module, devices) if self.nets is None else self.nets[:ngpu]
outputs = criterion_parallel_apply(replicas, inputs, targets, devices, kwargs)
Expand All @@ -179,8 +193,19 @@ def clear_gradient(para):
param_copies = comm.broadcast_coalesced(params, devices)

buffers = list(network.buffers())
buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
buffer_copies = comm.broadcast_coalesced(buffers, devices)
buffers_rg = []
buffers_not_rg = []
for buf in buffers:
if buf.requires_grad and not detach:
buffers_rg.append(clear_gradient(buf) if no_gradient else buf)
else:
buffers_not_rg.append(buf)

buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}

buffer_copies_rg = secure_broadcast_coalesced(buffers_rg, devices)
buffer_copies_not_rg = secure_broadcast_coalesced(buffers_not_rg, devices)

modules = list(network.modules())
module_copies = [[] for device in devices]
Expand All @@ -193,7 +218,9 @@ def clear_gradient(para):
if isinstance(module, ScriptModule):
# we have to initialize ScriptModule properly so that
# it works with pybind11
replica = ScriptModule()
replica = module._replicate_for_data_parallel()
replica._former_parameters = OrderedDict()
'''replica = ScriptModule()
attribute_names = set(entry[0] for entry in module._c._get_attributes())
Expand All @@ -203,7 +230,7 @@ def clear_gradient(para):
replica.__dict__[key] = module.__dict__[key]
for name, the_type, value in module._c._get_attributes():
if not name in module._buffers.keys():
replica._c._register_attribute(name, the_type, value)
replica._c._register_attribute(name, the_type, value)'''
else:
replica = module.__new__(type(module))
replica.__dict__ = module.__dict__.copy()
Expand All @@ -217,33 +244,34 @@ def clear_gradient(para):
for key, child in module._modules.items():
if child is None:
for j in range(num_replicas):
replica = module_copies[j][i]
replica._modules[key] = None
module_copies[j][i]._modules[key] = None
else:
module_idx = module_indices[child]
for j in range(num_replicas):
replica = module_copies[j][i]
replica._modules[key] = module_copies[j][module_idx]
module_copies[j][i]._modules[key] = module_copies[j][module_idx]
for key, param in module._parameters.items():
if param is None:
for j in range(num_replicas):
replica = module_copies[j][i]
replica._parameters[key] = None
module_copies[j][i]._parameters[key] = None
else:
param_idx = param_indices[param]
param_idx, _p_require_grad = param_indices[param], param.requires_grad
for j in range(num_replicas):
replica = module_copies[j][i]
replica._parameters[key] = param_copies[j][param_idx].requires_grad_(param.requires_grad)
module_copies[j][i]._parameters[key] = param_copies[j][param_idx].requires_grad_(_p_require_grad)
for key, buf in module._buffers.items():
if buf is None:
for j in range(num_replicas):
replica = module_copies[j][i]
replica._buffers[key] = None
else:
buffer_idx = buffer_indices[buf]
_p_require_grad = buf.requires_grad
if _p_require_grad:
buffer_copies = buffer_copies_rg
buffer_idx = buffer_indices_rg[buf]
else:
buffer_copies = buffer_copies_not_rg
buffer_idx = buffer_indices_not_rg[buf]
for j in range(num_replicas):
replica = module_copies[j][i]
replica._buffers[key] = buffer_copies[j][buffer_idx]
module_copies[j][i]._buffers[key] = buffer_copies[j][buffer_idx].requires_grad_(_p_require_grad)

for j in range(num_replicas):
for i, module in enumerate(modules):
Expand Down
6 changes: 3 additions & 3 deletions requirements.opt.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Cython>=0.29.14
Cython>=0.29.20
subword-nmt>=0.3.7
sacremoses>=0.0.38
Flask>=1.1.1
sacremoses>=0.0.43
Flask>=1.1.2
PyNLPIR>=0.6.0
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
tqdm>=4.41.0
torch>=1.1.0
tqdm>=4.46.1
torch>=1.5.0
h5py>=2.10.0
7 changes: 2 additions & 5 deletions tools/average_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,17 @@

import torch

from utils.base import mask_tensor_type
from utils.base import secure_type_map
from utils.h5serial import h5save, h5load

from cnfg.ihyp import *

def handle(srcfl, rsf):

type_map = {torch.float16: torch.float64, torch.float32: torch.float64, torch.uint8: torch.int64, torch.int8: torch.int64, torch.int16: torch.int64, torch.int32: torch.int64}
type_map[mask_tensor_type] = torch.int64

rsm = h5load(srcfl[0])

src_type = [para.dtype for para in rsm]
map_type = [type_map[para.dtype] if para.dtype in type_map else None for para in rsm]
map_type = [secure_type_map[para.dtype] if para.dtype in secure_type_map else None for para in rsm]
sec_rsm = [para if typ is None else para.to(typ) for para, typ in zip(rsm, map_type)]

nmodel = 1
Expand Down
2 changes: 1 addition & 1 deletion transformer/AvgDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt
# thus the fore path of the top-k candidate is pointed out
# _inds: indexes for the top-k candidate (bsize, beam_size)

_inds = (_inds / beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
_inds = (_inds // beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)

# select the corresponding translation history for the top-k candidate and update translation records
# trans: (bsize * beam_size, nquery) => (bsize * beam_size, nquery + 1)
Expand Down
11 changes: 6 additions & 5 deletions transformer/Decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a
_ahsize = isize if ahsize is None else ahsize
_fhsize = _ahsize * 4 if fhsize is None else fhsize

self.self_attn = SelfAttn(isize, _ahsize, isize, num_head, dropout=attn_drop, k_rel_pos=k_rel_pos)
self.cross_attn = CrossAttn(isize, _ahsize, isize, num_head, dropout=attn_drop)
self.self_attn = SelfAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos)
self.cross_attn = CrossAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop)

self.ff = PositionwiseFF(isize, _fhsize, dropout, norm_residual)
self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual)

self.layer_normer1 = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters)
self.layer_normer2 = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters)
Expand Down Expand Up @@ -403,7 +403,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt
# thus the fore path of the top-k candidate is pointed out
# _inds: indexes for the top-k candidate (bsize, beam_size)

_inds = (_inds / beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
_inds = (_inds // beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)

# select the corresponding translation history for the top-k candidate and update translation records
# trans: (bsize * beam_size, nquery) => (bsize * beam_size, nquery + 1)
Expand Down Expand Up @@ -716,7 +716,8 @@ def beam_decode_clip(self, inpute, src_pad_mask=None, beam_size=8, max_len=512,
# thus the fore path of the top-k candidate is pointed out
# _inds: indexes for the top-k candidate (bsize, beam_size)

_inds = (_inds / beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
# using "_inds / beam_size" in case old pytorch does not support "//" operation
_inds = (_inds // beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)

# select the corresponding translation history for the top-k candidate and update translation records
# trans: (bsize * beam_size, nquery) => (bsize * beam_size, nquery + 1)
Expand Down
2 changes: 1 addition & 1 deletion transformer/Doc/Para/Base/Decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def beam_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, beam

wds = _wds.view(bsizeb2).index_select(0, _tinds).view(real_bsize, 1)

_inds = (_inds / beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
_inds = (_inds // beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)

trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), 0) if fill_pad else wds), 1)

Expand Down
4 changes: 2 additions & 2 deletions transformer/Encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a
_ahsize = isize if ahsize is None else ahsize
_fhsize = _ahsize * 4 if fhsize is None else fhsize

self.attn = SelfAttn(isize, _ahsize, isize, num_head, dropout=attn_drop, k_rel_pos=k_rel_pos)
self.attn = SelfAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos)

self.ff = PositionwiseFF(isize, _fhsize, dropout, norm_residual)
self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual)

self.layer_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters)

Expand Down
2 changes: 1 addition & 1 deletion transformer/EnsembleAvgDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt
# thus the fore path of the top-k candidate is pointed out
# _inds: indexes for the top-k candidate (bsize, beam_size)

_inds = (_inds / beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
_inds = (_inds // beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)

# select the corresponding translation history for the top-k candidate and update translation records
# trans: (bsize * beam_size, nquery) => (bsize * beam_size, nquery + 1)
Expand Down
2 changes: 1 addition & 1 deletion transformer/EnsembleDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt
# thus the fore path of the top-k candidate is pointed out
# _inds: indexes for the top-k candidate (bsize, beam_size)

_inds = (_inds / beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
_inds = (_inds // beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)

# select the corresponding translation history for the top-k candidate and update translation records
# trans: (bsize * beam_size, nquery) => (bsize * beam_size, nquery + 1)
Expand Down
Loading

0 comments on commit ce580f3

Please sign in to comment.