diff --git a/.github/workflows/release_nightly_on_schedule.yml b/.github/workflows/release_nightly_on_schedule.yml index 4125f333f301..072a943aef19 100644 --- a/.github/workflows/release_nightly_on_schedule.yml +++ b/.github/workflows/release_nightly_on_schedule.yml @@ -6,11 +6,13 @@ on: - cron: '0 0 * * 6' # release on every Sunday 00:00 UTC time jobs: - build-n-publish: + publish: if: github.repository == 'hpcaitech/ColossalAI' name: Build and publish Python 🐍 distributions 📦 to PyPI runs-on: ubuntu-latest timeout-minutes: 20 + outputs: + status: ${{ steps.publish.outcome }} steps: - uses: actions/checkout@v2 @@ -18,7 +20,9 @@ jobs: with: python-version: '3.8.14' - - run: NIGHTLY=1 python setup.py sdist build + - run: | + python .github/workflows/scripts/update_setup_for_nightly.py + python setup.py sdist build # publish to PyPI if executed on the main branch - name: Publish package to PyPI @@ -31,7 +35,7 @@ jobs: notify: name: Notify Lark via webhook - needs: build-n-publish + needs: publish runs-on: ubuntu-latest if: ${{ always() }} && github.repository == 'hpcaitech/ColossalAI' steps: @@ -62,4 +66,4 @@ jobs: REPO: ${{ github.repository }} RUN_ID: ${{ github.run_id }} WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} - STATUS: ${{ steps.publish.outcome }} + STATUS: ${{ needs.publish.outputs.status }} diff --git a/.github/workflows/release_test_pypi_before_merge.yml b/.github/workflows/release_test_pypi_before_merge.yml index 284ab4d1afb0..7af641fc3056 100644 --- a/.github/workflows/release_test_pypi_before_merge.yml +++ b/.github/workflows/release_test_pypi_before_merge.yml @@ -49,6 +49,6 @@ jobs: # we need to install the requirements.txt first # as test-pypi may not contain the distributions for libs listed in the txt file pip install -r requirements/requirements.txt - pip install --index-url https://test.pypi.org/simple/ colossalai==$VERSION + pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.python.org/pypi colossalai==$VERSION env: VERSION: ${{ steps.prep-version.outputs.version }} diff --git a/.github/workflows/scripts/update_setup_for_nightly.py b/.github/workflows/scripts/update_setup_for_nightly.py new file mode 100644 index 000000000000..d8a3087ef54e --- /dev/null +++ b/.github/workflows/scripts/update_setup_for_nightly.py @@ -0,0 +1,34 @@ +from datetime import datetime + + +def open_setup_file(): + with open("setup.py", "r") as f: + file_lines = f.readlines() + return file_lines + + +def replace_nightly_package_info(file_lines): + version = datetime.today().strftime("%Y.%m.%d") + package_name = "colossalai-nightly" + + for idx, line in enumerate(file_lines): + if "version = get_version()" in line: + file_lines[idx] = f'version = "{version}"\n' + if 'package_name = "colossalai"' in line: + file_lines[idx] = f'package_name = "{package_name}"\n' + return file_lines + + +def write_setup_file(file_lines): + with open("setup.py", "w") as f: + f.writelines(file_lines) + + +def main(): + file_lines = open_setup_file() + file_lines = replace_nightly_package_info(file_lines) + write_setup_file(file_lines) + + +if __name__ == "__main__": + main() diff --git a/README.md b/README.md index 13757eece7db..442e6bbcd8cf 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Documentation | Examples | Forum | - Blog + Blog [![GitHub Repo stars](https://img.shields.io/github/stars/hpcaitech/ColossalAI?style=social)](https://github.com/hpcaitech/ColossalAI/stargazers) [![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml) @@ -398,10 +398,10 @@ pip install colossalai **Note: only Linux is supported for now.** -However, if you want to build the PyTorch extensions during installation, you can set `CUDA_EXT=1`. +However, if you want to build the PyTorch extensions during installation, you can set `BUILD_EXT=1`. ```bash -CUDA_EXT=1 pip install colossalai +BUILD_EXT=1 pip install colossalai ``` **Otherwise, CUDA kernels will be built during runtime when you actually need them.** @@ -429,7 +429,7 @@ By default, we do not compile CUDA/C++ kernels. ColossalAI will build them durin If you want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer): ```shell -CUDA_EXT=1 pip install . +BUILD_EXT=1 pip install . ``` For Users with CUDA 10.2, you can still build ColossalAI from source. However, you need to manually download the cub library and copy it to the corresponding directory. @@ -445,7 +445,7 @@ unzip 1.8.0.zip cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ # install -CUDA_EXT=1 pip install . +BUILD_EXT=1 pip install . ```

(back to top)

diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py index c0e257f54a07..e67e16231cc2 100644 --- a/applications/Chat/coati/dataset/sft_dataset.py +++ b/applications/Chat/coati/dataset/sft_dataset.py @@ -49,12 +49,13 @@ def _preprocess( max_length: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Preprocess the data by tokenizing.""" - sequences = [s + t for s, t in zip(sources, targets)] + sequences = [s + t + tokenizer.eos_token for s, t in zip(sources, targets)] sequences_token = tokenizer( - sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False ) + sources_token = tokenizer( - sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False ) assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently" @@ -65,7 +66,8 @@ def _preprocess( if tokenizer.padding_side == "right": # |prompt|completion|eos|pad| labels[i][:source_len] = IGNORE_INDEX - labels[i][-pad_len:] = IGNORE_INDEX + if pad_len>0: + labels[i][-pad_len:] = IGNORE_INDEX elif tokenizer.padding_side == "left": # |pad|prompt|completion|eos| labels[i][: pad_len + source_len] = IGNORE_INDEX diff --git a/applications/Chat/examples/train_sft.sh b/applications/Chat/examples/train_sft.sh index 0fb4da3d3ce8..b7d176847d9c 100755 --- a/applications/Chat/examples/train_sft.sh +++ b/applications/Chat/examples/train_sft.sh @@ -25,4 +25,4 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \ --accumulation_steps 8 \ --lr 2e-5 \ --max_datasets_size 512 \ - --max_epochs 1 + --max_epochs 1 \ No newline at end of file diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py index a2cfb2ef6264..327651f4e645 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py @@ -1,20 +1,16 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import numpy as np import os -import random from dataclasses import dataclass -from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable +from typing import Dict, Iterator, List, Optional, Sequence, Union import torch -from datasets import dataset_dict, load_from_disk +import torch.nn.functional as F from datasets import Dataset as HFDataset -from torch.distributed import ProcessGroup -from torch.distributed.distributed_c10d import _get_default_group -from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler +from datasets import dataset_dict, load_from_disk +from torch.utils.data import ConcatDataset, Dataset, DistributedSampler from transformers.tokenization_utils import PreTrainedTokenizer -import torch.nn.functional as F DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] PathType = Union[str, os.PathLike] @@ -62,6 +58,7 @@ class DataCollatorForSupervisedDataset(object): tokenizer: PreTrainedTokenizer max_length: int = 4096 ignore_index: int = -100 + padding: str = "max_length" def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: """ @@ -106,10 +103,11 @@ def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch batch_first=True, padding_value=self.ignore_index, ) # (bsz, max_len) - # pad to max - to_pad = self.max_length - input_ids.size(1) - input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) - labels = F.pad(labels, (0, to_pad), value=self.ignore_index) + if self.padding == "max_length": + # pad to max + to_pad = self.max_length - input_ids.size(1) + input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) + labels = F.pad(labels, (0, to_pad), value=self.ignore_index) elif self.tokenizer.padding_side == "left": reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids] reversed_input_ids = torch.nn.utils.rnn.pad_sequence( @@ -171,49 +169,3 @@ def __len__(self) -> int: def set_start_index(self, start_index: int) -> None: self.start_index = start_index - - -def setup_distributed_dataloader( - dataset: DatasetType, - batch_size: int = 1, - shuffle: bool = False, - seed: int = 1024, - drop_last: bool = False, - pin_memory: bool = False, - num_workers: int = 0, - collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None, - process_group: Optional[ProcessGroup] = None, - **kwargs, -) -> DataLoader: - """ - Setup dataloader for distributed training. - """ - _kwargs = kwargs.copy() - process_group = process_group or _get_default_group() - sampler = StatefulDistributedSampler( - dataset=dataset, - num_replicas=process_group.size(), - rank=process_group.rank(), - shuffle=shuffle, - seed=seed, - drop_last=drop_last, - ) - - # Deterministic dataloader - def seed_worker(worker_id: int) -> None: - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader( - dataset=dataset, - batch_size=batch_size, - sampler=sampler, - num_workers=num_workers, - collate_fn=collate_fn, - pin_memory=pin_memory, - drop_last=drop_last, - worker_init_fn=seed_worker, - **_kwargs, - ) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py index 1926ec78aba8..6c048c3b18cf 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py @@ -1,15 +1,15 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import math from types import MethodType from typing import Optional, Tuple import torch +import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func -from flash_attn.ops.rms_norm import rms_norm +from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, @@ -19,194 +19,334 @@ repeat_kv, ) +from colossalai.accelerator import get_accelerator from colossalai.logging import get_dist_logger logger = get_dist_logger() +if get_accelerator().name == "cuda": + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func + from flash_attn.ops.rms_norm import rms_norm -def _prepare_decoder_attention_mask( - self: LlamaModel, - attention_mask: torch.BoolTensor, - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, -) -> Optional[torch.Tensor]: - """ - Decoder attetion mask - """ - if past_key_values_length > 0 and attention_mask is not None: - attention_mask = torch.cat( - tensors=( - torch.full( - size=(input_shape[0], past_key_values_length), - fill_value=True, - dtype=attention_mask.dtype, - device=attention_mask.device, + def _prepare_decoder_attention_mask( + self: LlamaModel, + attention_mask: torch.BoolTensor, + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ) -> Optional[torch.Tensor]: + """ + Decoder attetion mask + """ + if past_key_values_length > 0 and attention_mask is not None: + attention_mask = torch.cat( + tensors=( + torch.full( + size=(input_shape[0], past_key_values_length), + fill_value=True, + dtype=attention_mask.dtype, + device=attention_mask.device, + ), + attention_mask, ), - attention_mask, - ), - dim=-1, - ) # (bsz, past_key_values_length + q_len) - if attention_mask is not None and torch.all(attention_mask): - return None # Faster - return attention_mask - - -def attention_forward( - self: LlamaAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. - """ - if output_attentions: - logger.warning( - "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, " - "return `None` instead." - ) - - bsz, q_len, _ = hidden_states.size() + dim=-1, + ) # (bsz, past_key_values_length + q_len) + if attention_mask is not None and torch.all(attention_mask): + return None # Faster + return attention_mask - if self.config.pretraining_tp > 1: - q_slicing, kv_slicing = ( - dim // self.config.pretraining_tp - for dim in ( - self.num_heads * self.head_dim, - self.num_key_value_heads * self.head_dim, + def attention_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. + """ + if output_attentions: + logger.warning( + "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, " + "return `None` instead." ) - ) # `Tuple[int, int]` - q_slices, k_slices, v_slices = ( - proj.weight.split(slicing, dim=0) - for proj, slicing in ( - (self.q_proj, q_slicing), - (self.k_proj, kv_slicing), - (self.v_proj, kv_slicing), + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + q_slicing, kv_slicing = ( + dim // self.config.pretraining_tp + for dim in ( + self.num_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ) + ) # `Tuple[int, int]` + q_slices, k_slices, v_slices = ( + proj.weight.split(slicing, dim=0) + for proj, slicing in ( + (self.q_proj, q_slicing), + (self.k_proj, kv_slicing), + (self.v_proj, kv_slicing), + ) + ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]] + q, k, v = ( + torch.cat( + [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)], + dim=-1, + ) + for slices in (q_slices, k_slices, v_slices) ) - ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]] + # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: + # (bsz, q_len, num_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim) + else: + q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) + # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: + # (bsz, q_len, num_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim) + + # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); + # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); + # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) q, k, v = ( - torch.cat( - [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)], - dim=-1, + states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) + for states, num_heads in ( + (q, self.num_heads), + (k, self.num_key_value_heads), + (v, self.num_key_value_heads), ) - for slices in (q_slices, k_slices, v_slices) - ) - # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: - # (bsz, q_len, num_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim) - else: - q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) - # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: - # (bsz, q_len, num_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim) - - # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); - # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); - # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) - q, k, v = ( - states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) - for states, num_heads in ( - (q, self.num_heads), - (k, self.num_key_value_heads), - (v, self.num_key_value_heads), ) - ) - kv_len = k.shape[-2] # initially, `kv_len` == `q_len` - past_kv_len = 0 - if past_key_value is not None: - # if `past_key_value` is not None, `kv_len` > `q_len`. - past_kv_len = past_key_value[0].shape[-2] - kv_len += past_kv_len - - # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) - cos, sin = self.rotary_emb(v, seq_len=kv_len) - # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim) - q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids) - if past_key_value is not None: - # reuse k, v, self_attention - k = torch.cat([past_key_value[0], k], dim=2) - v = torch.cat([past_key_value[1], v], dim=2) - - past_key_value = (k, v) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) - # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) - v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups) - # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) - - key_padding_mask = attention_mask - # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) - q, k, v = (states.transpose(1, 2) for states in (q, k, v)) - - if past_kv_len > 0: - q = torch.cat( - tensors=( - torch.full( - size=(bsz, past_kv_len, self.num_heads, self.head_dim), - fill_value=0.0, - dtype=q.dtype, - device=q.device, + kv_len = k.shape[-2] # initially, `kv_len` == `q_len` + past_kv_len = 0 + if past_key_value is not None: + # if `past_key_value` is not None, `kv_len` > `q_len`. + past_kv_len = past_key_value[0].shape[-2] + kv_len += past_kv_len + + # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) + cos, sin = self.rotary_emb(v, seq_len=kv_len) + # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim) + q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids) + if past_key_value is not None: + # reuse k, v, self_attention + k = torch.cat([past_key_value[0], k], dim=2) + v = torch.cat([past_key_value[1], v], dim=2) + + past_key_value = (k, v) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) + # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) + v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups) + # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) + + key_padding_mask = attention_mask + # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) + q, k, v = (states.transpose(1, 2) for states in (q, k, v)) + + if past_kv_len > 0: + q = torch.cat( + tensors=( + torch.full( + size=(bsz, past_kv_len, self.num_heads, self.head_dim), + fill_value=0.0, + dtype=q.dtype, + device=q.device, + ), + q, ), - q, - ), - dim=1, - ) # (bsz, past_kv_len + q_len, num_heads, head_dim) - - if key_padding_mask is None: - # (bsz, past_kv_len + q_len, num_heads, head_dim) - output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) - output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim) - else: - q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) - kv, _, cu_kv_lens, max_kv_len = unpad_input( - hidden_states=torch.stack(tensors=(k, v), dim=2), - attention_mask=key_padding_mask, - ) - output_unpad = flash_attn_varlen_kvpacked_func( - q=q, - kv=kv, - cu_seqlens_q=cu_q_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_q_len, - max_seqlen_k=max_kv_len, - dropout_p=0.0, - softmax_scale=None, - causal=True, - ) - output = pad_input( - hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), - indices=indices, - batch=bsz, - seqlen=past_kv_len + q_len, - ) # (bsz, past_kv_len + q_len, num_heads * head_dim) - - if past_kv_len > 0: - # Strip off the zero query outputs. - output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) - output = self.o_proj(output) # (bsz, q_len, hidden_size) - return output, None, past_key_value - - -def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: - """ - Formard function for RMS Norm - """ - return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) - - -def replace_with_flash_attention(model: LlamaForCausalLM) -> None: - for name, module in model.named_modules(): - if isinstance(module, LlamaAttention): - module.forward = MethodType(attention_forward, module) - if isinstance(module, LlamaModel): - module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) - if isinstance(module, LlamaRMSNorm): - module.forward = MethodType(rms_norm_forward, module) + dim=1, + ) # (bsz, past_kv_len + q_len, num_heads, head_dim) + + if key_padding_mask is None: + # (bsz, past_kv_len + q_len, num_heads, head_dim) + output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) + output = rearrange( + output, pattern="... h d -> ... (h d)" + ) # (bsz, past_kv_len + q_len, num_heads * head_dim) + else: + q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) + kv, _, cu_kv_lens, max_kv_len = unpad_input( + hidden_states=torch.stack(tensors=(k, v), dim=2), + attention_mask=key_padding_mask, + ) + output_unpad = flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_q=cu_q_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_q_len, + max_seqlen_k=max_kv_len, + dropout_p=0.0, + softmax_scale=None, + causal=True, + ) + output = pad_input( + hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), + indices=indices, + batch=bsz, + seqlen=past_kv_len + q_len, + ) # (bsz, past_kv_len + q_len, num_heads * head_dim) + + if past_kv_len > 0: + # Strip off the zero query outputs. + output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) + output = self.o_proj(output) # (bsz, q_len, hidden_size) + return output, None, past_key_value + + def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Formard function for RMS Norm + """ + return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) + + def replace_with_flash_attention(model: LlamaForCausalLM) -> None: + for name, module in model.named_modules(): + if isinstance(module, LlamaAttention): + module.forward = MethodType(attention_forward, module) + if isinstance(module, LlamaModel): + module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) + if isinstance(module, LlamaRMSNorm): + module.forward = MethodType(rms_norm_forward, module) + +elif get_accelerator().name == "npu": + import torch_npu + + class NPULlamaAttention(LlamaAttention): + use_flash: bool = True + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.setup() + + def setup(self): + self._softmax_scale = 1 / math.sqrt(self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if not self.use_flash: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + else: + attn_output, *_ = torch_npu.npu_fusion_attention( + query_states, + key_states, + value_states, + self.num_heads, + "BNSD", + atten_mask=attention_mask.bool(), + scale=self._softmax_scale, + padding_mask=None, + pre_tockens=65535, + next_tockens=0, + keep_prob=1.0, + inner_precise=0, + ) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum( + [F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)] + ) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class NPURMSNorm(LlamaRMSNorm): + def forward(self, hidden_states): + return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] + + def replace_with_flash_attention(model: LlamaForCausalLM) -> None: + for name, module in model.named_modules(): + if isinstance(module, LlamaAttention): + module.__class__ = NPULlamaAttention + module.setup() + if isinstance(module, LlamaRMSNorm): + module.__class__ = NPURMSNorm diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py index 9f6c9c1cc6f3..21d769f3c49f 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py @@ -17,7 +17,7 @@ def unwrap(model): if hasattr(model, "module"): - return unwrap_model(model.module) + return model.unwrap() else: return model diff --git a/applications/Colossal-LLaMA-2/inference_example.py b/applications/Colossal-LLaMA-2/inference_example.py index 7fe2d92abd05..63ce91e50432 100644 --- a/applications/Colossal-LLaMA-2/inference_example.py +++ b/applications/Colossal-LLaMA-2/inference_example.py @@ -1,22 +1,21 @@ import argparse -import os import torch +from colossal_llama2.dataset.conversation import default_conversation +from transformers import AutoModelForCausalLM, AutoTokenizer + from colossalai.logging import get_dist_logger -from transformers import AutoTokenizer, AutoModelForCausalLM logger = get_dist_logger() def load_model(model_path, device="cuda", **kwargs): - logger.info( - "Please check whether the tokenizer and model weights are properly stored in the same folder." - ) + logger.info("Please check whether the tokenizer and model weights are properly stored in the same folder.") model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) model.to(device) try: - tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left') except OSError: raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.") @@ -27,31 +26,51 @@ def load_model(model_path, device="cuda", **kwargs): def generate(args): model, tokenizer = load_model(model_path=args.model_path, device=args.device) - BASE_INFERENCE_SUFFIX = "\n\n->\n\n" - input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}" - - inputs = tokenizer(args.input_txt, return_tensors='pt').to(args.device) - output = model.generate(**inputs, - max_new_tokens=args.max_new_tokens, - do_sample=args.do_sample, - temperature=args.temperature, - top_k=args.top_k, - top_p=args.top_p, - num_return_sequences=1) - response = tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input_txt):] - logger.info(f"Question: {input_txt} \n\n Answer: \n{response}") + if args.prompt_style == "sft": + conversation = default_conversation.copy() + conversation.append_message("Human", args.input_txt) + conversation.append_message("Assistant", None) + input_txt = conversation.get_prompt() + else: + BASE_INFERENCE_SUFFIX = "\n\n->\n\n" + input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}" + + inputs = tokenizer(input_txt, return_tensors="pt").to(args.device) + num_input_tokens = inputs["input_ids"].shape[-1] + output = model.generate( + **inputs, + max_new_tokens=args.max_new_tokens, + do_sample=args.do_sample, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + num_return_sequences=1, + ) + response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True) + logger.info(f"\nHuman: {args.input_txt} \n\nAssistant: \n{response}") return response if __name__ == "__main__": parser = argparse.ArgumentParser(description="Colossal-LLaMA-2 inference Process.") - parser.add_argument('--model_path', type=str, default="hpcai-tech/Colossal-LLaMA-2-7b-base", help="HF repo name or local path of the model") - parser.add_argument('--device', type=str, default="cuda:0", help="Set the device") - parser.add_argument('--max_new_tokens', type=int, default=512, help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt") - parser.add_argument('--do_sample', type=bool, default=True, help="Set whether or not to use sampling") - parser.add_argument('--temperature', type=float, default=0.3, help="Set temperature value") - parser.add_argument('--top_k', type=int, default=50, help="Set top_k value for top-k-filtering") - parser.add_argument('--top_p', type=int, default=0.95, help="Set top_p value for generation") - parser.add_argument('--input_txt', type=str, default="明月松间照,", help="The prompt input to the model") + parser.add_argument( + "--model_path", + type=str, + default="hpcai-tech/Colossal-LLaMA-2-7b-base", + help="HF repo name or local path of the model", + ) + parser.add_argument("--device", type=str, default="cuda:0", help="Set the device") + parser.add_argument( + "--max_new_tokens", + type=int, + default=512, + help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt", + ) + parser.add_argument("--do_sample", type=bool, default=True, help="Set whether or not to use sampling") + parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value") + parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering") + parser.add_argument("--top_p", type=float, default=0.95, help="Set top_p value for generation") + parser.add_argument("--input_txt", type=str, default="明月松间照,", help="The prompt input to the model") + parser.add_argument("--prompt_style", choices=["sft", "pretrained"], default="sft", help="The style of the prompt") args = parser.parse_args() - generate(args) \ No newline at end of file + generate(args) diff --git a/applications/Colossal-LLaMA-2/requirements.txt b/applications/Colossal-LLaMA-2/requirements.txt index d8afee768c02..34afaf7e5cfd 100644 --- a/applications/Colossal-LLaMA-2/requirements.txt +++ b/applications/Colossal-LLaMA-2/requirements.txt @@ -1,9 +1,9 @@ torch<2.0.0, >=1.12.1 packaging==23.1 -colossalai==0.3.2 +colossalai==0.3.5 autoflake==2.2.1 black==23.9.1 -transformers +transformers==4.33.3 tensorboard==2.14.0 six==1.16.0 datasets diff --git a/applications/Colossal-LLaMA-2/train.example.sh b/applications/Colossal-LLaMA-2/train.example.sh index 276d9ce99d42..6a1c887bf6cc 100644 --- a/applications/Colossal-LLaMA-2/train.example.sh +++ b/applications/Colossal-LLaMA-2/train.example.sh @@ -42,3 +42,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train. --warmup_steps 100 \ --use_grad_checkpoint \ --use_flash_attn \ + --pad_token "unk" diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index 92863e8e4bba..2e4bab75a085 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Continual Pre-training of LLaMA-2 developed by Colossal-AI Team +Continual Pre-training/Supervised fine-tuning of Colossal-LLaMA-2 developed by Colossal-AI Team """ import argparse @@ -16,22 +16,24 @@ DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset, - setup_distributed_dataloader, ) from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention from colossal_llama2.utils.froze import freeze_non_embeds_parameters +from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer +from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device def get_model_numel(model: torch.nn.Module) -> int: @@ -83,6 +85,7 @@ def main() -> None: parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory") parser.add_argument("--config_file", type=str, default="config_file", help="Config file") parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps") parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process") parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") parser.add_argument("--max_length", type=int, default=4096, help="Model max length") @@ -108,6 +111,12 @@ def main() -> None: default=False, help="Use flash-attention", ) + parser.add_argument( + "--use_neft", + action="store_true", + default=False, + help="Use NEFTune", + ) parser.add_argument( "--freeze_non_embeds_params", action="store_true", @@ -116,6 +125,8 @@ def main() -> None: ) parser.add_argument("--tp", type=int, default=1) parser.add_argument("--zero", type=int, default=1) + parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos") + parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length") args = parser.parse_args() with open(args.config_file, "w") as f: @@ -125,6 +136,7 @@ def main() -> None: # Initialize Distributed Training # ============================== colossalai.launch_from_torch({}) + accelerator = get_accelerator() coordinator = DistCoordinator() # ============================== @@ -142,6 +154,7 @@ def main() -> None: precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip, + enable_gradient_accumulation=(args.accumulation_steps > 1), ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -149,6 +162,7 @@ def main() -> None: placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip, + enable_gradient_accumulation=(args.accumulation_steps > 1), ) elif args.plugin == "zero2": plugin = LowLevelZeroPlugin( @@ -182,7 +196,10 @@ def main() -> None: # Initialize Tokenizer, Dataset, Collator and Dataloader # ====================================================== tokenizer = LlamaTokenizer.from_pretrained(args.pretrained) - tokenizer.pad_token = tokenizer.unk_token + if args.pad_token == "eos": + tokenizer.pad_token = tokenizer.eos_token + elif args.pad_token == "unk": + tokenizer.pad_token = tokenizer.unk_token tokenizer.add_bos_token = False tokenizer.add_eos_token = False @@ -193,38 +210,36 @@ def main() -> None: coordinator.print_on_master(f"Load dataset: {args.dataset}") dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") - data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) - dataloader = setup_distributed_dataloader( + data_collator = DataCollatorForSupervisedDataset( + tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode + ) + dataloader = plugin.prepare_dataloader( dataset=dataset, batch_size=args.micro_batch_size, shuffle=True, drop_last=True, collate_fn=data_collator, + distributed_sampler_cls=StatefulDistributedSampler, ) coordinator.print_on_master( - f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" ) # ====================================================== # Initialize Model, Objective, Optimizer and LR Scheduler # ====================================================== - - # colossalai has changed api for get_current_device in 0.3.4 version or newer - try: - from colossalai.accelerator import get_accelerator - - current_device = get_accelerator().get_current_device() - except: - from colossalai.utils import get_current_device - - current_device = get_current_device() - - init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() + init_ctx = ( + LazyInitContext(default_device=get_current_device()) + if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) + else nullcontext() + ) with init_ctx: - model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) + model = LlamaForCausalLM.from_pretrained(args.pretrained) # Freeze part of parameters. if args.freeze_non_embeds_params: freeze_non_embeds_parameters(model=model) + # this is essential, otherwise the grad checkpoint will not work. + model.train() if args.use_grad_checkpoint: model.gradient_checkpointing_enable() @@ -246,12 +261,14 @@ def main() -> None: adamw_mode=True, ) + if args.warmup_steps is None: + args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps)) + coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") + lr_scheduler = CosineAnnealingWarmupLR( optimizer=optimizer, - total_steps=args.num_epochs * len(dataloader), - warmup_steps=args.warmup_steps - if args.warmup_steps is not None - else int(args.num_epochs * len(dataloader) * 0.025), + total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps), + warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr, ) @@ -267,11 +284,9 @@ def main() -> None: torch.set_default_dtype(torch.float) - if args.load_checkpoint is None: - coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}") - booster.load_model(model, args.pretrained, strict=False) - - coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + coordinator.print_on_master( + f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) coordinator.print_on_master( f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" ) @@ -298,85 +313,109 @@ def main() -> None: coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") coordinator.print_on_master( - f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + f"Checkpoint loaded max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" ) coordinator.print_on_master( - f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + f"Checkpoint loaded device memory: {accelerator.memory_allocated() / 1024 ** 2:.2f} MB" ) coordinator.print_on_master( f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" ) - num_steps_per_epoch = len(dataloader) + if args.use_neft: + coordinator.print_on_master("Activate NEFTune.") + model, handle = activate_neftune(model) + + num_steps_per_epoch = len(dataloader) // args.accumulation_steps # If resume training, set the sampler start index to the correct value assert isinstance(dataloader.sampler, StatefulDistributedSampler) dataloader.sampler.set_start_index(start_index=sampler_start_idx) for epoch in range(start_epoch, args.num_epochs): dataloader.sampler.set_epoch(epoch=epoch) - with tqdm( - iterable=enumerate(dataloader, start=start_step), + pbar = tqdm( desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, - initial=start_step, - ) as pbar: - for step, batch in pbar: - batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)} + initial=start_step // args.accumulation_steps, + ) + total_loss = torch.tensor(0.0, device=get_current_device()) + for step, batch in enumerate(dataloader, start=start_step): + batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} - batch_output = model(**batch) + batch_output = model(**batch) - loss = batch_output.loss + loss = batch_output.loss / args.accumulation_steps + total_loss.add_(loss.data) - booster.backward(loss=loss, optimizer=optimizer) + booster.backward(loss=loss, optimizer=optimizer) + if (step + 1) % args.accumulation_steps == 0: optimizer.step() lr_scheduler.step() optimizer.zero_grad() - all_reduce_mean(tensor=loss) - pbar.set_postfix({"Loss": f"{loss.item():.4f}"}) + all_reduce_mean(tensor=total_loss) + pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"}) if coordinator.is_master(): - global_step = epoch * num_steps_per_epoch + step - writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=global_step) + global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps + writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step) writer.add_scalar( tag="Learning Rate", scalar_value=lr_scheduler.get_last_lr()[0], global_step=global_step, ) - # Save modeling. - - if (args.save_interval > 0 and (step + 1) % args.save_interval == 0) or (step + 1) == len(dataloader): - coordinator.print_on_master("\nStart saving model checkpoint with running states") - save_checkpoint( - save_dir=args.save_dir, - booster=booster, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - epoch=epoch, - step=step + 1, - batch_size=args.micro_batch_size, - coordinator=coordinator, - ) - coordinator.print_on_master( - f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" - ) - - # Delete CUDA cache. - # del batch, batch_labels, batch_output, loss - torch.cuda.empty_cache() + total_loss.fill_(0.0) + pbar.update() + # Save modeling. + + if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or ( + step + 1 + ) == len(dataloader): + coordinator.print_on_master("\nStart saving model checkpoint with running states") + + if args.use_neft: + coordinator.print_on_master("Deactivate NEFTune before saving model.") + deactivate_neftune(model, handle) + + accelerator.empty_cache() + save_checkpoint( + save_dir=args.save_dir, + booster=booster, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + step=step + 1, + batch_size=args.micro_batch_size, + coordinator=coordinator, + ) + coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" + ) + + if args.use_neft: + coordinator.print_on_master("Activate NEFTune.") + model, handle = activate_neftune(model) + + # Delete cache. + # del batch, batch_labels, batch_output, loss + accelerator.empty_cache() # the continue epochs are not resumed, so we need to reset the sampler start index and start step dataloader.sampler.set_start_index(start_index=0) start_step = 0 + if args.use_neft: + coordinator.print_on_master("Deactivate NEFTune.") + deactivate_neftune(model, handle) + # Final save. coordinator.print_on_master("Start saving final model checkpoint") booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") - coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB") if __name__ == "__main__": diff --git a/applications/Colossal-LLaMA-2/train_sft.example.sh b/applications/Colossal-LLaMA-2/train_sft.example.sh index dcb11515d48f..d87f9ef82f4f 100755 --- a/applications/Colossal-LLaMA-2/train_sft.example.sh +++ b/applications/Colossal-LLaMA-2/train_sft.example.sh @@ -25,7 +25,7 @@ SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}" CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" -colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_sft.py \ +colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.py \ --pretrained $PRETRAINED_MODEL_PATH \ --dataset ${dataset[@]} \ --plugin "zero2" \ @@ -44,3 +44,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_ --use_grad_checkpoint \ --use_flash_attn \ --use_neft \ + --pad_token "eos" diff --git a/applications/Colossal-LLaMA-2/train_sft.py b/applications/Colossal-LLaMA-2/train_sft.py deleted file mode 100644 index fd9e1cd3e747..000000000000 --- a/applications/Colossal-LLaMA-2/train_sft.py +++ /dev/null @@ -1,403 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Supervised fine-tuning of Colossal-LLaMA-2-base developed by Colossal-AI Team -""" - -import argparse -import json -import os -import resource -from contextlib import nullcontext - -import torch -import torch.distributed as dist -from colossal_llama2.dataset.loader import ( - DataCollatorForSupervisedDataset, - StatefulDistributedSampler, - load_tokenized_dataset, - setup_distributed_dataloader, -) -from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint -from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention -from colossal_llama2.utils.froze import freeze_non_embeds_parameters -from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm -from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin -from colossalai.cluster import DistCoordinator -from colossalai.lazy import LazyInitContext -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device - - -def get_model_numel(model: torch.nn.Module) -> int: - return sum(p.numel() for p in model.parameters()) - - -def format_numel_str(numel: int) -> str: - B = 1024**3 - M = 1024**2 - K = 1024 - if numel >= B: - return f"{numel / B:.2f} B" - elif numel >= M: - return f"{numel / M:.2f} M" - elif numel >= K: - return f"{numel / K:.2f} K" - else: - return f"{numel}" - - -def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: - dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) - tensor.div_(dist.get_world_size()) - return tensor - - -def main() -> None: - # ============================== - # Parse Arguments - # ============================== - parser = argparse.ArgumentParser() - parser.add_argument( - "--pretrained", - type=str, - default=None, - help="Address of the pre-trained modeling", - ) - parser.add_argument("--dataset", nargs="+", default=[]) - parser.add_argument( - "--plugin", - type=str, - default="gemini", - choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], - help="Choose which plugin to use", - ) - parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint") - parser.add_argument("--save_interval", type=int, default=1000, help="Save interval") - parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory") - parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory") - parser.add_argument("--config_file", type=str, default="config_file", help="Config file") - parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") - parser.add_argument("--accumulation_steps", type=int, default=8, help="Number of accumulation steps") - parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process") - parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") - parser.add_argument("--max_length", type=int, default=4096, help="Model max length") - parser.add_argument( - "--mixed_precision", - type=str, - default="fp16", - choices=["fp16", "bf16"], - help="Mixed precision", - ) - parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") - parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") - parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") - parser.add_argument( - "--use_grad_checkpoint", - action="store_true", - default=False, - help="Use gradient checkpointing", - ) - parser.add_argument( - "--use_flash_attn", - action="store_true", - default=False, - help="Use flash-attention", - ) - parser.add_argument( - "--use_neft", - action="store_true", - default=False, - help="Use NEFTune", - ) - parser.add_argument( - "--freeze_non_embeds_params", - action="store_true", - default=False, - help="Freeze non embeddings parameters", - ) - parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--zero", type=int, default=1) - args = parser.parse_args() - - with open(args.config_file, "w") as f: - json.dump(args.__dict__, f, indent=4) - - # ============================== - # Initialize Distributed Training - # ============================== - colossalai.launch_from_torch({}) - coordinator = DistCoordinator() - - # ============================== - # Initialize Tensorboard - # ============================== - if coordinator.is_master(): - os.makedirs(args.tensorboard_dir, exist_ok=True) - writer = SummaryWriter(args.tensorboard_dir) - - # ============================== - # Initialize Booster - # ============================== - if args.plugin == "gemini": - plugin = GeminiPlugin( - precision=args.mixed_precision, - initial_scale=2**16, - max_norm=args.grad_clip, - ) - elif args.plugin == "gemini_auto": - plugin = GeminiPlugin( - precision=args.mixed_precision, - placement_policy="auto", - initial_scale=2**16, - max_norm=args.grad_clip, - ) - elif args.plugin == "zero2": - plugin = LowLevelZeroPlugin( - stage=2, - precision=args.mixed_precision, - initial_scale=2**16, - max_norm=args.grad_clip, - ) - elif args.plugin == "zero2_cpu": - plugin = LowLevelZeroPlugin( - stage=2, - precision=args.mixed_precision, - initial_scale=2**16, - cpu_offload=True, - max_norm=args.grad_clip, - ) - elif args.plugin == "3d": - plugin = HybridParallelPlugin( - tp_size=args.tp, - pp_size=1, - zero_stage=args.zero, - max_norm=args.grad_clip, - precision=args.mixed_precision, - ) - else: - raise ValueError(f"Unknown plugin {args.plugin}") - - booster = Booster(plugin=plugin) - - # ====================================================== - # Initialize Tokenizer, Dataset, Collator and Dataloader - # ====================================================== - tokenizer = LlamaTokenizer.from_pretrained(args.pretrained) - tokenizer.pad_token = tokenizer.eos_token - tokenizer.add_bos_token = False - tokenizer.add_eos_token = False - - coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}") - coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}") - coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}") - - coordinator.print_on_master(f"Load dataset: {args.dataset}") - - dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") - data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) - dataloader = setup_distributed_dataloader( - dataset=dataset, - batch_size=args.micro_batch_size, - shuffle=True, - drop_last=True, - collate_fn=data_collator, - ) - coordinator.print_on_master( - f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" - ) - - # ====================================================== - # Initialize Model, Objective, Optimizer and LR Scheduler - # ====================================================== - init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() - ) - with init_ctx: - model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) - # Freeze part of parameters. - if args.freeze_non_embeds_params: - freeze_non_embeds_parameters(model=model) - - if args.use_grad_checkpoint: - model.gradient_checkpointing_enable() - coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") - if args.use_flash_attn: - replace_with_flash_attention(model=model) - coordinator.print_on_master(msg="Flash-attention enabled successfully") - - model_numel = get_model_numel(model) - coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") - - optimizer = HybridAdam( - model_params=filter(lambda p: p.requires_grad, model.parameters()) - if args.freeze_non_embeds_params - else model.parameters(), - lr=args.lr, - betas=(0.9, 0.95), - weight_decay=args.weight_decay, - adamw_mode=True, - ) - - if args.warmup_steps is None: - args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps)) - coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") - - lr_scheduler = CosineAnnealingWarmupLR( - optimizer=optimizer, - total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps), - warmup_steps=args.warmup_steps, - eta_min=0.1 * args.lr, - ) - - # Flash attention will be disabled because it does NOT support fp32. - default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 - torch.set_default_dtype(default_dtype) - model, optimizer, _, dataloader, lr_scheduler = booster.boost( - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - dataloader=dataloader, - ) - - torch.set_default_dtype(torch.float) - - if args.load_checkpoint is None: - coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}") - booster.load_model(model, args.pretrained, strict=False) - - coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") - coordinator.print_on_master( - f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" - ) - - start_epoch = 0 - start_step = 0 - sampler_start_idx = 0 - if args.load_checkpoint is not None: - if "modeling" in args.load_checkpoint: - coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}") - booster.load_model(model, args.load_checkpoint) - else: - coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}") - start_epoch, start_step, sampler_start_idx = load_checkpoint( - load_dir=args.load_checkpoint, - booster=booster, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - ) - coordinator.print_on_master( - f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}" - ) - coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") - - coordinator.print_on_master( - f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" - ) - coordinator.print_on_master( - f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" - ) - coordinator.print_on_master( - f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" - ) - - if args.use_neft: - coordinator.print_on_master("Activate NEFTune.") - model, handle = activate_neftune(model) - - num_steps_per_epoch = len(dataloader) // args.accumulation_steps - # If resume training, set the sampler start index to the correct value - assert isinstance(dataloader.sampler, StatefulDistributedSampler) - dataloader.sampler.set_start_index(start_index=sampler_start_idx) - - for epoch in range(start_epoch, args.num_epochs): - dataloader.sampler.set_epoch(epoch=epoch) - pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch) - total_loss = torch.tensor(0.0).to(torch.cuda.current_device()) - for step, batch in enumerate(dataloader): - batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} - - batch_output = model(**batch) - - loss = batch_output.loss / args.accumulation_steps - total_loss += loss.item() - - booster.backward(loss=loss, optimizer=optimizer) - - if (step + 1) % args.accumulation_steps == 0: - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - all_reduce_mean(tensor=total_loss) - pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"}) - if coordinator.is_master(): - global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps - writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step) - writer.add_scalar( - tag="Learning Rate", - scalar_value=lr_scheduler.get_last_lr()[0], - global_step=global_step, - ) - total_loss.fill_(0.0) - pbar.update() - # Save modeling. - - if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or ( - step + 1 - ) == len(dataloader): - coordinator.print_on_master("\nStart saving model checkpoint with running states") - - if args.use_neft: - coordinator.print_on_master("Deactivate NEFTune before saving model.") - deactivate_neftune(model, handle) - - save_checkpoint( - save_dir=args.save_dir, - booster=booster, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - epoch=epoch, - step=step + 1, - batch_size=args.micro_batch_size, - coordinator=coordinator, - ) - coordinator.print_on_master( - f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" - ) - - if args.use_neft: - coordinator.print_on_master("Activate NEFTune.") - model, handle = activate_neftune(model) - - # Delete CUDA cache. - # del batch, batch_labels, batch_output, loss - torch.cuda.empty_cache() - - # the continue epochs are not resumed, so we need to reset the sampler start index and start step - dataloader.sampler.set_start_index(start_index=0) - start_step = 0 - - if args.use_neft: - coordinator.print_on_master("Deactivate NEFTune.") - deactivate_neftune(model, handle) - - # Final save. - coordinator.print_on_master("Start saving final model checkpoint") - booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) - coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") - - coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - - -if __name__ == "__main__": - main() diff --git a/applications/ColossalEval/colossal_eval/models/chatglm.py b/applications/ColossalEval/colossal_eval/models/chatglm.py index f293c4f699cd..9c70c0d2a1ad 100644 --- a/applications/ColossalEval/colossal_eval/models/chatglm.py +++ b/applications/ColossalEval/colossal_eval/models/chatglm.py @@ -3,6 +3,8 @@ import torch +from colossalai.utils import get_current_device + from .huggingface import HuggingFaceModel IGNORE_INDEX = -100 @@ -126,9 +128,9 @@ def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[t """ input_ids = torch.nn.utils.rnn.pad_sequence( input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id - ).to(torch.cuda.current_device()) + ).to(get_current_device()) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to( - torch.cuda.current_device() + get_current_device() ) outputs = self.model(input_ids)[0] @@ -197,7 +199,7 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str truncation=True, return_tensors="pt", max_length=self.model_max_length - max_new_tokens, - ).to(torch.cuda.current_device()) + ).to(get_current_device()) # Set output_scores=True to get prediction scores. outputs = self.model.generate( diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index 741c884f0043..fff697e21e34 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -11,6 +11,7 @@ from colossalai.logging import DistributedLogger from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.utils import get_current_device from .base import BaseModel @@ -128,12 +129,12 @@ def _load_model( self.model = AutoModel.from_pretrained(path, **model_kwargs) shard_former = ShardFormer(shard_config) self.model, sharded_parameters = shard_former.optimize(self.model) - self.model.to(torch.cuda.current_device()) + self.model.to(get_current_device()) if peft_path is not None: raise NotImplementedError("ShardFormer for PEFT models is not implemented.") else: - self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device()) + self.model = AutoModel.from_pretrained(path, **model_kwargs).to(get_current_device()) if peft_path is not None: self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False) self.model.eval() @@ -155,11 +156,11 @@ def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[t """ input_ids = torch.nn.utils.rnn.pad_sequence( input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id - ).to(torch.cuda.current_device()) + ).to(get_current_device()) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to( - torch.cuda.current_device() + get_current_device() ) - attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(torch.cuda.current_device()) + attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(get_current_device()) outputs = self.model(input_ids, attention_mask=attention_mask)[0] @@ -464,7 +465,7 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str return_tensors="pt", return_token_type_ids=False, max_length=self.model_max_length - max_new_tokens, - ).to(torch.cuda.current_device()) + ).to(get_current_device()) # Set output_scores=True to get prediction scores. outputs = self.model.generate( @@ -598,12 +599,12 @@ def _load_model( self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs) shard_former = ShardFormer(shard_config) self.model, sharded_parameters = shard_former.optimize(self.model) - self.model.to(torch.cuda.current_device()) + self.model.to(get_current_device()) if peft_path is not None: raise NotImplementedError("ShardFormer for PEFT models is not implemented.") else: - self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device()) + self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(get_current_device()) if peft_path is not None: self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False) diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index 5b09f9de8da6..a340f3bfd281 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -8,6 +8,7 @@ from colossal_eval import dataset, models, utils import colossalai +from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.logging import get_dist_logger from colossalai.shardformer import ShardConfig @@ -82,6 +83,7 @@ def rm_and_merge( def main(args): colossalai.launch_from_torch(config={}, seed=42) + accelerator = get_accelerator() world_size = dist.get_world_size() rank = dist.get_rank() @@ -235,10 +237,10 @@ def main(args): ), ) - logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB") + logger.info(f"Rank {rank} peak device mem: {accelerator.max_memory_allocated()/1024**3:.3f} GB") del model_ - torch.cuda.empty_cache() + accelerator.empty_cache() dist.barrier() if rank == 0: diff --git a/applications/ColossalMoE/README.md b/applications/ColossalMoE/README.md new file mode 100644 index 000000000000..be50a8f9f251 Binary files /dev/null and b/applications/ColossalMoE/README.md differ diff --git a/applications/ColossalMoE/colossal_moe/__init__.py b/applications/ColossalMoE/colossal_moe/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/applications/ColossalMoE/colossal_moe/models/__init__.py b/applications/ColossalMoE/colossal_moe/models/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py new file mode 100644 index 000000000000..d08dfd5f8120 --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py @@ -0,0 +1,629 @@ +import copy +import logging +import os +from pathlib import Path +from shutil import rmtree +from typing import Dict, Iterator, Optional, OrderedDict, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.checkpoint_io import CheckpointIndexFile +from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO +from colossalai.checkpoint_io.index_file import CheckpointIndexFile +from colossalai.checkpoint_io.utils import ( + StateDictSharder, + gather_distributed_param, + get_model_base_filenames, + get_optimizer_base_filenames, + load_shard_state_dict, + load_states_into_optimizer, + save_config_file, + save_param_groups, + save_state_dict_shards, + search_tp_partition_dim, + sharded_optimizer_loading_epilogue, +) +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.moe import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" + + +class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): + def __init__( + self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + verbose: bool = True, + ) -> None: + super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose) + moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size] + self.ep_group = moe_info.ep_group + self.ep_size = moe_info.ep_size + self.ep_rank = moe_info.ep_rank + self.real_dp_rank = moe_info.dp_rank + + @staticmethod + def _model_sharder( + model: nn.Module, + prefix: str = "", + keep_vars: bool = False, + size_per_shard: int = 1024, + param_name_pattern: Optional[str] = None, + ) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + if param_name_pattern is not None and param_name_pattern not in name: + continue + # Gather tensor pieces when using tensor parallel. + param_ = gather_distributed_param(param, keep_vars=False) + block, block_size = state_dict_sharder.append_param(prefix + name, param_) + if block is not None: + yield block, block_size + + # Save buffers. + for name, buf in model.named_buffers(): + if buf is not None and name not in model._non_persistent_buffers_set: + buffer = buf if keep_vars else buf.detach() + block, block_size = state_dict_sharder.append_param(prefix + name, buffer) + if block is not None: + yield block, block_size + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + extra_state = model.get_extra_state() + block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + def save_sharded_model( + self, + model: ModelWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: + """ + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_model.-000XX.bin" + + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + """ + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model = model.unwrap() + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + if self.real_dp_rank != 0: + dist.barrier() + return + + # ep_rank 0 saves all the parameters and buffers. + # other ep_ranks save only experts + ep_param_pattern = "experts." if self.ep_rank != 0 else None + + # Then collect the sharded parameters & buffers along tp_group. + # Only devices with tp_rank == 0 are responsible for model saving. + state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder( + model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern + ) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + control_saving = self.tp_rank == 0 + + if self.pp_size == 1 and self.ep_size == 1: + # When pipeline is not used, save the model shards as in general checkpointIO + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + dist.barrier() + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin") + weights_name = weights_name.replace( + ".safetensors", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.safetensors" + ) + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + use_pp_format=True, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + dist.barrier() + return + + dist.barrier() + + # The global master rank integrates the index files and clean the folder. + if self.coordinator.is_master(): + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for weight, weight_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(weight, weight_filename) + + final_index_file.write_index_file(final_index_file_path) + save_config_file(model, checkpoint) + rmtree(tmp_index_file_folder) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) + + @staticmethod + def gather_from_sharded_optimizer_state( + state: OrderedDict, + param: torch.Tensor, + original_shape: torch.Size, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + use_zero: bool, + inplace: bool, + is_moe_param: bool, + device: torch.device = torch.device("cpu"), + ) -> OrderedDict: + """ + With given parameter and its optimizer states, gather the complete optimizer state for saving. + + Args: + state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero. + param (torch.Tensor): The given parameter. It should be working_param when using Zero. + original_shape (torch.Size): The size of parameter before sharding. + dp_group (ProcessGroup): The process group of data parallel. + tp_group (ProcessGroup): The process group of tensor parallel. + use_zero (bool): Whether Zero is used. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu'). + + Returns: + OrderedDict: The complete optimizer state of given parameter. + """ + dp_size = dist.get_world_size(dp_group) + tp_size = dist.get_world_size(tp_group) + current_shape = param.shape + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != "step": + # First gather Zero shards. + if use_zero and not is_moe_param: + v = v.cuda() + gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] + dist.all_gather(gather_tensor, v, group=dp_group) + v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) + + # Then gather TP shards. + partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) + if partition_dim is not None: + gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] + dist.all_gather(gather_tensor, v, group=tp_group) + v = torch.cat(gather_tensor, dim=partition_dim) + + state_[k] = v.detach().clone().to(device) + + return state_ + + @staticmethod + def _optimizer_sharder( + optimizer: OptimizerWrapper, + use_zero: bool, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + size_per_shard: int = 1024, + only_moe_param: bool = False, + ): + # An internel method that breaks state_dict of optimizer into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + param_info = optimizer.param_info + master_to_working_map = optimizer.get_master_to_working_map() + + for param, state in optimizer.optim.state.items(): + if param is None: + continue + + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + param_id = param_info["param2id"][id(working_param)] + original_shape = param_info["param2shape"][id(working_param)] + state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state( + state, + working_param, + original_shape=original_shape, + dp_group=dp_group, + tp_group=tp_group, + use_zero=use_zero, + inplace=False, + is_moe_param=is_moe_tensor(working_param), + ) + + if only_moe_param and not is_moe_tensor(working_param): + continue + block, block_size = state_dict_sharder.append_optim_state(param_id, state_) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + def save_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + ): + """ + Save sharded optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files that store state tensors of optimizers. + If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_optim.-000XX.bin" + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file shard that store state tensors + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of states when zero is not used. + # In this case only let the device with dp_rank == 0 save the model. + if not self.use_zero and self.real_dp_rank != 0: + dist.barrier() + return + + # Then collect the sharded states along dp_group(if using zero)/tp_group. + # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. + state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder( + optimizer, + use_zero=self.use_zero, + dp_group=self.dp_group, + tp_group=self.tp_group, + size_per_shard=size_per_shard, + only_moe_param=self.ep_rank != 0, + ) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + control_saving = self.real_dp_rank == 0 and self.tp_rank == 0 + + if self.pp_size == 1 and self.ep_size == 1: + # When pipeline is not used, save the optimizer shards as in general checkpointIO + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + ) + + if control_saving: + # Store param groups. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + save_param_groups({"param_groups": param_groups}, group_file_path) + # Store index file. + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + dist.barrier() + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + ) + + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + dist.barrier() + return + + dist.barrier() + + # The global master rank integrates the index files and clean the folder. + if self.coordinator.is_master(): + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for param_id, state_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(param_id, state_filename) + + # Store param groups. + final_index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + save_param_groups({"param_groups": param_groups}, group_file_path) + + final_index_file.write_index_file(final_index_file_path) + rmtree(tmp_index_file_folder) + + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) + + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): + """ + Load sharded optimizer with the given path to index file of checkpoint folder. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + prefix (str): Not used. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + + def _get_param_id_from_optimizer_param( + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + ): + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + return optimizer.param_info["param2id"][id(working_param)] + + # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. + # When Zero is used, the mapped parameter objects should be fp32 master parameters. + # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. + id_map = {} + master_to_working_map = optimizer.get_master_to_working_map() + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + id_map[param_id] = param + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory." + ) + saved_groups = torch.load(param_group_path) + + updated_groups = [] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + # obtain updated param group + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. + updated_groups.append(new_pg) + # ep param groups + if len(optimizer.optim.param_groups) == len(saved_groups) + 1: + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = optimizer.optim.param_groups[-1]["params"] + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({"param_groups": updated_groups}) + + # Load saved states to optimizer. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + if param is None: + continue + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + if param_id not in weight_map: + continue + filename = weight_map[param_id] + + # If this param's states has been loaded before, directly return. + if filename in loaded_file: + continue + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) + loaded_file.add(filename) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + device = param.device + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state( + state, + current_shape=working_param.shape, + original_shape=original_shape, + device=device, + inplace=True, + is_moe_param=is_moe_tensor(working_param), + ) + optimizer.optim.state[param] = sharded_state + + sharded_optimizer_loading_epilogue(optimizer.optim) + if self.verbose and self.coordinator.is_master(): + logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + + def shard_from_complete_optimizer_state( + self, + state: OrderedDict, + current_shape: torch.Size, + original_shape: torch.Size, + device: torch.device, + inplace: bool, + is_moe_param: bool, + ) -> OrderedDict: + """ + With complete optimizer states of a specific parameter loaded from checkpoint, + slice out the sharded optimizer states kept by current device. + + Args: + state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint. + current_shape (torch.Size): The size of parameter after sharding. + original_shape (torch.Size): The size of parameter before sharding. + device (torch.device): The destination device of loaded optimizer states. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The sharded optimizer state of the given parameter. + """ + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != "step": + # Shard state along tensor parallel group. + partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + if partition_dim is not None: + slice_size = current_shape[partition_dim] + v = v.split(slice_size, dim=partition_dim)[self.tp_rank] + + # Shard state along data parallel group when using Zero. + if self.use_zero and not is_moe_param: + padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + slice_size = v.numel() // self.dp_size + v = v.split(slice_size, dim=0)[self.dp_rank] + + state_[k] = v.detach().clone().to(device) + + return state_ + + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + raise NotImplementedError + + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + raise NotImplementedError + + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False): + raise NotImplementedError diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py new file mode 100644 index 000000000000..a2b78a2bd18c --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py @@ -0,0 +1,92 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +from colossalai.lazy import LazyInitContext +from colossalai.moe import MOE_MANAGER +from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven +from colossalai.shardformer.shard.utils import set_tensors_to_none +from colossalai.tensor.moe_tensor.api import set_moe_tensor_info + + +class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): + def __init__(self, config): + super().__init__(config) + self.setup_ep() + + def setup_ep(self): + _, moe_info = MOE_MANAGER.get_info(self.num_experts) + ep_group = moe_info.ep_group + self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1 + self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0 + assert self.num_experts % self.ep_size == 0 + self.ep_group = ep_group + self.num_experts_per_ep = self.num_experts // self.ep_size + self.expert_start_idx = self.ep_rank * self.num_experts_per_ep + held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] + set_tensors_to_none(self.experts, exclude=set(held_experts)) + for p in self.experts.parameters(): + set_moe_tensor_info(p, moe_info) + + @staticmethod + def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock": + LazyInitContext.materialize(module) + module.__class__ = EPMixtralSparseMoeBlock + module.setup_ep() + return module + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + selected_experts = selected_experts.t().reshape(-1) + selected_experts_idx = selected_experts.argsort() + dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx] + input_split_sizes = selected_experts.bincount(minlength=self.num_experts) + output_split_sizes = torch.zeros_like(input_split_sizes) + dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) + + input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) + # compute expert output + output_states = MoeInGradScaler.apply(output_states, self.ep_size) + if output_states.size(0) > 0: + if self.num_experts_per_ep == 1: + # no need to split + expert = self.experts[self.expert_start_idx] + output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states) + output_states = expert.w2(output_states) + else: + output_states_splits = output_states.split(output_split_sizes.tolist()) + output_states_list = [] + for i, split_states in enumerate(output_states_splits): + if split_states.size(0) == 0: + continue + expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] + split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) + split_states = expert.w2(split_states) + output_states_list.append(split_states) + output_states = torch.cat(output_states_list) + output_states = MoeOutGradScaler.apply(output_states, self.ep_size) + dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + recover_experts_idx = torch.empty_like(selected_experts_idx) + recover_experts_idx[selected_experts_idx] = torch.arange( + selected_experts_idx.size(0), device=selected_experts_idx.device + ) + dispatch_states = dispatch_states[recover_experts_idx] + k_hidden_states = dispatch_states.chunk(self.top_k) + output_states = k_hidden_states[0] * routing_weights[:, 0, None] + for i in range(1, self.top_k): + output_states += k_hidden_states[i] * routing_weights[:, i, None] + output_states = output_states.reshape(batch_size, sequence_length, hidden_dim) + return output_states, router_logits diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py new file mode 100644 index 000000000000..218b05b27fad --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py @@ -0,0 +1,557 @@ +from functools import partial +from typing import Callable, Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.models.mixtral.modeling_mixtral import ( + MixtralDecoderLayer, + MixtralForCausalLM, + MixtralModel, + MoeCausalLMOutputWithPast, + _prepare_4d_causal_attention_mask, + load_balancing_loss_func, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from colossalai.shardformer.shard import ShardConfig + +from .mixtral_layer import EPMixtralSparseMoeBlock + +__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"] + + +class MixtralPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + raise NotImplementedError( + "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + ) + + if self.shard_config.enable_tensor_parallelism: + raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") + + # expert parallel + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="block_sparse_moe", + target_module=EPMixtralSparseMoeBlock, + ) + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=MixtralModel, + ) + + if self.shard_config.enable_flash_attention: + raise NotImplementedError("Flash attention has already been replaced in mixtral.") + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "MixtralModel": + module = self.model + else: + module = self.model.model + + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + return + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "MixtralModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class MixtralModelPolicy(MixtralPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MixtralModel, + new_forward=MixtralPipelineForwards.mixtral_model_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class MixtralForCausalLMPolicy(MixtralPolicy): + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + MixtralForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True), + ) + ] + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MixtralForCausalLM, + new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + llama_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] + + +class MixtralPipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + @staticmethod + def mixtral_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + past_router_logits: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + output_attentions, + output_router_logits, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + output_router_logits, + use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + + if output_router_logits and past_router_logits is not None: + all_router_logits = past_router_logits + all_router_logits + if stage_manager.is_last_stage(): + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + # always return dict for imediate stage + return { + "hidden_states": hidden_states, + "past_router_logits": all_router_logits, + } + + @staticmethod + def mixtral_for_causal_lm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + past_router_logits: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = MixtralPipelineForwards.mixtral_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + past_router_logits=past_router_logits, + ) + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=None, + hidden_states=outputs[0], + attentions=None, + router_logits=outputs[-1], + ) + else: + out = {} + hidden_states = outputs.get("hidden_states") + out["hidden_states"] = hidden_states + if output_router_logits: + out["past_router_logits"] = outputs["past_router_logits"] + return out diff --git a/applications/ColossalMoE/colossal_moe/utils.py b/applications/ColossalMoE/colossal_moe/utils.py new file mode 100644 index 000000000000..a2a0a7e78239 --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/utils.py @@ -0,0 +1,84 @@ +import json +import os +from typing import Any, Dict, Tuple, Union + +import torch +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer + +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator + + +def move_to_cuda(batch, device): + return {k: v.to(device) for k, v in batch.items()} + + +def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]: + """ + Load file in JSON format + """ + with open(file=file_path, mode="r", encoding="utf-8") as fp: + return json.load(fp) + + +def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None: + """ + Save as JSON format + """ + with open(file=file_path, mode="w", encoding="utf-8") as fp: + json.dump(data, fp=fp, ensure_ascii=False, indent=4) + + +def save_checkpoint( + save_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + batch_size: int, + coordinator: DistCoordinator, +) -> None: + """ + Save model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}") + os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + running_states = { + "epoch": epoch, + "step": step, + "sample_start_index": step * batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + + +def load_checkpoint( + load_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, +) -> Tuple[int, int, int]: + """ + Load model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + # Update booster params states. + booster.load_model(model, os.path.join(load_dir, "modeling")) + booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer")) + booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler")) + + running_states = load_json(file_path=os.path.join(load_dir, "running_states.json")) + return ( + running_states["epoch"], + running_states["step"], + running_states["sample_start_index"], + ) diff --git a/applications/ColossalMoE/infer.py b/applications/ColossalMoE/infer.py new file mode 100644 index 000000000000..46ff70ff33ab --- /dev/null +++ b/applications/ColossalMoE/infer.py @@ -0,0 +1,111 @@ +import argparse + +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO +from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from transformers import AutoTokenizer +from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator + + +def parse_args(): + # basic settings + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="mistralai/Mixtral-8x7B-v0.1", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--plugin", + type=str, + default="ep", + choices=["ep"], + help="Parallel methos.", + ) + parser.add_argument( + "--precision", + type=str, + default="bf16", + choices=["fp32", "bf16", "fp16"], + help="The mixed precision training.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.", + ) + parser.add_argument( + "--use_layernorm_kernel", + action="store_true", + help="Use layernorm kernel. Need to install apex. Raise error if not installed.", + ) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + + config = MixtralConfig.from_pretrained(args.model_name) + ep_size = min(dist.get_world_size(), config.num_local_experts) + # Set plugin + if args.plugin == "ep": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + ep_size=ep_size, + zero_stage=1, + precision=args.precision, + custom_policy=MixtralForCausalLMPolicy(), + checkpoint_io=MixtralMoEHybridParallelCheckpointIO, + enable_fused_normalization=args.use_layernorm_kernel, + enable_jit_fused=args.use_kernel, + ) + else: + raise ValueError(f"Invalid plugin {args.plugin}") + coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") + + # Build mixtral model + model = MixtralForCausalLM.from_pretrained(args.model_name) + coordinator.print_on_master(f"Finish load model") + + # Prepare tokenizer and dataloader + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + + # Set booster + booster = Booster(plugin=plugin) + model, _, _, _, _ = booster.boost(model=model) + coordinator.print_on_master(f"Finish init booster") + + model.eval() + + if coordinator.rank == 0: + text = ["Hello my name is"] + else: + text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"] + tokenizer.pad_token = tokenizer.unk_token + inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device()) + + with torch.no_grad(): + outputs = model.module.generate(**inputs, max_new_tokens=20) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + print(f"[{coordinator.rank}] {outputs}") + + + +if __name__ == "__main__": + main() diff --git a/applications/ColossalMoE/infer.sh b/applications/ColossalMoE/infer.sh new file mode 100644 index 000000000000..0487fe9c1562 --- /dev/null +++ b/applications/ColossalMoE/infer.sh @@ -0,0 +1,7 @@ +NUM_GPU=2 +MODEL="mistralai/Mixtral-8x7B-v0.1" + +# ep +torchrun --standalone --nproc_per_node $NUM_GPU infer.py \ + --model_name $MODEL \ + --plugin "ep" \ diff --git a/applications/ColossalMoE/requirements.txt b/applications/ColossalMoE/requirements.txt new file mode 100644 index 000000000000..9a5738c412b9 --- /dev/null +++ b/applications/ColossalMoE/requirements.txt @@ -0,0 +1,5 @@ +colossalai >= 0.3.3 +torch >= 1.8.1 +transformers == 4.36.0 +sentencepiece +datasets diff --git a/applications/ColossalMoE/setup.py b/applications/ColossalMoE/setup.py new file mode 100644 index 000000000000..275f59e10a06 --- /dev/null +++ b/applications/ColossalMoE/setup.py @@ -0,0 +1,43 @@ +from setuptools import find_packages, setup + + +def fetch_requirements(path): + with open(path, "r") as fd: + return [r.strip() for r in fd.readlines()] + + +def fetch_readme(): + with open("README.md", encoding="utf-8") as f: + return f.read() + + +def fetch_version(): + with open("version.txt", "r") as f: + return f.read().strip() + + +setup( + name="colossal_moe", + version=fetch_version(), + packages=find_packages( + exclude=( + "tests", + "benchmarks", + "*.egg-info", + ) + ), + description="Colossal-AI MoE", + long_description=fetch_readme(), + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://github.com/hpcaitech", + install_requires=fetch_requirements("requirements.txt"), + python_requires=">=3.6", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", + ], +) diff --git a/applications/ColossalMoE/tests/__init__.py b/applications/ColossalMoE/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/applications/ColossalMoE/tests/test_mixtral_layer.py b/applications/ColossalMoE/tests/test_mixtral_layer.py new file mode 100644 index 000000000000..57589ab20d22 --- /dev/null +++ b/applications/ColossalMoE/tests/test_mixtral_layer.py @@ -0,0 +1,63 @@ +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock +from torch.testing import assert_close +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +import colossalai +from colossalai.moe import MOE_MANAGER +from colossalai.testing.utils import spawn + +tokens, n_experts = 7, 4 +hidden_size = 8 +top_k = 2 + + +def check_mixtral_moe_layer(): + torch.cuda.set_device(dist.get_rank()) + MOE_MANAGER.setup( + parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1 + ) + config = MixtralConfig( + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_local_experts=n_experts, + num_experts_per_tok=top_k, + ) + torch.manual_seed(0) + orig_model = MixtralSparseMoeBlock(config).cuda() + x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda() + orig_output, orig_logits = orig_model(x) + model = deepcopy(orig_model) + model = EPMixtralSparseMoeBlock.from_native_module(model) + ep_output, ep_logits = model(x) + assert_close(orig_logits, ep_logits) + assert_close(orig_output, ep_output) + orig_loss = orig_output.mean() + orig_loss.backward() + ep_loss = ep_output.mean() + ep_loss.backward() + assert_close(orig_loss, ep_loss) + name_to_p = {n: p for n, p in orig_model.named_parameters()} + for n, ep_p in model.named_parameters(): + p = name_to_p[n] + if ep_p.grad is not None: + assert_close(p.grad, ep_p.grad) + + +def run_dist(rank: int, world_size: int, port: int): + colossalai.launch({}, rank, world_size, "localhost", port) + check_mixtral_moe_layer() + + +@pytest.mark.parametrize("world_size", [2, 4]) +def test_mixtral_moe_layer(world_size: int): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_mixtral_moe_layer(2) diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py new file mode 100644 index 000000000000..822e7410f016 --- /dev/null +++ b/applications/ColossalMoE/tests/test_moe_checkpoint.py @@ -0,0 +1,146 @@ +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO +from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from torch.optim import Adam +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.testing.utils import spawn + +tokens, n_experts = 7, 4 +hidden_size = 8 +top_k = 2 + + +def check_model_equal(model1, model2): + assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert torch.equal(p1.half(), p2.half()) + + +def get_optimizer_snapshot(optim): + state = {id(k): deepcopy(v) for k, v in optim.state.items()} + param_groups = [] + for group in optim.param_groups: + params = [id(p) for p in group["params"]] + new_group = {"params": params} + for k, v in group.items(): + if k != "params": + new_group[k] = v + param_groups.append(new_group) + return { + "state": state, + "param_groups": param_groups, + } + + +def check_optimizer_snapshot_equal(snapshot1, snapshot2): + # check param_groups + assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"]) + for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]): + assert set(group1.keys()) == set(group2.keys()) + for k in group1.keys(): + assert group1[k] == group2[k] + # check state + assert set(snapshot1["state"].keys()) == set( + snapshot2["state"].keys() + ), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}" + for pid in snapshot1["state"].keys(): + state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid] + assert set(state1.keys()) == set(state2.keys()) + for k in state1.keys(): + if isinstance(state1[k], torch.Tensor): + assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}" + else: + assert state1[k] == state2[k] + + +def check_mixtral_moe_layer(): + torch.cuda.set_device(dist.get_rank()) + config = MixtralConfig( + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_local_experts=n_experts, + num_experts_per_tok=top_k, + num_attention_heads=2, + num_key_value_heads=2, + ) + torch.manual_seed(0) + input_ids = torch.randint(0, 100, (2, tokens)).cuda() + orig_model = MixtralForCausalLM(config).cuda() + model = deepcopy(orig_model) + optimizer = Adam(model.parameters(), lr=1e-3) + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=2, + ep_size=2, + custom_policy=MixtralForCausalLMPolicy(), + checkpoint_io=MixtralMoEHybridParallelCheckpointIO, + microbatch_size=1, + zero_stage=1, + ) + booster = Booster(plugin=plugin) + model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer) + # initialize grads + data_iter = iter( + [{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}] + ) + booster.execute_pipeline( + data_iter, + model, + lambda outputs, inputs: outputs.loss, + optimizer, + ) + + # check save model + booster.save_model(model, "mixtral_model", shard=True) + dist.barrier() + if dist.get_rank() == 0: + saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda() + check_model_equal(orig_model, saved_model) + saved_model.save_pretrained("mixtral_hf_model") + dist.barrier() + + # check load model + new_model = MixtralForCausalLM(config).cuda() + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer) + booster.load_model(new_model, "mixtral_hf_model") + check_model_equal(model, new_model) + + # check save optimizer + optimizer.step() + for group in optimizer.param_groups: + group["lr"] = 0.1 + snapshot = get_optimizer_snapshot(optimizer.unwrap()) + booster.save_optimizer(optimizer, "mixtral_optim", shard=True) + dist.barrier() + # reset optimizer state + for state in optimizer.unwrap().state.values(): + for v in state.values(): + if isinstance(v, torch.Tensor): + v.zero_() + booster.load_optimizer(optimizer, "mixtral_optim") + loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap()) + check_optimizer_snapshot_equal(snapshot, loaded_snapshot) + + +def run_dist(rank: int, world_size: int, port: int): + colossalai.launch({}, rank, world_size, "localhost", port) + check_mixtral_moe_layer() + + +@pytest.mark.parametrize("world_size", [4]) +def test_mixtral_moe_layer(world_size: int): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_mixtral_moe_layer(4) diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py new file mode 100644 index 000000000000..c567038ec252 --- /dev/null +++ b/applications/ColossalMoE/train.py @@ -0,0 +1,295 @@ +import argparse + +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO +from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import AutoTokenizer +from transformers.models.mixtral import MixtralForCausalLM + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + + +@torch.no_grad() +def get_global_loss(loss, booster): + global_loss = loss.clone().detach() + dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group) + global_loss.div_(booster.plugin.dp_size) + return global_loss + + +class RandomDataset(Dataset): + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } + + +def parse_args(): + # basic settings + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="mistralai/Mixtral-8x7B-v0.1", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint") + parser.add_argument( + "--plugin", + type=str, + default="hybrid", + choices=["hybrid"], + help="Parallel methods.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./outputs", + help="The path of your saved model after finetuning.", + ) + parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.") + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size (per dp group) for the training dataloader.", + ) + parser.add_argument( + "--save_interval", + type=int, + default=1000, + help=" The interval (steps) of saving checkpoints.", + ) + parser.add_argument( + "--precision", + type=str, + default="bf16", + choices=["fp32", "bf16", "fp16"], + help="The mixed precision training.", + ) + parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + + # optim + parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + + # lr scheduler + parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + + # zero stage for all plugins + parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.") + # hybrid plugin + parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin") + parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin") + parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin") + parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin") + + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.", + ) + parser.add_argument( + "--use_layernorm_kernel", + action="store_true", + help="Use layernorm kernel. Need to install apex. Raise error if not installed.", + ) + + # load balance + parser.add_argument( + "--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable." + ) + parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.") + # communicate overlap + parser.add_argument( + "--comm_overlap", + action="store_true", + help="Use communication overlap for MoE. Recommended to enable for muiti-node training.", + ) + # hierarchical all-to-all + parser.add_argument( + "--hierarchical_alltoall", + action="store_true", + help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.", + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + + # Set plugin + if args.plugin == "hybrid": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=args.pp_size, + ep_size=args.ep_size, + microbatch_size=args.microbatch_size, + custom_policy=MixtralForCausalLMPolicy(), + enable_fused_normalization=args.use_layernorm_kernel, + enable_jit_fused=args.use_kernel, + precision=args.precision, + zero_stage=args.zero_stage, + checkpoint_io=MixtralMoEHybridParallelCheckpointIO, + ) + + else: + raise ValueError(f"Invalid plugin {args.plugin}") + coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") + + # Build Mixtral model + model = MixtralForCausalLM.from_pretrained(args.model_name) + coordinator.print_on_master(f"Finish init model") + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Prepare tokenizer and dataloader + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + dataset = RandomDataset(num_samples=100, tokenizer=tokenizer) + collate_fn = None + dataloader = plugin.prepare_dataloader( + dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn + ) + + # Set optimizer + optimizer = HybridAdam( + model_params=model.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + # Set lr scheduler + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, + total_steps=args.num_epochs * len(dataloader), + warmup_steps=args.warmup_steps + if args.warmup_steps is not None + else int(args.num_epochs * len(dataloader) * 0.025), + eta_min=0.1 * args.lr, + ) + + # Set booster + booster = Booster(plugin=plugin) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dataloader=dataloader, + ) + use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + coordinator.print_on_master(f"Finish init booster") + + # Load ckpt + if args.load_checkpoint is not None: + load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler) + coordinator.print_on_master(f"Finish load optimizer") + + # Start finetuning + coordinator.print_on_master(f"Start finetuning") + for epoch in range(args.num_epoch): + model.train() + train_dataloader_iter = iter(dataloader) + total_len = len(train_dataloader_iter) + with tqdm( + range(total_len), + desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", + disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage, + ) as pbar: + for step in pbar: + if use_pipeline: + # Forward pass + outputs = booster.execute_pipeline( + train_dataloader_iter, + model, + lambda x, y: x.loss, + optimizer, + return_loss=True, + return_outputs=True, + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + global_loss = get_global_loss(loss, booster) + if coordinator._local_rank == "0": + pbar.set_postfix({"Loss": global_loss.item()}) + else: + # Forward pass + data = next(train_dataloader_iter) + data = move_to_cuda(data, torch.cuda.current_device()) + outputs = model(**data) + loss = outputs["loss"] + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Apply load balance + # if ( + # args.load_balance + # and args.load_balance_interval > 0 + # and (step + 1) % args.load_balance_interval == 0 + # ): + # coordinator.print_on_master(f"Apply load balance") + # apply_load_balance(model, optimizer) + # save ckeckpoint + if (step + 1) % args.save_interval == 0: + coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") + save_checkpoint( + args.output_path, + booster, + model, + optimizer, + lr_scheduler, + epoch, + step, + args.batch_size, + coordinator, + ) + + # save checkpoint at the end of each epochs + booster.save_model(model, args.output_path, shard=True, size_per_shard=5120) + coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") + + # Finish training + coordinator.print_on_master(f"Finish training") + + +if __name__ == "__main__": + main() diff --git a/applications/ColossalMoE/train.sh b/applications/ColossalMoE/train.sh new file mode 100644 index 000000000000..bee7f5c8fdf8 --- /dev/null +++ b/applications/ColossalMoE/train.sh @@ -0,0 +1,19 @@ +NUM_GPU=8 +MODEL="mistralai/Mixtral-8x7B-v0.1" +SEQ_LENGTH=2048 +BATCH_SIZE=1 +LR=0.00001 + +# hybrid +# torchrun --standalone --nproc_per_node $NUM_GPU \ +colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile" \ + train.py \ + --num_epoch 1 \ + --model_name $MODEL \ + --plugin "hybrid" \ + --batch_size $BATCH_SIZE \ + --lr $LR \ + --zero_stage 1 \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 8 \ diff --git a/applications/ColossalMoE/version.txt b/applications/ColossalMoE/version.txt new file mode 100644 index 000000000000..3eefcb9dd5b3 --- /dev/null +++ b/applications/ColossalMoE/version.txt @@ -0,0 +1 @@ +1.0.0 diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index 0b7b51a71955..7439ad5d3526 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -4,9 +4,9 @@ from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size +from colossalai.auto_parallel.tensor_shard.constants import BCAST_FUNC_OP from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from ..constants import BCAST_FUNC_OP from ..registry import meta_register __all__ = ["binary_elementwise_meta_info"] diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py index d2dd00453e32..27285f95ce52 100644 --- a/colossalai/booster/plugin/dp_plugin_base.py +++ b/colossalai/booster/plugin/dp_plugin_base.py @@ -21,7 +21,16 @@ def __init__(self) -> None: self.world_size = dist.get_world_size() def prepare_dataloader( - self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by @@ -45,7 +54,8 @@ def prepare_dataloader( :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) # Deterministic dataloader def seed_worker(worker_id): diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index d14109dd43e5..95b96bbfd9ed 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -456,7 +456,16 @@ def supported_devices(self) -> List[str]: return ["cuda", "npu"] def prepare_dataloader( - self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by @@ -484,7 +493,8 @@ def prepare_dataloader( extra_dp_world_size = self.pg_mesh.size(DP_AXIS) zero_rank = self.pg_mesh.coordinate(ZERO_AXIS) extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS) - sampler = DistributedSampler( + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls( dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 5837156a90cd..da67e6b41fbf 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,5 +1,6 @@ import ctypes import random +import warnings from contextlib import contextmanager from functools import partial from types import MethodType @@ -1134,7 +1135,12 @@ def configure( tp_process_group=self.tp_group, ) else: - assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + if self.dp_size == 1: + warnings.warn( + "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " + "If you are not intended to use cpu_offload, please consider set zero_stage=0." + ) + assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = HybridParallelZeroOptimizer( optimizer, @@ -1199,7 +1205,16 @@ def execute_pipeline( return outputs def prepare_dataloader( - self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by @@ -1223,7 +1238,8 @@ def prepare_dataloader( :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() - sampler = DistributedSampler( + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls( dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle ) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index e976d0aaf014..45e5a23c1b22 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -22,7 +22,7 @@ ) from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.moe import MoECheckpintIO +from colossalai.moe import MOE_MANAGER, MoECheckpintIO from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig @@ -150,6 +150,7 @@ def __init__( self, tp_size: int, pp_size: int, + ep_size: int, extra_dp_size: int = 1, precision: str = "fp16", zero_stage: int = 0, @@ -181,6 +182,7 @@ def __init__( overlap_communication: bool = True, use_ep_inside: bool = True, custom_policy: Policy = None, + checkpoint_io: Optional[MoECheckpintIO] = None, ) -> None: assert ( dist.get_world_size() % (tp_size * pp_size) == 0 @@ -188,10 +190,26 @@ def __init__( if enable_sequence_parallelism: assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" - + assert ( + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + assert ( + dist.get_world_size() % (tp_size * pp_size * ep_size) == 0 + ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}" + self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size) + MOE_MANAGER.setup( + parallel="EP", + mode="fixed", + fixed_dp_size=self.real_dp_size, + fixed_ep_size=ep_size, + fixed_pp_size=pp_size, + use_ep_inside=use_ep_inside, + ) self.tp_size = tp_size self.pp_size = pp_size self.dp_size = dist.get_world_size() // (tp_size * pp_size) + self.ep_size = ep_size + self.moe_info = MOE_MANAGER.get_info(0)[1] self.precision = precision self.zero_stage = zero_stage self.cpu_offload = cpu_offload @@ -200,6 +218,7 @@ def __init__( self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism + self.checkpoint_io = checkpoint_io # we change pg mesh to (pp, dp, tp) for better moe performance self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size) @@ -323,7 +342,10 @@ def seed_worker(worker_id): ) def get_checkpoint_io(self) -> MoECheckpintIO: - self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + if self.checkpoint_io is None: + self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + else: + self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) return self.checkpoint_io def configure( diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 2ea7593a5cc5..5445b4a6349d 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,3 +1,5 @@ +import logging +import os import warnings from pathlib import Path from typing import Callable, Iterable, Iterator, List, Optional, Tuple @@ -25,7 +27,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader -from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils +from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils, CheckpointIndexFile from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper @@ -74,17 +76,54 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, def save_sharded_model( self, - model: nn.Module, - checkpoint: str, - gather_dtensor: bool, - prefix: Optional[str], - size_per_shard: int, - use_safetensors: bool, + model: ModelWrapper, + checkpoint_path: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, ): """ Save model to checkpoint but only on master process. """ - raise NotImplementedError("Sharded model checkpoint is not supported yet.") + assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" + if os.path.isfile(checkpoint_path): + logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + return + + Path(checkpoint_path).mkdir(parents=True, exist_ok=True) + with FSDP.state_dict_type( + model.unwrap(), + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + ): + state_dict = model.unwrap().state_dict() + + state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard) + + weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint_path) + + # In general cases, is_master is set to True to get the right behavior. + total_size = utils.save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=self.coordinator.is_master(), + use_safetensors=use_safetensors, + ) + + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + utils.save_config_file(model.unwrap(), checkpoint_path) + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) def load_sharded_model( self, @@ -97,7 +136,24 @@ def load_sharded_model( """ Load model to checkpoint but only on master process. """ - raise NotImplementedError("Sharded model checkpoint is not supported yet.") + assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not utils.is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # read checkpoint index file + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + + fsdp_state_dict = {} + for shard_file in checkpoint_files: + fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors)) + + with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT): + model.unwrap().load_state_dict(fsdp_state_dict, strict=False) def save_sharded_optimizer( self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int @@ -105,13 +161,86 @@ def save_sharded_optimizer( """ Save optimizer to checkpoint but only on master process. """ - raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + with FSDP.state_dict_type( + optimizer.unwrap_model().unwrap(), + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + ): + fsdp_optim_state = FSDP.full_optim_state_dict( + optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True + ) + + if self.coordinator.is_master(): + # Preparing file paths and index file. + states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + utils.save_param_groups(fsdp_optim_state, group_file_path) + + sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard) + + # Save shards of optimizer states. + # In general cases, is_master is set to True to get the right behavior. + total_size = utils.save_state_dict_shards( + sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=self.coordinator.is_master(), + use_safetensors=False, + ) + + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int): """ Load optimizer to checkpoint but only on master process. """ - raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" + + ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {index_file_path} for an optimizer. " + "Looking param group file under current directory." + ) + + saved_param_groups = torch.load(param_group_path) + + # Load param + fsdp_optim_state = {} + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + for shard_file in checkpoint_files: + state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False) + fsdp_optim_state.update(state_dict_shard) + + fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups) + + with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT): + fsdp_state = FSDP.optim_state_dict_to_load( + model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict + ) + optimizer.load_state_dict(fsdp_state) + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ @@ -190,7 +319,7 @@ def __init__( raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") def support_no_sync(self) -> bool: - False + return False def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: raise NotImplementedError("Torch fsdp no_sync func not supported yet.") diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 780117598e18..71232421586d 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -9,7 +9,7 @@ from colossalai.interface import ModelWrapper -from .utils import has_index_file +from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file __all__ = ["CheckpointIO"] @@ -90,7 +90,15 @@ def load_model( if index_file_exists: self.load_sharded_model(model, index_file_path, strict) else: - self.load_unsharded_model(model, checkpoint, strict) + path = Path(checkpoint, SAFE_WEIGHTS_NAME) + if path.is_file(): + self.load_unsharded_model(model, str(path), strict) + else: + path = Path(checkpoint, WEIGHTS_NAME) + if path.is_file(): + self.load_unsharded_model(model, str(path), strict) + else: + self.load_unsharded_model(model, checkpoint, strict) return origin_model diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index b7900bc0f217..36df30335dd7 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -1,7 +1,7 @@ import copy -from functools import reduce import logging import os +from functools import reduce from pathlib import Path from shutil import rmtree from typing import Dict, Iterator, Optional, OrderedDict, Tuple @@ -14,6 +14,7 @@ from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.utils import get_current_device from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -445,7 +446,11 @@ def save_sharded_optimizer( # Store param groups. index_file.append_meta_data("param_groups", param_group_file) group_file_path = os.path.join(checkpoint, param_group_file) - save_param_groups(optimizer.param_info, group_file_path) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + save_param_groups({"param_groups": param_groups}, group_file_path) # Store index file. index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) @@ -504,7 +509,11 @@ def save_sharded_optimizer( # Store param groups. final_index_file.append_meta_data("param_groups", param_group_file) group_file_path = os.path.join(checkpoint, param_group_file) - save_param_groups(optimizer.param_info, group_file_path) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + save_param_groups({"param_groups": param_groups}, group_file_path) final_index_file.write_index_file(final_index_file_path) rmtree(tmp_index_file_folder) @@ -713,12 +722,16 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, tp_group=self.tp_group, use_zero=self.use_zero, inplace=False, - device=torch.device("cuda"), + device=get_current_device(), ) if self.pp_size == 1: # When pipeline is not used, let master rank directly save the collected state_dict. - state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states} + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + state_dict = {"param_groups": param_groups, "state": local_states} if self.coordinator.is_master(): save_state_dict(state_dict, checkpoint, use_safetensors=False) else: @@ -729,7 +742,11 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, # Only the master rank do the saving. if self.coordinator.is_master(): - state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()} + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + state_dict = {"param_groups": param_groups, "state": dict()} for _states in states_list: state_dict["state"].update(_states) save_state_dict(state_dict, checkpoint, use_safetensors=False) @@ -838,7 +855,7 @@ def gather_from_sharded_optimizer_state( if isinstance(v, torch.Tensor) and k != "step": # First gather Zero shards. if use_zero: - v = v.cuda() + v = v.to(get_current_device()) gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] dist.all_gather(gather_tensor, v, group=dp_group) v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index 721da69d0741..6dd0a5fc3c52 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,6 +1,7 @@ from .checkpoint import MoECheckpintIO from .experts import MLPExperts -from .layers import SparseMLP +from .layers import SparseMLP, apply_load_balance +from .manager import MOE_MANAGER from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter from .utils import NormalNoiseGenerator, UniformNoiseGenerator @@ -14,4 +15,6 @@ "UniformNoiseGenerator", "SparseMLP", "MoECheckpintIO", + "MOE_MANAGER", + "apply_load_balance", ] diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 34342436f263..01c837ee36ad 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple import torch import torch.distributed as dist @@ -329,3 +329,68 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: if ctx.ep_size != 1: grad = grad / ctx.ep_size return grad, None + + +def _all_to_all( + inputs: torch.Tensor, + input_split_sizes: Optional[List[int]] = None, + output_split_sizes: Optional[List[int]] = None, + group=None, + async_op: bool = False, +): + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + outputs_shape = list(inputs.shape) + if output_split_sizes is not None: + outputs_shape[0] = sum(output_split_sizes) + outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device) + inputs = inputs.contiguous() + outputs = outputs.contiguous() + handle = dist.all_to_all_single( + outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op + ) + return outputs, handle + + +class AllToAllUneven(torch.autograd.Function): + @staticmethod + def forward( + ctx, + inputs, + input_split_sizes=None, + output_split_sizes=None, + group=None, + overlap: bool = False, + ): + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + ctx.input_split_sizes = input_split_sizes + ctx.output_split_sizes = output_split_sizes + ctx.group = group + return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap) + + @staticmethod + def backward(ctx: Any, *grad_outputs): + return ( + _all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0], + None, + None, + None, + None, + ) + + +def all_to_all_uneven( + inputs: torch.Tensor, + input_split_sizes: Optional[List[int]] = None, + output_split_sizes: Optional[List[int]] = None, + group=None, + overlap: bool = False, +): + return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap) diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py index a8c50eab66e3..b37ffabea41f 100644 --- a/colossalai/moe/checkpoint.py +++ b/colossalai/moe/checkpoint.py @@ -224,6 +224,7 @@ def save_sharded_model( size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. """ + torch.cuda.empty_cache() if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return @@ -265,6 +266,7 @@ def save_sharded_model( f"index located at {save_index_file}." ) dist.barrier() + torch.cuda.empty_cache() # ======================================================== # Abstract methods for optimizer loading/saving implementation @@ -332,10 +334,12 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_f assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" def _get_param_id_from_optimizer_param( - param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None ): if master_to_working_map is not None and id(param) in master_to_working_map: working_param = master_to_working_map[id(param)] + elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map: + working_param = optimizer.moe_master_to_working_map[id(param)] else: working_param = param return optimizer.param_info["param2id"][id(working_param)] @@ -347,7 +351,7 @@ def _get_param_id_from_optimizer_param( master_to_working_map = optimizer.get_master_to_working_map() for pg in optimizer.optim.param_groups: for param in pg["params"]: - param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer) id_map[param_id] = param # Read checkpoint index file. @@ -371,14 +375,10 @@ def _get_param_id_from_optimizer_param( new_pg = copy.deepcopy(saved_pg) new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. updated_groups.append(new_pg) - # ep extra group - if MOE_MANAGER.parallel == "EP": + # ep param group + if len(optimizer.optim.param_groups) > len(saved_groups): new_pg = copy.deepcopy(saved_pg) - new_pg["params"] = optimizer.optim.param_groups[-1][ - "params" - ] # Only keep the parameters kept by current pipeline stage. - for param in new_pg["params"]: - param.data = param.data.to(torch.float32) + new_pg["params"] = optimizer.optim.param_groups[-1]["params"] updated_groups.append(new_pg) optimizer.optim.__dict__.update({"param_groups": updated_groups}) @@ -389,7 +389,7 @@ def _get_param_id_from_optimizer_param( for param in pg["params"]: if param is None: continue - param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer) if param_id not in weight_map: continue filename = weight_map[param_id] @@ -400,27 +400,34 @@ def _get_param_id_from_optimizer_param( file_path = os.path.join(ckpt_root_path, filename) state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + + # Then shard the loaded optimizer states if using tp/zero. + for pid, state in list(state_dict.items()): + if pid in id_map: + param = id_map[pid] + if master_to_working_map is not None and id(param) in master_to_working_map: + working_param = master_to_working_map[id(param)] + elif ( + hasattr(optimizer, "moe_master_to_working_map") + and id(param) in optimizer.moe_master_to_working_map + ): + working_param = optimizer.moe_master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.pre_load_optim( + state, + working_param, + current_shape=working_param.shape, + original_shape=original_shape, + device="cpu", + inplace=True, + ) + state_dict[pid] = sharded_state + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) loaded_file.add(filename) - # Then shard the loaded optimizer states if using tp/zero. - for param, state in optimizer.optim.state.items(): - device = param.device - if master_to_working_map is not None and id(param) in master_to_working_map: - working_param = master_to_working_map[id(param)] - else: - working_param = param - original_shape = optimizer.param_info["param2shape"][id(working_param)] - sharded_state = self.pre_load_optim( - state, - param, - current_shape=working_param.shape, - original_shape=original_shape, - device=device, - inplace=True, - ) - optimizer.optim.state[param] = sharded_state - sharded_optimizer_loading_epilogue(optimizer.optim) if self.verbose and self.coordinator.is_master(): logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") @@ -576,6 +583,8 @@ def _optimizer_sharder( if master_to_working_map is not None and id(param) in master_to_working_map: working_param = master_to_working_map[id(param)] + elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map: + working_param = optimizer.moe_master_to_working_map[id(param)] else: working_param = param @@ -618,6 +627,7 @@ def save_sharded_optimizer( prefix (str): Perfix of file to save size_per_shard (int): Max file size of each file shard that store state tensors """ + torch.cuda.empty_cache() assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") @@ -723,6 +733,7 @@ def save_sharded_optimizer( f"You can find where each parameters has been saved in the " f"index located at {final_index_file_path}." ) + torch.cuda.empty_cache() def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): """ diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 477b76547c7e..8e6ea3884df4 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -67,7 +67,11 @@ def __init__( self.ep_size = 1 if gated: - self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2)) + self.wi_gate = nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size + ) + ) self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) else: self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index b768fb94a585..2ac5b186d116 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -51,6 +51,8 @@ def __init__( hidden_size: int, intermediate_size: int, router_top_k: int = 1, + router_loss: bool = True, + router_norm: bool = False, router_capacity_factor_train: float = 1.25, router_capacity_factor_eval: float = 2.0, router_min_capacity: int = 4, @@ -65,15 +67,19 @@ def __init__( enable_kernel: bool = False, enable_comm_overlap: bool = False, enable_hierarchical_comm: bool = False, + return_gate_logits: bool = False, ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_experts = num_experts self.gated = mlp_gated + self.return_gate_logits = return_gate_logits self.enable_kernel = enable_kernel self.enable_comm_overlap = enable_comm_overlap self.expert_parallel = MOE_MANAGER.get_parallel() + self.router_loss = router_loss + self.router_norm = router_norm # moe router noisy_func = get_noise_generator(router_noisy_policy, num_experts) @@ -150,9 +156,8 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: tokens = inputs.reshape(-1, self.hidden_size) # the data type of the inputs in the gating should be fp32 - fp32_input = tokens.to(torch.float) - fp32_weight = self.gate_weight.to(torch.float) - gate_output = F.linear(fp32_input, fp32_weight) + gate_logits = F.linear(tokens, self.gate_weight) + gate_output = gate_logits.to(torch.float) # update expert load if self.enable_load_balance == True: @@ -165,7 +170,12 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: # the result from the router used_capacity, *route_result_list = self.router( - inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) + inputs=gate_output, + use_kernel=self.enable_kernel, + ep_group=self.ep_group, + use_loss=self.router_loss, + use_norm=self.router_norm, + ) # dispatch_data: (num_experts, capacity, hidden_size) if self.enable_kernel: @@ -177,22 +187,15 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: # expert_output: (num_groups, num_experts, capacity, hidden_size) if self.expert_parallel == "EP": - expert_output = self._ep_process( - dispatch_data, - used_capacity, - overlap=self.enable_comm_overlap - ) + expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) elif self.expert_parallel == "TP": - expert_output = self._tp_process( - dispatch_data, - used_capacity, - overlap=self.enable_comm_overlap - ) + expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) elif self.expert_parallel is None: expert_output = self._local_process(dispatch_data) else: - raise NotImplementedError("This kind of communication has not been implemented yet.\n" - "Please use Experts build function.") + raise NotImplementedError( + "This kind of communication has not been implemented yet.\n" "Please use Experts build function." + ) if self.enable_kernel: expert_output = expert_output.reshape(-1, self.hidden_size) @@ -204,7 +207,11 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: ans = torch.matmul(combine_weights, expert_output) ans = ans.reshape(inputs.shape) - return ans + + if self.return_gate_logits: + return ans, gate_logits + else: + return ans def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_in = expert_in.unsqueeze(0) @@ -212,10 +219,7 @@ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: return expert_out def _ep_process( - self, - dispatch_data: torch.Tensor, - used_capacity: torch.Tensor, - overlap: bool = False + self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False ) -> torch.Tensor: """ Expert Parallel @@ -228,10 +232,14 @@ def _ep_process( """ if not overlap or dist.get_world_size(self.ep_group) == 1: if self.ep_hierarchical_group is not None: - expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank) + expert_input = HierarchicalAllToAll.apply( + dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank + ) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_output = self.experts(expert_input) - expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank) + expert_output = HierarchicalAllToAll.apply( + expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank + ) return expert_output else: expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] @@ -249,7 +257,7 @@ class Capsule: NUM_CHUNK = 4 NUM_STAGES = 4 - assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet" + assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet" chunk_size = dispatch_data.shape[1] // NUM_CHUNK input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) dispatch_data = dispatch_data.reshape(*input_shape) @@ -262,13 +270,15 @@ class Capsule: for i in range(NUM_CHUNK + NUM_STAGES - 1): if expert_out is not None: expert_out.handle.wait() - output[:, :, offset:offset + chunk_size, :] = expert_out.data + output[:, :, offset : offset + chunk_size, :] = expert_out.data offset += chunk_size expert_out = None # all2all last output if _expert_out is not None: - expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),) + expert_out = Capsule( + *AllToAll.apply(_expert_out.data, self.ep_group, True), + ) _expert_out = None # all2all next input @@ -288,10 +298,7 @@ class Capsule: return output def _tp_process( - self, - dispatch_data: torch.Tensor, - used_capacity: torch.Tensor, - overlap: bool = False + self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False ) -> torch.Tensor: """ without overlap: @@ -326,8 +333,9 @@ class Capsule: NUM_CHUNK = 4 NUM_STAGES = 4 - assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ - "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" + assert ( + dispatch_data.shape[0] % NUM_CHUNK == 0 + ), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" chunk_size = dispatch_data.shape[0] // NUM_CHUNK chunk_data = torch.split(dispatch_data, chunk_size, dim=0) output = torch.empty_like(dispatch_data) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index f5815d05d111..e40674c9bb44 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -45,9 +45,13 @@ def __init__( self._z_loss = None self.use_kernel = use_kernel - def get_capacity(self, logits_shape): + def get_capacity(self, num_tokens, num_experts, ep_group=None): + if ep_group is not None: + num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device()) + dist.all_reduce(num_tokens_tensor, group=ep_group) + num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group) capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval - capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts) capacity += capacity % 2 capacity = max(capacity, self.min_capacity) assert capacity > 0 @@ -150,7 +154,14 @@ def __init__( high=torch.tensor(1.0, device=get_accelerator().get_current_device()), ).rsample - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: + def forward( + self, + inputs: torch.Tensor, + use_kernel: bool = False, + ep_group: Optional[ProcessGroup] = None, + use_loss: bool = False, + use_norm: bool = False, + ) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). @@ -168,7 +179,8 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti assert inputs.dtype == torch.float probs = F.softmax(inputs, dim=-1) num_experts = probs.size(-1) - capacity = self.get_capacity(inputs.shape) + num_tokens = inputs.size(0) + capacity = self.get_capacity(num_tokens, num_experts, ep_group) top1_idx = torch.argmax(inputs, dim=-1) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) @@ -207,7 +219,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti weight = mask * probs.type_as(inputs) combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) sec_mask = combine_weights.bool() - return used_capacity, combine_weights, sec_mask + return used_capacity, combine_weights, sec_mask, probs class Top2Router(MoeRouter): @@ -240,7 +252,14 @@ def __init__( drop_tks=drop_tks, ) - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: + def forward( + self, + inputs: torch.Tensor, + use_kernel: bool = False, + ep_group: Optional[ProcessGroup] = None, + use_norm: bool = False, + use_loss: bool = True, + ) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). @@ -257,8 +276,13 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti assert inputs.dtype == torch.float probs = F.softmax(inputs, dim=-1) + if use_norm: + routing_weights, _ = torch.topk(probs, 2, dim=-1) + probs = probs / routing_weights.sum(dim=-1, keepdim=True) + num_experts = probs.size(-1) - capacity = self.get_capacity(inputs.shape) + num_tokens = inputs.size(0) + capacity = self.get_capacity(num_tokens, num_experts, ep_group) top1_idx = torch.argmax(probs, dim=-1) mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) @@ -270,10 +294,11 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 # calculate loss - expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) - self.set_aux_loss(probs, expert_indices, num_experts) - self.set_z_loss(inputs) - self.pop_router_loss() + if use_loss: + expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) + self.set_aux_loss(probs, expert_indices, num_experts) + self.set_z_loss(inputs) + self.pop_router_loss() if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(cmask, dim=0)) diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index e25e7dd48892..c642f1a4450f 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable: return torch.nn.GELU() elif act == "swiglu": return SwiGLU + elif act == "silu": + return torch.nn.SiLU() else: raise NotImplementedError("Unsupported activation function") diff --git a/colossalai/nn/lr_scheduler/delayed.py b/colossalai/nn/lr_scheduler/delayed.py index 9d1d8f01dd2d..e55e82280a5f 100644 --- a/colossalai/nn/lr_scheduler/delayed.py +++ b/colossalai/nn/lr_scheduler/delayed.py @@ -6,6 +6,8 @@ else: from torch.optim.lr_scheduler import _LRScheduler +from colossalai.logging import get_dist_logger + class _enable_get_lr_call: def __init__(self, o): @@ -19,7 +21,39 @@ def __exit__(self, type, value, traceback): self.o._get_lr_called_within_step = False -class DelayerScheduler(_LRScheduler): +class TwoStageScheduler(_LRScheduler): + def __init__(self, optimizer, after_scheduler: _LRScheduler, last_epoch=-1): + self.after_scheduler = after_scheduler + self.finished = False + super().__init__(optimizer, last_epoch) + + def state_dict(self): + state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} + if isinstance(state_dict["after_scheduler"], _LRScheduler): + state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ + state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() + del state_dict["after_scheduler"] + else: + raise NotImplementedError() + return state_dict + + def load_state_dict(self, state_dict): + if "after_scheduler_dict" not in state_dict: + logger = get_dist_logger() + logger.warning( + "after_scheduler_dict is not found, skip loading after_scheduler. This may cause unexpected behavior." + ) + else: + self.after_scheduler.load_state_dict(state_dict["after_scheduler_dict"]) + state_dict = { + key: value + for key, value in state_dict.items() + if key not in ("after_scheduler_type", "after_scheduler_dict") + } + super().load_state_dict(state_dict) + + +class DelayerScheduler(TwoStageScheduler): """Starts with a flat lr schedule until it reaches N epochs then applies the specific scheduler (For example: ReduceLROnPlateau) @@ -35,19 +69,7 @@ def __init__(self, optimizer, delay_epochs, after_scheduler, last_epoch=-1): if delay_epochs < 0: raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}") self.delay_epochs = delay_epochs - self.after_scheduler = after_scheduler - self.finished = False - super().__init__(optimizer, last_epoch) - - def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} - if isinstance(state_dict["after_scheduler"], _LRScheduler): - state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ - state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() - del state_dict["after_scheduler"] - else: - raise NotImplementedError() - return state_dict + super().__init__(optimizer, after_scheduler, last_epoch) def get_lr(self): if self.last_epoch >= self.delay_epochs: @@ -71,7 +93,7 @@ def step(self, epoch=None): return super(DelayerScheduler, self).step(epoch) -class WarmupScheduler(_LRScheduler): +class WarmupScheduler(TwoStageScheduler): """Starts with a linear warmup lr schedule until it reaches N epochs then applies the specific scheduler (For example: ReduceLROnPlateau). @@ -85,19 +107,7 @@ class WarmupScheduler(_LRScheduler): def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1): self.warmup_epochs = int(warmup_epochs) - self.after_scheduler = after_scheduler - self.finished = False - super().__init__(optimizer, last_epoch) - - def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} - if isinstance(state_dict["after_scheduler"], _LRScheduler): - state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ - state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() - del state_dict["after_scheduler"] - else: - raise NotImplementedError() - return state_dict + super().__init__(optimizer, after_scheduler, last_epoch) def get_lr(self): if self.last_epoch >= self.warmup_epochs: @@ -120,7 +130,7 @@ def step(self, epoch=None): return super().step(epoch) -class WarmupDelayerScheduler(_LRScheduler): +class WarmupDelayerScheduler(TwoStageScheduler): """Starts with a linear warmup lr schedule until it reaches N epochs and a flat lr schedule until it reaches M epochs then applies the specific scheduler (For example: ReduceLROnPlateau). @@ -140,19 +150,7 @@ def __init__(self, optimizer, warmup_epochs, delay_epochs, after_scheduler, last raise ValueError(f"warmup_epochs must >= 0, got {warmup_epochs}") self.warmup_epochs = warmup_epochs self.delay_epochs = delay_epochs - self.after_scheduler = after_scheduler - self.finished = False - super().__init__(optimizer, last_epoch) - - def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} - if isinstance(state_dict["after_scheduler"], _LRScheduler): - state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ - state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() - del state_dict["after_scheduler"] - else: - raise NotImplementedError() - return state_dict + super().__init__(optimizer, after_scheduler, last_epoch) def get_lr(self): if self.last_epoch >= self.warmup_epochs + self.delay_epochs: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index e10a7ed7da0c..92c709218a26 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -16,6 +16,7 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d +from ..layer._operation import _gather try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask @@ -288,6 +289,9 @@ def llama_for_causal_lm_forward( shift_logits = shift_logits.view(-1, self.config.vocab_size) loss = loss_fct(shift_logits, shift_labels) + if not shard_config.parallel_output: + logits = _gather(logits, -1, shard_config.tensor_parallel_process_group) + if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -588,6 +592,9 @@ def forward( shift_logits = shift_logits.view(-1, self.config.vocab_size) loss = loss_fct(shift_logits, shift_labels) + if not shard_config.parallel_output: + logits = _gather(logits, -1, shard_config.tensor_parallel_process_group) + if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index b5c9e66e0b87..415fc6dd5f06 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -34,6 +34,7 @@ class ShardConfig: enable_all_optimization: bool = False enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False + parallel_output = True extra_kwargs: Dict[str, Any] = field(default_factory=dict) # pipeline_parallel_size: int # data_parallel_size: int diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 5301c87b9836..acb9fc4ae8fc 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -7,11 +7,12 @@ from .colo_tensor import _convert_output -WHITE_LIST_FUNCS = {torch.Tensor.__getitem__, torch.Tensor.is_floating_point} +WHITE_LIST_FUNCS = {torch.Tensor.__getitem__} +NO_HOOK_FUNCS = {torch.Tensor.is_floating_point} def is_no_hook_op(func) -> bool: - return func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS + return (func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS) or func in NO_HOOK_FUNCS def filter_colo_parameters(*args, **kwargs): diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py index ba6c77056222..5ac3c2b3a57e 100644 --- a/colossalai/tensor/moe_tensor/moe_info.py +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -26,3 +26,5 @@ def __init__(self, ep_inside: bool, ep_size: int, dp_size: int, pp_size: int = 1 self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) self.dp_group = self.pg.get_group_along_axis(self.dp_axis) self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group) + self.ep_rank = self.pg.coordinate(self.ep_axis) + self.dp_rank = self.pg.coordinate(self.dp_axis) diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index 1fe99cd89a4e..40de43c43b05 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -92,7 +92,10 @@ def pre_op(params: List[torch.Tensor], *args: Any) -> list: @staticmethod def post_op(params: List[torch.Tensor], arg: Any) -> Any: ColoParamOpHookManager._trigger_post_forward(params) - return PostFwdPreBwd.apply(params, arg) + # incase the output is a tuple, we have to flatten it + grad_args, other_args, grad_flags, spec = _flatten_grad_args(arg) + new_grad_args = PostFwdPreBwd.apply(params, *grad_args) + return _merge_args(new_grad_args, other_args, grad_flags, spec) @staticmethod def has_hook() -> bool: @@ -113,7 +116,7 @@ def backward(ctx, *grads): class PostFwdPreBwd(torch.autograd.Function): @staticmethod - def forward(ctx, params, args): + def forward(ctx, params, *args): ctx.params = params return args @@ -142,7 +145,6 @@ def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]: grad_args.append(arg) else: other_args.append(arg) - assert len(grad_args) > 0 return grad_args, other_args, grad_flags, spec diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 79831cf33dbc..bc6c9d088094 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -726,11 +726,13 @@ def load_parameter(chunk_slice, data): chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end]) del temp_chunk - if self.reuse_fp16_chunk: - for chunk_32 in chunk_list: - chunk_16 = chunk_32.paired_chunk - assert chunk_16 is not None - chunk_16.payload.copy_(chunk_32.payload) + + # sync running weights and master weights + if self.master_weights: + for loaded_chunk in chunk_list: + paired_chunk = loaded_chunk.paired_chunk + assert paired_chunk is not None + paired_chunk.payload.copy_(loaded_chunk.payload) for name, buf in persistent_buffers.items(): if buf is not None: diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 98fbb0c50e24..18367af59d80 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -621,7 +621,10 @@ def get_param_groups_for_saving(self) -> list: Return the param_groups in Pytorch format when saving to checkpoint. """ - param_groups = copy.deepcopy(self.param_groups_backup) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(self.optim.param_groups, self.param_groups_backup) + ] # To be compatible with pytorch checkpointing, # store extra hyperparameters used by pytorch Adam optimizer. diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e01c852bee50..a2433d1b261c 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -141,7 +141,7 @@ def __init__( # because they have different parallel strategy # so we need to store them separately in param_groups # instead of working_groups - moe_params = list() + self.working_moe_params = list() # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -153,7 +153,7 @@ def __init__( if self.moe_extra_dp_pg is None: # skip moe param if is_moe_tensor(param): - moe_params.append(param) + self.working_moe_params.append(param) continue group_params.append(param) @@ -168,13 +168,23 @@ def __init__( # managed by this data parallel rank param_group["params"] = master_param_current_rank - # if there are moe params, store in additional group in optim - if len(moe_params) > 0: + # if there are moe params, store in addtional group in optim + if len(self.working_moe_params) > 0: + self._sync_master_param = False param_group = dict() + # create fp32 master param for key, value in self.optim.param_groups[0].items(): if key != "params": param_group[key] = value - param_group["params"] = moe_params + self.master_moe_params = [] + for param in self.working_moe_params: + self.master_moe_params.append(param.clone().to(torch.float32).detach()) + # create mapping from master to working for optimizer io + self.moe_master_to_working_map = {} + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param + # add to optim + param_group["params"] = self.master_moe_params self.optim.param_groups.append(param_group) # initialize communication stream for @@ -593,24 +603,40 @@ def step(self, closure=None): # update the params in the optimizer self.optim.param_groups[group_id]["params"] = real_master_params[group_id] + # update param for moe ep + # move grad to master param and compute norm + if len(self.working_moe_params) > 0: + moe_grads = [] + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + if master_moe_param.grad is not None: + raise RuntimeError("Moe param should not have grad here") + grad = working_moe_param.grad + # no need to copy fp32 grad if master_weights is False + if self._master_weights: + grad = grad.to(master_moe_param.dtype).to(master_moe_param.device) + master_moe_param.grad = grad + working_moe_param.grad = None + moe_grads.append(grad) + grad_partition_groups.append(grad) + norm_group = self._compute_grad_norm(gradients=moe_grads) + norm_groups.append(norm_group) + self.optim.param_groups[-1]["params"] = self.master_moe_params + del moe_grads + # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) self._unscale_and_clip_grads(grad_partition_groups, global_norm) - # TODO: we should store master param for ep - if len(self.param_groups) > len(self._working_param_groups): - for param in self.param_groups[-1]["params"]: - param.data = param.data.to(torch.float32) - param.grad = param.grad.to(torch.float32) - # update the parameters self.optim.step() - # release the moe gradm - if len(self.param_groups) > len(self._working_param_groups): - for param in self.param_groups[-1]["params"]: - param.grad = None - param.data = param.data.to(self._dtype) + # release moe grad + if len(self.working_moe_params) > 0: + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + master_moe_param.grad = None + working_moe_param.data = ( + master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach() + ) # release the grad grad_partition_groups = [] @@ -885,9 +911,14 @@ def update_master_params(self, model: nn.Module) -> None: master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) else: master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) + if hasattr(self, "master_moe_params"): + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + master_moe_param.copy_(working_moe_param) def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: return self._param_store.working_to_master_param def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: + if hasattr(self, "moe_master_to_working_map"): + return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map} return self._param_store.master_to_working_param diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 0c438c726baa..c25f19795a20 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -9,7 +9,7 @@ 文档 | 例程 | 论坛 | - 博客 + 博客 [![GitHub Repo stars](https://img.shields.io/github/stars/hpcaitech/ColossalAI?style=social)](https://github.com/hpcaitech/ColossalAI/stargazers) [![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml) diff --git a/docs/source/en/get_started/installation.md b/docs/source/en/get_started/installation.md index 18607a34cf65..f9c8fe4758c8 100644 --- a/docs/source/en/get_started/installation.md +++ b/docs/source/en/get_started/installation.md @@ -23,7 +23,7 @@ pip install colossalai If you want to build PyTorch extensions during installation, you can use the command below. Otherwise, the PyTorch extensions will be built during runtime. ```shell -CUDA_EXT=1 pip install colossalai +BUILD_EXT=1 pip install colossalai ``` @@ -39,7 +39,7 @@ cd ColossalAI pip install -r requirements/requirements.txt # install colossalai -CUDA_EXT=1 pip install . +BUILD_EXT=1 pip install . ``` If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer), just don't specify the `CUDA_EXT`: @@ -61,7 +61,7 @@ unzip 1.8.0.zip cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ # install -CUDA_EXT=1 pip install . +BUILD_EXT=1 pip install . ``` diff --git a/docs/source/zh-Hans/features/1D_tensor_parallel.md b/docs/source/zh-Hans/features/1D_tensor_parallel.md index fb6fd90ec4c2..481efe98ac12 100644 --- a/docs/source/zh-Hans/features/1D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/1D_tensor_parallel.md @@ -19,10 +19,8 @@ 当第二个线性层 $Z=YB$ 跟随上述列并行层的时候, 我们把 $B$ 划分为 $$ \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right] -``` -这就是所谓的行并行方式. $$ - +这就是所谓的行并行方式. 为了计算 $$ Z = [Y_1 ~ Y_2] \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right] diff --git a/docs/source/zh-Hans/get_started/installation.md b/docs/source/zh-Hans/get_started/installation.md index e75e42530fc1..9e4f34707c13 100755 --- a/docs/source/zh-Hans/get_started/installation.md +++ b/docs/source/zh-Hans/get_started/installation.md @@ -20,10 +20,10 @@ pip install colossalai **注:现在只支持Linux。** -如果你想同时安装PyTorch扩展的话,可以添加`CUDA_EXT=1`。如果不添加的话,PyTorch扩展会在运行时自动安装。 +如果你想同时安装PyTorch扩展的话,可以添加`BUILD_EXT=1`。如果不添加的话,PyTorch扩展会在运行时自动安装。 ```shell -CUDA_EXT=1 pip install colossalai +BUILD_EXT=1 pip install colossalai ``` ## 从源安装 @@ -38,10 +38,10 @@ cd ColossalAI pip install -r requirements/requirements.txt # install colossalai -CUDA_EXT=1 pip install . +BUILD_EXT=1 pip install . ``` -如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装),您可以不添加`CUDA_EXT=1`: +如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装),您可以不添加`BUILD_EXT=1`: ```shell pip install . @@ -60,7 +60,7 @@ unzip 1.8.0.zip cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ # install -CUDA_EXT=1 pip install . +BUILD_EXT=1 pip install . ``` diff --git a/examples/language/llama2/attn.py b/examples/language/llama2/attn.py deleted file mode 100644 index 2b2356b18b70..000000000000 --- a/examples/language/llama2/attn.py +++ /dev/null @@ -1,84 +0,0 @@ -from types import MethodType -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv - -SUPPORT_XFORMERS = False -SUPPORT_FLASH2 = False -try: - import xformers.ops as xops - - SUPPORT_XFORMERS = True -except ImportError: - pass - -try: - from flash_attn import flash_attn_func - - SUPPORT_FLASH2 = True -except ImportError: - pass - -SUPPORT_FLASH = SUPPORT_XFORMERS or SUPPORT_FLASH2 - - -def llama_flash_attention( - self: LlamaAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - # q, k, v is [B, H, S, K] and xformers need [B, S, H, K]. returns [B, S, H, K] - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - if SUPPORT_FLASH2: - attn_output = flash_attn_func(query_states, key_states, value_states, causal=True) - else: - attn_output = xops.memory_efficient_attention( - query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask() - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def replace_xformers(model: nn.Module): - for module in model.modules(): - if isinstance(module, LlamaAttention): - module.forward = MethodType(llama_flash_attention, module) diff --git a/examples/language/llama2/attn.py b/examples/language/llama2/attn.py new file mode 120000 index 000000000000..4e95c7bfa519 --- /dev/null +++ b/examples/language/llama2/attn.py @@ -0,0 +1 @@ +../../../applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py \ No newline at end of file diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index b8f70ce9c9d8..54b023f64742 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -3,7 +3,7 @@ from contextlib import nullcontext import torch -from attn import SUPPORT_FLASH, replace_xformers +from attn import replace_with_flash_attention from data_utils import RandomDataset from model_utils import format_numel_str, get_model_numel from performance_evaluator import PerformanceEvaluator @@ -188,8 +188,7 @@ def empty_init(): model.gradient_checkpointing_enable() if args.xformers: - assert SUPPORT_FLASH, "Use flash attention while xfomers is not installed" - replace_xformers(model) + replace_with_flash_attention(model) model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py index 66b5400765f7..3dbd0cf357b4 100644 --- a/examples/language/llama2/finetune.py +++ b/examples/language/llama2/finetune.py @@ -9,7 +9,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from attn import SUPPORT_XFORMERS, replace_xformers +from attn import replace_with_flash_attention from data_utils import load_json, prepare_dataloader, save_json from datasets import load_dataset from torch.optim import Optimizer @@ -219,8 +219,7 @@ def main(): if args.grad_checkpoint: model.gradient_checkpointing_enable() if args.flash_attention: - assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed" - replace_xformers(model) + replace_with_flash_attention(model) model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py index 4cdf93e1914b..fe7d958307e9 100644 --- a/examples/language/llama2/pretrain.py +++ b/examples/language/llama2/pretrain.py @@ -8,7 +8,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from attn import SUPPORT_XFORMERS, replace_xformers +from attn import replace_with_flash_attention from data_utils import load_json, prepare_dataloader, save_json from datasets import load_dataset from torch.optim import Optimizer @@ -238,8 +238,7 @@ def main(): if args.grad_checkpoint: model.gradient_checkpointing_enable() if args.flash_attention: - assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed" - replace_xformers(model) + replace_with_flash_attention(model) model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") diff --git a/extensions/cpp_extension.py b/extensions/cpp_extension.py index b4c40c9f1105..3adb65fb8f4e 100644 --- a/extensions/cpp_extension.py +++ b/extensions/cpp_extension.py @@ -126,7 +126,7 @@ def cxx_flags(self) -> List[str]: def load(self): try: op_kernel = self.import_op() - except ImportError: + except (ImportError, ModuleNotFoundError): # if import error occurs, it means that the kernel is not pre-built # so we build it jit op_kernel = self.build_jit() diff --git a/setup.py b/setup.py index 1244bfff0327..e54ec41ea9f8 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,5 @@ import os import sys -from datetime import datetime from typing import List from setuptools import find_packages, setup @@ -15,7 +14,6 @@ THIS_DIR = os.path.dirname(os.path.abspath(__file__)) BUILD_EXT = int(os.environ.get("BUILD_EXT", "0")) == 1 -IS_NIGHTLY = int(os.environ.get("NIGHTLY", "0")) == 1 # we do not support windows currently if sys.platform == "win32": @@ -96,23 +94,15 @@ def get_version() -> str: else: ext_modules = [] -# always put not nightly branch as the if branch -# otherwise github will treat colossalai-nightly as the project name -# and it will mess up with the dependency graph insights -if not IS_NIGHTLY: - version = get_version() - package_name = "colossalai" -else: - # use date as the nightly version - version = datetime.today().strftime("%Y.%m.%d") - package_name = "colossalai-nightly" +version = get_version() +package_name = "colossalai" setup( name=package_name, version=version, packages=find_packages( exclude=( - "op_builder", + "extensions", "benchmark", "docker", "tests", @@ -121,8 +111,9 @@ def get_version() -> str: "tests", "scripts", "requirements", + "extensions", "*.egg-info", - ) + ), ), description="An integrated large-scale model training system with efficient parallelization techniques", long_description=fetch_readme(), @@ -153,10 +144,7 @@ def get_version() -> str: ], package_data={ "colossalai": [ - "_C/*.pyi", - "kernel/cuda_native/csrc/*", - "kernel/cuda_native/csrc/kernel/*", - "kernel/cuda_native/csrc/kernels/include/*", + "kernel/extensions/csrc/**/*", ] }, ) diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index 67b0bef50594..d629e769d715 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -118,6 +118,20 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True): @parameterize( "test_args", [ + { + "batch_size": 8, + "num_steps": 4, + "tp": 2, + "pp": 2, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 4, + "zero": 1, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, { "batch_size": 8, "num_steps": 4, diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 708a1906b118..61cac1d8369b 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -97,7 +97,7 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha new_model = model_fn() optimizer = HybridAdam(model.parameters(), lr=0.001) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) + new_optimizer = HybridAdam(new_model.parameters(), lr=0.01) new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) data = data_gen_fn() @@ -109,6 +109,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha booster.backward(loss, optimizer) optimizer.step() + for group in optimizer.param_groups: + group["lr"] = 0.1 with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" @@ -127,6 +129,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha check_state_dict_equal( optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False ) + for group in new_optimizer.param_groups: + assert group["lr"] == 0.1 # Check the new model/optimizer can successfully run. data = data_gen_fn() diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index a42b550cd6fc..b5cb31715aed 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -83,7 +83,8 @@ def _preprocess_data(data): optimizer.backward(loss) optimizer.step() - + for group in optimizer.param_groups: + group["lr"] = 0.1 with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py index dd41f8185c2b..dca562a3b837 100644 --- a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py @@ -10,6 +10,7 @@ if version.parse(torch.__version__) >= version.parse("1.12.0"): from colossalai.booster.plugin import TorchFSDPPlugin + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -99,6 +100,43 @@ def run_model(): outputs_sec = fsdp_model(inputs) assert criterion(outputs_sec) == criterion(outputs) + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optim_ckpt_path = f"{tempdir}/optimizer" + + run_model() + + booster.save_model(fsdp_model, model_ckpt_path, shard=True) + booster.save_optimizer(optimizer, optim_ckpt_path, shard=True) + + full_msd = fsdp_model.unwrap().state_dict() + full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer) + + import copy + sharded_osd = copy.deepcopy(full_osd) + + run_model() + + full_msd_updated = fsdp_model.unwrap().state_dict() + full_osd_updated = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer) + + # cost much time led to timeout + # assert not compare_nested_dict(full_osd_updated, sharded_osd) + # assert not compare_nested_dict(full_msd_updated, full_msd) + outputs_first = fsdp_model(inputs) + assert criterion(outputs_first) != criterion(outputs) + + booster.load_model(fsdp_model, model_ckpt_path) + booster.load_optimizer(optimizer, optim_ckpt_path) + + full_msd_restore = fsdp_model.unwrap().state_dict() + sharded_osd_restore = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer) + + assert compare_nested_dict(sharded_osd, sharded_osd_restore) + assert compare_nested_dict(full_msd_restore, full_msd) + outputs_sec = fsdp_model(inputs) + assert criterion(outputs_sec) == criterion(outputs) + def run_dist(rank, world_size, port): # init dist env diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 721a4796abfd..17b790e3e87a 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -1,13 +1,22 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch.testing import assert_close +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_moe_epsize_param_dict +from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size + + +def delete_moe_info(model): + for _, param in model.named_parameters(): + if hasattr(param, "moe_info"): + delattr(param, "moe_info") class MoeModel(nn.Module): @@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None): for i in range(world_size - 1): a = tensor_list[i] b = tensor_list[i + 1] - assert not torch.allclose(a, b), \ - (f"expected tensors on rank {i} and {i + 1} not to be equal " - f"but they are, {a} vs {b}") + assert not torch.allclose(a, b), ( + f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}" + ) + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss) + else: + loss.backward() + return y + + +def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from ep model + + Args: + local_model (MoeModule) + ep_model (MoeModule) + """ + for (local_name, local_param), (ep_name, ep_param) in zip( + local_model.named_parameters(), ep_model.named_parameters() + ): + assert local_name in ep_name, print(f"{local_name} != {ep_name}") + if "experts" not in local_name: + if assert_grad_flag: + assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}" + assert torch.allclose(local_param.grad, ep_param.grad) + else: + local_param.data.copy_(ep_param.data) + continue + + # gather param from ep model + param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) + all_param = torch.cat(param_list, dim=0) + if assert_grad_flag: + grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) + all_grad = torch.cat(grad_list, dim=0) + + if assert_grad_flag: + assert torch.allclose(local_param, all_param) + assert torch.allclose(local_param.grad, all_grad) + else: + local_param.data.copy_(all_param.data) + + +def loose_close(a, b, dtype: torch.dtype = torch.float32): + rtol = None + atol = None + if dtype is torch.float16: + rtol = 5e-2 + atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 + + a = a.detach().to(dtype) + b = b.detach().to(dtype).to(a.device) + + assert_close(a, b, rtol=rtol, atol=atol) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 8f51e1663727..d6dad2d7fb41 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -12,7 +12,6 @@ from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn sys.path.append( @@ -95,6 +94,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=1, + ep_size=1, zero_stage=2, custom_policy=OpenMoeForCausalLMPolicy(), ) @@ -103,6 +103,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=1, + ep_size=dist.get_world_size(), zero_stage=2, custom_policy=OpenMoeForCausalLMPolicy(), ) @@ -111,6 +112,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=1, + ep_size=2, zero_stage=2, extra_dp_size=2, custom_policy=OpenMoeForCausalLMPolicy(), @@ -120,6 +122,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=2, + ep_size=2, zero_stage=1, microbatch_size=1, custom_policy=OpenMoeForCausalLMPolicy(), @@ -130,27 +133,6 @@ def get_model(parallel): def _test_moe_checkpoint(rank, parallel): - if parallel == None: - MOE_MANAGER.setup( - parallel=None, - ) - elif parallel == "ep": - MOE_MANAGER.setup( - parallel="EP", - ) - elif parallel == "ep_zero": - MOE_MANAGER.setup( - parallel="EP", - max_ep_size=2, - ) - elif parallel == "hybrid": - MOE_MANAGER.setup( - parallel="EP", - mode="fixed", - fixed_dp_size=1, - fixed_ep_size=2, - fixed_pp_size=2, - ) model1, booster1, optim1 = get_model(parallel) model2, booster2, optim2 = get_model(parallel) model3, booster3, optim3 = get_model(parallel) @@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel): _test_moe_checkpoint(rank, parallel) +@pytest.mark.skip(reason="This is tested in ColossalMOE") @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"]) diff --git a/tests/test_moe/test_moe_router.py b/tests/test_moe/test_moe_router.py index 7ba7fa6f6b7d..9f6167692d61 100644 --- a/tests/test_moe/test_moe_router.py +++ b/tests/test_moe/test_moe_router.py @@ -4,15 +4,21 @@ from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter -@pytest.mark.parametrize(["router", "num_groups"], [ - (Top1Router(), 1), - (Top2Router(), 1), - # (TopKRouter(num_selected_experts=3), 4), -]) -@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [ - (4, 5, 8), - (3, 4, 4), -]) +@pytest.mark.parametrize( + ["router", "num_groups"], + [ + (Top1Router(), 1), + (Top2Router(), 1), + # (TopKRouter(num_selected_experts=3), 4), + ], +) +@pytest.mark.parametrize( + ["batch_size", "seq_len", "num_experts"], + [ + (4, 5, 8), + (3, 4, 4), + ], +) def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int): x = torch.randn((batch_size * seq_len, num_experts)).cuda() if num_groups > 1: @@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex router.train() if isinstance(router, TopKRouter): - _, combine_array, dispatch_mask = router(x, expert_capacity=2) + combine_array, dispatch_mask = router(x, expert_capacity=2) else: - _, combine_array, dispatch_mask = router(x) + combine_array, dispatch_mask = router(x)[1:3] assert combine_array.shape[:-1] == x.shape assert dispatch_mask.shape[:-1] == x.shape assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) router.eval() if isinstance(router, TopKRouter): - _, combine_array, dispatch_mask = router(x, expert_capacity=2) + combine_array, dispatch_mask = router(x, expert_capacity=2) else: - _, combine_array, dispatch_mask = router(x) + combine_array, dispatch_mask = router(x)[1:3] assert combine_array.shape[:-1] == x.shape assert dispatch_mask.shape[:-1] == x.shape assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index f0795a4c738f..1bff2106675e 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -4,102 +4,75 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all -from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel +from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep -def split_ddp_grad(grad, world_size): - with torch.no_grad(): - grad = grad.clone().detach().flatten() - padding_size = (world_size - grad.numel() % world_size) % world_size - if padding_size > 0: - grad = torch.nn.functional.pad(grad, [0, padding_size]) - splited_grad = grad.split(grad.numel() // world_size) - return splited_grad - - -def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.float() - - if isinstance(model, LowLevelZeroModel): - optimizer.backward(loss) - else: - loss.backward() - return y - - -def run_zero_test(local_rank, world_size, stage=1): +def run_zero_test(local_rank, stage=1): criterion = torch.nn.CrossEntropyLoss() - zero_model = MoeModel() - optimizer = torch.optim.Adam(zero_model.parameters()) - plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") - booster = Booster(plugin=plugin) - zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer) - - torch_model = MoeModel() - for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): - torch_param.data.copy_(zero_param.data) - torch_model = torch_model.cuda() - grad_handler = MoeGradientHandler(torch_model) - - # assert zero model - for (torch_name, torch_param), (zero_name, zero_param) in zip( - torch_model.named_parameters(), zero_model.module.named_parameters() - ): - assert zero_name == torch_name - assert torch.allclose(zero_param.data, torch_param.data) - - data = torch.randn(16, 4).cuda() + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel="EP") + moe_model = MoeModel().bfloat16() + moe_optimizer = torch.optim.Adam(moe_model.parameters()) + moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + moe_booster = Booster(plugin=moe_plugin) + moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer) + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel=None) + zero_model = MoeModel().bfloat16() + delete_moe_info(zero_model) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + zero_booster = Booster(plugin=zero_plugin) + zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer) + sync_local_from_ep(zero_model, moe_model) + + data = torch.randn(16, 4).bfloat16().cuda() label = torch.randint(0, 4, (16,)).cuda() - torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) - zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer) - assert torch.allclose(torch_out, zero_out) - grad_handler.handle_gradient() + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer) + assert torch.allclose(zero_out, moe_out) - for (zero_name, zero_param), (torch_name, torch_param) in zip( - zero_model.module.named_parameters(), torch_model.named_parameters() + for (moe_name, moe_param), (zero_name, zero_param) in zip( + moe_model.module.named_parameters(), zero_model.module.named_parameters() ): - assert zero_name == torch_name - zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) - if hasattr(zero_param, "moe_info"): - assert len(zero_grad_list) == 0 - assert torch.allclose(zero_param.grad, torch_param.grad) + assert moe_name == zero_name + moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param)) + zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) + if hasattr(moe_param, "moe_info"): + assert len(moe_grad_list) == 0 + if stage == 1: + zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape) + else: + zero_grad = zero_grad_list[0].view(moe_param.grad.shape) + assert torch.allclose( + moe_param.grad, zero_grad, atol=1e-5 + ), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}" else: - assert len(zero_grad_list) > 0 - torch_grad_list = split_ddp_grad(torch_param.grad, world_size) - if stage == 2: - torch_grad_list = torch_grad_list[local_rank : local_rank + 1] - assert len(zero_grad_list) == len(torch_grad_list) - for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): - assert torch.allclose(zero_grad, torch_grad) + assert len(moe_grad_list) > 0 + assert len(moe_grad_list) == len(zero_grad_list) + for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list): + assert torch.allclose(moe_grad, zero_grad) -def run_dist(rank, world_size, port): +def run_dist(rank, world_size, port, stage): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_MANAGER.setup(parallel="EP") seed_all(42 + rank) - run_zero_test(rank, world_size, stage=1) - run_zero_test(rank, world_size, stage=2) + run_zero_test(rank, stage=stage) @pytest.mark.dist @pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("stage", [1, 2]) @rerun_if_address_is_in_use() -def test_moe_zero_model(world_size): - spawn(run_dist, world_size) +def test_moe_zero_model(world_size, stage): + spawn(run_dist, world_size, stage=stage) if __name__ == "__main__": - test_moe_zero_model(world_size=2) + test_moe_zero_model(world_size=2, stage=1) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 0d2e2fb1b2d8..4f6067aaa10a 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -4,89 +4,80 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.testing import rerun_if_address_is_in_use, spawn -from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep -def split_ddp_grad(grad, world_size): - with torch.no_grad(): - grad = grad.clone().detach().flatten() - padding_size = (world_size - grad.numel() % world_size) % world_size - if padding_size > 0: - grad = torch.nn.functional.pad(grad, [0, padding_size]) - splited_grad = grad.split(grad.numel() // world_size) - return splited_grad - - -def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.float() - - if isinstance(model, LowLevelZeroModel): - optimizer.backward(loss) - else: - loss.backward() - return y - - -def run_zero_optim_test(local_rank, world_size, stage=1): +def run_zero_test(local_rank, stage=1): criterion = torch.nn.CrossEntropyLoss() - zero_model = MoeModel() - zero_optimizer = torch.optim.Adam(zero_model.parameters()) - plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") - booster = Booster(plugin=plugin) - zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) - - torch_model = MoeModel() - for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): - torch_param.data.copy_(zero_param.data) - torch_optimizer = torch.optim.Adam(torch_model.parameters()) - torch_model = torch_model.cuda() - grad_handler = MoeGradientHandler(torch_model) - - for _ in range(2): - data = torch.randn(16, 4).cuda() / (local_rank + 1) - label = torch.randint(0, 4, (16,)).cuda() - run_fwd_bwd(torch_model, data, label, criterion, None) - run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) - grad_handler.handle_gradient() - - torch_optimizer.step() + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel="EP") + moe_model = MoeModel().bfloat16() + moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0) + moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + moe_booster = Booster(plugin=moe_plugin) + moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer) + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel=None) + zero_model = MoeModel().bfloat16() + delete_moe_info(zero_model) + sync_local_from_ep(zero_model, moe_model) + zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0) + zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + zero_booster = Booster(plugin=zero_plugin) + zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer) + + for (moe_name, moe_param), (zero_name, zero_param) in zip( + moe_model.named_parameters(), zero_model.named_parameters() + ): + if ".experts." in moe_name: + continue + assert moe_name == zero_name + assert torch.allclose( + moe_param.data, zero_param.data + ), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}" + + for _ in range(1): + data = torch.randn(2, 4).bfloat16().cuda() + label = torch.randint(0, 4, (2,)).cuda() + + moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer) + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + assert torch.allclose(zero_out, moe_out) + moe_optimizer.step() zero_optimizer.step() - for (torch_name, torch_param), (zero_name, zero_param) in zip( - torch_model.named_parameters(), zero_model.named_parameters() + for (moe_name, moe_param), (zero_name, zero_param) in zip( + moe_model.named_parameters(), zero_model.named_parameters() ): - assert torch.allclose( - torch_param.data, zero_param.data - ), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" + assert moe_name == zero_name + if is_moe_tensor(moe_param): + param_size = moe_param.shape[0] + zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size] + loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype) - torch_optimizer.zero_grad() + moe_optimizer.zero_grad() zero_optimizer.zero_grad() -def run_dist(rank, world_size, port): +def run_dist(rank, world_size, port, stage): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_MANAGER.setup(parallel="EP") - run_zero_optim_test(rank, world_size, stage=1) - run_zero_optim_test(rank, world_size, stage=2) + seed_all(42 + rank) + run_zero_test(rank, stage=stage) @pytest.mark.dist @pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("stage", [1, 2]) @rerun_if_address_is_in_use() -def test_moe_zero_optim(world_size): - spawn(run_dist, world_size) +def test_moe_zero_optim(world_size, stage): + spawn(run_dist, world_size, stage=stage) if __name__ == "__main__": - test_moe_zero_optim(world_size=2) + test_moe_zero_optim(world_size=2, stage=1) diff --git a/tests/test_optimizer/test_lr_scheduler.py b/tests/test_optimizer/test_lr_scheduler.py new file mode 100644 index 000000000000..e0b084140595 --- /dev/null +++ b/tests/test_optimizer/test_lr_scheduler.py @@ -0,0 +1,20 @@ +import torch.nn as nn +from torch.optim import Adam + +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR + + +def test_lr_scheduler_save_load(): + model = nn.Linear(10, 10) + optimizer = Adam(model.parameters(), lr=1e-3) + scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2) + new_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2) + for _ in range(5): + scheduler.step() + state_dict = scheduler.state_dict() + new_scheduler.load_state_dict(state_dict) + assert state_dict == new_scheduler.state_dict() + + +if __name__ == "__main__": + test_lr_scheduler_save_load() diff --git a/version.txt b/version.txt index 42045acae20f..c2c0004f0e2a 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.4 +0.3.5