From 62e8185b66506cbf91f3773c123536303b343959 Mon Sep 17 00:00:00 2001 From: GuoxiaWang Date: Thu, 10 Nov 2022 09:49:05 +0800 Subject: [PATCH] support branch parallel for evoformer --- train_monomer_demo_bp.sh | 26 +++++ unifold/config.py | 4 +- unifold/data/data_ops.py | 24 ++-- unifold/modules/evoformer.py | 220 +++++++++++++++++++++++++---------- 4 files changed, 197 insertions(+), 77 deletions(-) create mode 100755 train_monomer_demo_bp.sh diff --git a/train_monomer_demo_bp.sh b/train_monomer_demo_bp.sh new file mode 100755 index 0000000..57f5629 --- /dev/null +++ b/train_monomer_demo_bp.sh @@ -0,0 +1,26 @@ +ps -ef | grep "torch" | grep -v grep | awk '{print $2}' | xargs kill -9 +ps -ef | grep "unicore-train" | grep -v grep | awk '{print $2}' | xargs kill -9 +export MASTER_IP=10.67.228.15 +[ -z "${MASTER_IP}" ] && MASTER_IP=127.0.0.1 +[ -z "${MASTER_PORT}" ] && MASTER_PORT=12345 +[ -z "${n_gpu}" ] && n_gpu=$(nvidia-smi -L | wc -l) +[ -z "${PADDLE_TRAINERS_NUM}" ] && PADDLE_TRAINERS_NUM=1 +[ -z "${PADDLE_TRAINER_ID}" ] && PADDLE_TRAINER_ID=0 +export NCCL_ASYNC_ERROR_HANDLING=1 +export OMP_NUM_THREADS=1 +mkdir -p $1 +#n_gpu=4 +tmp_dir=`mktemp -d` +#model_name=model_init_af2 +model_name=model_1_af2 +python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT --nnodes=$PADDLE_TRAINERS_NUM --node_rank=$PADDLE_TRAINER_ID --master_addr=$MASTER_IP $(which unicore-train) ./example_data/ --user-dir unifold \ + --num-workers 8 --ddp-backend=no_c10d \ + --task af2 --loss af2 --arch af2 --model-name $model_name \ + --optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-6 --clip-norm 0.0 --per-sample-clip-norm 0.1 --allreduce-fp32-grad \ + --lr-scheduler exponential_decay --lr 1e-3 --warmup-updates 1000 --decay-ratio 0.95 --decay-steps 50000 --batch-size 1 \ + --update-freq 1 --seed 42 --tensorboard-logdir $1/tsb/ \ + --max-update 1000 --max-epoch 1 --log-interval 10 --log-format simple \ + --save-interval-updates 500 --validate-interval-updates 500 --keep-interval-updates 40 --no-epoch-checkpoints \ + --save-dir $1 --tmp-save-dir $tmp_dir --required-batch-size-multiple 1 --bf16 --ema-decay 0.999 --data-buffer-size 32 --bf16-sr --bp-degree 2 +rm -rf $tmp_dir +# --save-dir $1 --tmp-save-dir $tmp_dir --required-batch-size-multiple 1 --ema-decay 0.999 --data-buffer-size 32 --bp-degree 1 diff --git a/unifold/config.py b/unifold/config.py index c396973..c58bf65 100644 --- a/unifold/config.py +++ b/unifold/config.py @@ -316,7 +316,7 @@ def base_config(): "pair_dropout": 0.25, "inf": 1e9, "eps": 1e-10, - "outer_product_mean_first": False, + "outer_product_mean_pos": 'end', }, "enabled": True, }, @@ -336,7 +336,7 @@ def base_config(): "pair_dropout": 0.25, "inf": 1e9, "eps": 1e-10, - "outer_product_mean_first": False, + "outer_product_mean_pos": 'end', }, "structure_module": { "d_single": d_single, diff --git a/unifold/data/data_ops.py b/unifold/data/data_ops.py index 78dcd57..183506a 100644 --- a/unifold/data/data_ops.py +++ b/unifold/data/data_ops.py @@ -619,18 +619,18 @@ def get_pad_size(cur_size, multiplier=4): return max(multiplier, ((cur_size + multiplier - 1) // multiplier) * multiplier ) - if num_res is not None: - input_num_res = ( - protein["aatype"].shape[0] - if "aatype" in protein - else protein["msa_mask"].shape[1] - ) - if input_num_res != num_res: - num_res = get_pad_size(input_num_res, 4) - if "extra_msa_mask" in protein: - input_extra_msa_size = protein["extra_msa_mask"].shape[0] - if input_extra_msa_size != extra_msa_size: - extra_msa_size = get_pad_size(input_extra_msa_size, 8) + # if num_res is not None: + # input_num_res = ( + # protein["aatype"].shape[0] + # if "aatype" in protein + # else protein["msa_mask"].shape[1] + # ) + # if input_num_res != num_res: + # num_res = get_pad_size(input_num_res, 4) + # if "extra_msa_mask" in protein: + # input_extra_msa_size = protein["extra_msa_mask"].shape[0] + # if input_extra_msa_size != extra_msa_size: + # extra_msa_size = get_pad_size(input_extra_msa_size, 8) pad_size_map = { N_RES: num_res, N_MSA: msa_cluster_size, diff --git a/unifold/modules/evoformer.py b/unifold/modules/evoformer.py index 671054d..b74f015 100644 --- a/unifold/modules/evoformer.py +++ b/unifold/modules/evoformer.py @@ -26,6 +26,11 @@ from unicore.utils import checkpoint_sequential +import torch.distributed as dist +from unicore.distributed.comm_group import scg +from unicore.distributed import bp + + class EvoformerIteration(nn.Module): def __init__( self, @@ -40,7 +45,7 @@ def __init__( transition_n: int, msa_dropout: float, pair_dropout: float, - outer_product_mean_first: bool, + outer_product_mean_pos: bool, inf: float, eps: float, _is_extra_msa_stack: bool = False, @@ -48,7 +53,7 @@ def __init__( super(EvoformerIteration, self).__init__() self._is_extra_msa_stack = _is_extra_msa_stack - self.outer_product_mean_first = outer_product_mean_first + self.outer_product_mean_pos = outer_product_mean_pos self.msa_att_row = MSARowAttentionWithPairBias( d_msa=d_msa, @@ -120,80 +125,169 @@ def forward( msa_mask: torch.Tensor, pair_mask: torch.Tensor, msa_row_attn_mask: torch.Tensor, - msa_col_attn_mask: Optional[torch.Tensor], + msa_col_attn_mask: torch.Tensor, tri_start_attn_mask: torch.Tensor, tri_end_attn_mask: torch.Tensor, chunk_size: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if self.outer_product_mean_first: - z = residual( - z, self.outer_product_mean(m, mask=msa_mask, chunk_size=chunk_size) - ) + if scg.get_bp_world_size() > 1: + + # Note(GuoxiaWang): add zeros trigger the status of stop_gradient=False within recompute context. + z = z + torch.zeros_like(z) + m = m + torch.zeros_like(m) + + # # Note(GuoxiaWang): reduce the pair_act's gradient from msa branch and pair branch + if z.requires_grad: + z.register_hook(bp.all_reduce) + + if scg.get_bp_rank_in_group() == 0: + m = bias_dropout_residual( + self.msa_att_row, + m, + self.msa_att_row( + m, z=z, attn_mask=msa_row_attn_mask, chunk_size=chunk_size + ), + self.row_dropout_share_dim, + self.msa_dropout, + self.training, + ) + if self._is_extra_msa_stack: + m = residual(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)) + else: + m = bias_dropout_residual( + self.msa_att_col, + m, + self.msa_att_col(m, attn_mask=msa_col_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.msa_dropout, + self.training, + ) + m = residual(m, self.msa_transition(m, chunk_size=chunk_size)) + if self.outer_product_mean_pos == 'middle' or self.outer_product_mean_pos == 'end': + outer = self.outer_product_mean(m, mask=msa_mask, chunk_size=chunk_size) + + if scg.get_bp_rank_in_group() == 1: + + z = bias_gated_dropout_residual( + self.tri_mul_out, + z, + self.tri_mul_out(z, mask=pair_mask), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + ) + + z = bias_gated_dropout_residual( + self.tri_mul_in, + z, + self.tri_mul_in(z, mask=pair_mask), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + ) + + z = bias_dropout_residual( + self.tri_att_start, + z, + self.tri_att_start(z, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + ) + + z = bias_dropout_residual( + self.tri_att_end, + z, + self.tri_att_end(z, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.pair_dropout, + self.training, + ) + z = residual(z, self.pair_transition(z, chunk_size=chunk_size)) + outer = torch.zeros_like(z) + outer.requires_grad = z.requires_grad + # m = m.clone() + + m, z = bp.sync_evoformer_results(outer, m, z) + z = z.clone() + m = m.clone() - m = bias_dropout_residual( - self.msa_att_row, - m, - self.msa_att_row( - m, z=z, attn_mask=msa_row_attn_mask, chunk_size=chunk_size - ), - self.row_dropout_share_dim, - self.msa_dropout, - self.training, - ) - if self._is_extra_msa_stack: - m = residual(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)) else: + + if self.outer_product_mean_pos == 'first': + z = residual( + z, self.outer_product_mean(m, mask=msa_mask, chunk_size=chunk_size) + ) + m = bias_dropout_residual( - self.msa_att_col, + self.msa_att_row, m, - self.msa_att_col(m, attn_mask=msa_col_attn_mask, chunk_size=chunk_size), - self.col_dropout_share_dim, + self.msa_att_row( + m, z=z, attn_mask=msa_row_attn_mask, chunk_size=chunk_size + ), + self.row_dropout_share_dim, self.msa_dropout, self.training, ) - m = residual(m, self.msa_transition(m, chunk_size=chunk_size)) - if not self.outer_product_mean_first: - z = residual( - z, self.outer_product_mean(m, mask=msa_mask, chunk_size=chunk_size) + if self._is_extra_msa_stack: + m = residual(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)) + else: + m = bias_dropout_residual( + self.msa_att_col, + m, + self.msa_att_col(m, attn_mask=msa_col_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.msa_dropout, + self.training, + ) + m = residual(m, self.msa_transition(m, chunk_size=chunk_size)) + if self.outer_product_mean_pos == 'middle' or self.outer_product_mean_pos == 'end': + outer = self.outer_product_mean(m, mask=msa_mask, chunk_size=chunk_size) + + if self.outer_product_mean_pos == 'middle': + z = residual(z, outer) + + z = bias_gated_dropout_residual( + self.tri_mul_out, + z, + self.tri_mul_out(z, mask=pair_mask), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, ) - z = bias_gated_dropout_residual( - self.tri_mul_out, - z, - self.tri_mul_out(z, mask=pair_mask), - self.row_dropout_share_dim, - self.pair_dropout, - self.training, - ) + z = bias_gated_dropout_residual( + self.tri_mul_in, + z, + self.tri_mul_in(z, mask=pair_mask), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + ) - z = bias_gated_dropout_residual( - self.tri_mul_in, - z, - self.tri_mul_in(z, mask=pair_mask), - self.row_dropout_share_dim, - self.pair_dropout, - self.training, - ) + z = bias_dropout_residual( + self.tri_att_start, + z, + self.tri_att_start(z, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + ) - z = bias_dropout_residual( - self.tri_att_start, - z, - self.tri_att_start(z, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), - self.row_dropout_share_dim, - self.pair_dropout, - self.training, - ) + z = bias_dropout_residual( + self.tri_att_end, + z, + self.tri_att_end(z, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.pair_dropout, + self.training, + ) + z = residual(z, self.pair_transition(z, chunk_size=chunk_size)) - z = bias_dropout_residual( - self.tri_att_end, - z, - self.tri_att_end(z, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), - self.col_dropout_share_dim, - self.pair_dropout, - self.training, - ) - z = residual(z, self.pair_transition(z, chunk_size=chunk_size)) + if self.outer_product_mean_pos == 'end': + z = residual(z, outer) + # print(f'rank: {dist.get_rank()}, size: {m.size()}, is_extra_msa_stack: {self._is_extra_msa_stack}') return m, z @@ -213,7 +307,7 @@ def __init__( transition_n: int, msa_dropout: float, pair_dropout: float, - outer_product_mean_first: bool, + outer_product_mean_pos: bool, inf: float, eps: float, _is_extra_msa_stack: bool = False, @@ -239,7 +333,7 @@ def __init__( transition_n=transition_n, msa_dropout=msa_dropout, pair_dropout=pair_dropout, - outer_product_mean_first=outer_product_mean_first, + outer_product_mean_pos=outer_product_mean_pos, inf=inf, eps=eps, _is_extra_msa_stack=_is_extra_msa_stack, @@ -306,7 +400,7 @@ def __init__( transition_n: int, msa_dropout: float, pair_dropout: float, - outer_product_mean_first: bool, + outer_product_mean_pos: bool, inf: float, eps: float, **kwargs, @@ -325,7 +419,7 @@ def __init__( transition_n=transition_n, msa_dropout=msa_dropout, pair_dropout=pair_dropout, - outer_product_mean_first=outer_product_mean_first, + outer_product_mean_pos=outer_product_mean_pos, inf=inf, eps=eps, _is_extra_msa_stack=True,