diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml index d0b5c2164119..b7522ffbdf74 100644 --- a/.github/workflows/run_chatgpt_examples.yml +++ b/.github/workflows/run_chatgpt_examples.yml @@ -31,18 +31,18 @@ jobs: - name: Install Colossal-AI run: | - BUILD_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install --no-cache-dir -v -e . - name: Install ChatGPT run: | cd applications/ColossalChat - pip install -v . + pip install --no-cache-dir -v . export BUILD_EXT=1 - pip install -r examples/requirements.txt + pip install --no-cache-dir -r examples/requirements.txt - name: Install Transformers run: | - pip install transformers==4.36.2 + pip install --no-cache-dir transformers==4.36.2 - name: Execute Examples run: | diff --git a/applications/ColossalChat/.gitignore b/applications/ColossalChat/.gitignore index 7b361d38e6d0..5a4bb905f4ea 100755 --- a/applications/ColossalChat/.gitignore +++ b/applications/ColossalChat/.gitignore @@ -161,3 +161,9 @@ applications/ColossalChat/sft_data applications/ColossalChat/prompt_data applications/ColossalChat/preference_data applications/ColossalChat/temp + +# Testing data +/kto_data/ +/preference_data/ +/prompt_data/ +/sft_data/ diff --git a/applications/ColossalChat/coati/trainer/base.py b/applications/ColossalChat/coati/trainer/base.py index 63c903a51940..bef4ccc3e078 100755 --- a/applications/ColossalChat/coati/trainer/base.py +++ b/applications/ColossalChat/coati/trainer/base.py @@ -16,7 +16,7 @@ from coati.experience_maker import Experience from torch.optim import Optimizer -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from .utils import is_rank_0 @@ -38,6 +38,7 @@ def __init__( max_epochs: int, model: nn.Module, optimizer: Optimizer, + plugin: Plugin, start_epoch: int = 0, ) -> None: super().__init__() @@ -45,6 +46,7 @@ def __init__( self.max_epochs = max_epochs self.model = model self.optimizer = optimizer + self.plugin = plugin self.start_epoch = start_epoch @abstractmethod diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py index 24ddca6545c8..faa7a90d92de 100755 --- a/applications/ColossalChat/coati/trainer/dpo.py +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -16,7 +16,7 @@ from tqdm import trange from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -50,6 +50,7 @@ def __init__( ref_model: Any, booster: Booster, actor_optim: Optimizer, + plugin: Plugin, actor_lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, @@ -63,7 +64,9 @@ def __init__( save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch + ) self.ref_model = ref_model self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py index 6462ba816686..f0b23afb667f 100755 --- a/applications/ColossalChat/coati/trainer/kto.py +++ b/applications/ColossalChat/coati/trainer/kto.py @@ -17,7 +17,7 @@ from tqdm import trange from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -53,6 +53,7 @@ def __init__( ref_model: Any, booster: Booster, actor_optim: Optimizer, + plugin: Plugin, actor_lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, @@ -66,7 +67,9 @@ def __init__( save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch + ) self.ref_model = ref_model self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py index c2f75771cdff..761fd305a6ff 100644 --- a/applications/ColossalChat/coati/trainer/orpo.py +++ b/applications/ColossalChat/coati/trainer/orpo.py @@ -16,7 +16,7 @@ from tqdm import trange from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -48,6 +48,7 @@ def __init__( actor: Any, booster: Booster, actor_optim: Optimizer, + plugin: Plugin, actor_lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, max_epochs: int = 1, @@ -59,7 +60,9 @@ def __init__( save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch + ) self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer self.odds_ratio_loss_fn = OddsRatioLoss() diff --git a/applications/ColossalChat/coati/trainer/rm.py b/applications/ColossalChat/coati/trainer/rm.py index b9e84ef557fa..82e4625b9c8e 100755 --- a/applications/ColossalChat/coati/trainer/rm.py +++ b/applications/ColossalChat/coati/trainer/rm.py @@ -15,7 +15,7 @@ from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase -from colossalai.booster import Booster +from colossalai.booster import Booster, Plugin from colossalai.cluster import DistCoordinator from colossalai.utils import get_current_device @@ -48,6 +48,7 @@ def __init__( model: Any, booster: Booster, optimizer: Optimizer, + plugin: Plugin, lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, loss_fn: Optional[Callable] = None, @@ -59,7 +60,9 @@ def __init__( save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch + ) self.actor_scheduler = lr_scheduler self.tokenizer = tokenizer self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta) diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py index d37676ada3e0..3aedcf7a99af 100755 --- a/applications/ColossalChat/coati/trainer/sft.py +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -6,14 +6,16 @@ from typing import Optional import torch +import torch.distributed as dist from coati.trainer.utils import all_reduce_mean from coati.utils import AccumulativeMeanMeter, save_checkpoint from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from tqdm import trange +from tqdm import tqdm, trange from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin, Plugin from colossalai.cluster import DistCoordinator from .base import SLTrainer @@ -40,6 +42,7 @@ def __init__( optim: Optimizer, lr_scheduler: _LRScheduler, max_epochs: int = 2, + plugin: Plugin = None, accumulation_steps: int = 8, apply_loss_mask: bool = True, start_epoch=0, @@ -47,7 +50,7 @@ def __init__( save_dir: str = None, coordinator: Optional[DistCoordinator] = None, ) -> None: - super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch) + super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch) self.accumulation_steps = accumulation_steps self.scheduler = lr_scheduler @@ -94,90 +97,152 @@ def _before_fit( def _train(self, epoch: int): self.model.train() - step_bar = trange( - len(self.train_dataloader) // self.accumulation_steps, - desc=f"Epoch {epoch + 1}/{self.max_epochs}", - disable=not is_rank_0(), - ) - for i, batch in enumerate(self.train_dataloader): - batch = to_device(batch, torch.cuda.current_device()) - batch_size = batch["input_ids"].size(0) - outputs = self.model( - batch["input_ids"], - attention_mask=batch["attention_mask"], - labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1: + data_iter = iter(self.train_dataloader) + step_bar = tqdm( + range(len(self.train_dataloader)), + desc="Step", + disable=not (dist.get_rank() == dist.get_world_size() - 1), ) - loss = outputs.loss - - self.booster.backward(loss=loss, optimizer=self.optimizer) + for step in step_bar: + outputs = self.booster.execute_pipeline( + data_iter, + self.model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=self.optimizer, + return_loss=True, + ) + loss = outputs["loss"] - loss_mean = all_reduce_mean(tensor=loss) - self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + if self.booster.plugin.stage_manager.is_last_stage(): + global_loss = all_reduce_mean(loss, self.plugin) + if dist.get_rank() == dist.get_world_size() - 1: + step_bar.set_postfix({"train/loss": global_loss.item()}) - # Gradient accumulation - if (i + 1) % self.accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() - self.scheduler.step() - - step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) - if self.writer: - self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) - self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) - self.num_train_step += 1 - self.accumulative_meter.reset() - step_bar.update() - - # Save checkpoint - if ( - self.save_dir is not None - and self.save_interval is not None - and (self.num_train_step + 1) % self.save_interval == 0 - ): - save_checkpoint( - save_dir=self.save_dir, - booster=self.booster, - model=self.model, - optimizer=self.optimizer, - lr_scheduler=self.scheduler, - epoch=epoch, - step=self.num_train_step + 1, - batch_size=batch_size, - coordinator=self.coordinator, - ) - self.coordinator.print_on_master( - f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" - ) - step_bar.close() - - def _eval(self, epoch: int): - if self.eval_dataloader is None: - self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation") - return - self.accumulative_meter.reset() - self.model.eval() - with torch.no_grad(): + else: step_bar = trange( - len(self.eval_dataloader), + len(self.train_dataloader) // self.accumulation_steps, desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0(), ) - for batch in self.eval_dataloader: + for i, batch in enumerate(self.train_dataloader): batch = to_device(batch, torch.cuda.current_device()) + batch_size = batch["input_ids"].size(0) outputs = self.model( batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], ) - loss_mean = all_reduce_mean(tensor=outputs.loss) - self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) - step_bar.update() - loss_mean = self.accumulative_meter.get("loss") - msg = "Evaluation Result:\n" - for tag in ["loss"]: - msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" - self.coordinator.print_on_master(msg) - os.makedirs(self.save_dir, exist_ok=True) - with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: - f.write(msg) - step_bar.close() + loss = outputs.loss + + self.booster.backward(loss=loss, optimizer=self.optimizer) + + loss_mean = all_reduce_mean(tensor=loss) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + + # Gradient accumulation + if (i + 1) % self.accumulation_steps == 0: + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + + step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) + if self.writer: + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) + self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) + self.num_train_step += 1 + self.accumulative_meter.reset() + step_bar.update() + + # Save checkpoint + if ( + self.save_dir is not None + and self.save_interval is not None + and (self.num_train_step + 1) % self.save_interval == 0 + ): + save_checkpoint( + save_dir=self.save_dir, + booster=self.booster, + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.scheduler, + epoch=epoch, + step=self.num_train_step + 1, + batch_size=batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" + ) + step_bar.close() + + def _eval(self, epoch: int): + if self.eval_dataloader is None: + self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation") + return + self.accumulative_meter.reset() + self.model.eval() + with torch.no_grad(): + if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1: + data_iter = iter(self.eval_dataloader) + step_bar = tqdm( + range(len(self.eval_dataloader)), + desc="Step", + disable=not (dist.get_rank() == dist.get_world_size() - 1), + ) + for step in step_bar: + outputs = self.booster.execute_pipeline( + data_iter, + self.model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=self.optimizer, + return_loss=True, + ) + loss = outputs["loss"] + if self.booster.plugin.stage_manager.is_last_stage(): + global_loss = all_reduce_mean(loss, self.plugin) + if dist.get_rank() == dist.get_world_size() - 1: + step_bar.set_postfix({"eval/loss": global_loss.item()}) + self.accumulative_meter.add("loss", global_loss.item()) + + if dist.get_rank() == dist.get_world_size() - 1: + loss_mean = self.accumulative_meter.get("loss") + msg = "Evaluation Result:\n" + for tag in ["loss"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + print(msg) + if self.save_dir is not None: + os.makedirs(self.save_dir, exist_ok=True) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) + step_bar.close() + + else: + step_bar = trange( + len(self.eval_dataloader), + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for batch in self.eval_dataloader: + batch = to_device(batch, torch.cuda.current_device()) + outputs = self.model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"], + ) + loss_mean = all_reduce_mean(tensor=outputs.loss) + self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) + step_bar.update() + + loss_mean = self.accumulative_meter.get("loss") + msg = "Evaluation Result:\n" + for tag in ["loss"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + self.coordinator.print_on_master(msg) + if self.save_dir is not None: + os.makedirs(self.save_dir, exist_ok=True) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) + step_bar.close() diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py index 3c836b4b4db1..217a87cf0419 100755 --- a/applications/ColossalChat/coati/trainer/utils.py +++ b/applications/ColossalChat/coati/trainer/utils.py @@ -9,6 +9,8 @@ from torch.utils._pytree import tree_map from torch.utils.data import DataLoader +from colossalai.booster import Plugin + class CycledDataLoader: """ @@ -85,7 +87,7 @@ def _to(t: Any): return tree_map(_to, x) -def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: +def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: """ Perform all-reduce operation on the given tensor and compute the mean across all processes. @@ -95,8 +97,13 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: The reduced tensor with mean computed across all processes. """ - dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) - tensor.div_(dist.get_world_size()) + # All reduce mean across DP group + if plugin is not None: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group) + tensor.div_(plugin.dp_size) + else: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) return tensor diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py index d88750aebc8f..3b324ee784e0 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py @@ -267,6 +267,7 @@ def train(args): ref_model=ref_model, booster=booster, actor_optim=optim, + plugin=plugin, actor_lr_scheduler=lr_scheduler, tokenizer=tokenizer, max_epochs=args.max_epochs, diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.py b/applications/ColossalChat/examples/training_scripts/train_kto.py index 598fd8062fcf..931c1657710e 100755 --- a/applications/ColossalChat/examples/training_scripts/train_kto.py +++ b/applications/ColossalChat/examples/training_scripts/train_kto.py @@ -286,6 +286,7 @@ def train(args): ref_model=ref_model, booster=booster, actor_optim=optim, + plugin=plugin, actor_lr_scheduler=lr_scheduler, tokenizer=tokenizer, max_epochs=args.max_epochs, diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.py b/applications/ColossalChat/examples/training_scripts/train_orpo.py index 87860f7ea023..0f2fbfa2ba44 100755 --- a/applications/ColossalChat/examples/training_scripts/train_orpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_orpo.py @@ -250,6 +250,7 @@ def train(args): actor=model, booster=booster, actor_optim=optim, + plugin=plugin, actor_lr_scheduler=lr_scheduler, tokenizer=tokenizer, max_epochs=args.max_epochs, diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.py b/applications/ColossalChat/examples/training_scripts/train_rm.py index 4c0a782b4766..5ea1a06acc36 100755 --- a/applications/ColossalChat/examples/training_scripts/train_rm.py +++ b/applications/ColossalChat/examples/training_scripts/train_rm.py @@ -262,6 +262,7 @@ def train(args): model, booster, optim, + plugin, lr_scheduler, tokenizer, loss_fn=loss_fn, diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py index c4ef3b783d4d..62acad32f66a 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.py +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -114,7 +114,7 @@ def train(args): parallel_output=False, max_norm=args.grad_clip, precision=args.mixed_precision, - microbatch_size=args.batch_size, + microbatch_size=args.microbatch_size, ) else: raise ValueError(f"Unknown plugin {args.plugin}") @@ -269,6 +269,7 @@ def train(args): model=model, booster=booster, optim=optim, + plugin=plugin, lr_scheduler=lr_scheduler, max_epochs=args.max_epochs, accumulation_steps=args.accumulation_steps, @@ -344,6 +345,7 @@ def train(args): parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--use_flash_attn", default=False, action="store_true") + parser.add_argument("--microbatch_size", type=int, default=1) args = parser.parse_args() if args.config_file is not None: os.makedirs(os.path.dirname(args.config_file), exist_ok=True) diff --git a/applications/ColossalChat/tests/test_lora.py b/applications/ColossalChat/tests/test_lora.py index 7787592105b6..a6365051758f 100755 --- a/applications/ColossalChat/tests/test_lora.py +++ b/applications/ColossalChat/tests/test_lora.py @@ -61,7 +61,7 @@ def test_overfit(): _, predicted = torch.max(outputs.data, 1) total = labels.size(0) correct = (predicted == Y).sum().item() - assert (correct / total > 0.95, "The model has not overfitted to the synthesized dataset") + assert correct / total > 0.95 assert (weight_to_compare - model.fc1.weight).sum() < 0.01 diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh index 69036de635c9..2935a6369986 100755 --- a/applications/ColossalChat/tests/test_train.sh +++ b/applications/ColossalChat/tests/test_train.sh @@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" } -set_n_least_used_CUDA_VISIBLE_DEVICES 2 +set_n_least_used_CUDA_VISIBLE_DEVICES 4 set -xu @@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models MODELS_DIR=$TEMP_DIR/models_config # Skip those tests due to CI tests timeout MODELS=('llama') -ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') # pp is still buggy +ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu' 'pp' 'tp_pp') PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json" @@ -91,7 +91,7 @@ SKIPPED_TESTS=( llama-gemini_auto-20 # gemini_auto plugin doesn't support lora llama-gemini-20 # gemini doesn't support lora ) - +skip_eval=false GRAD_CKPTS=('--grad_checkpoint') for lora_rank in ${LORA_RANK[@]}; do for model in ${MODELS[@]}; do @@ -129,15 +129,18 @@ for lora_rank in ${LORA_RANK[@]}; do plugin='3d' fi if [[ $plugin == "tp_pp" ]]; then + echo "Here" tp='2' bs='8' pp='2' plugin='3d' + skip_eval=true fi if [[ $plugin == "pp" ]]; then bs='8' pp='2' plugin='3d' + skip_eval=true fi if [[ $plugin == "sp_split_gather" ]]; then enable_sequence_parallelism='--enable_sequence_parallelism' @@ -175,28 +178,53 @@ for lora_rank in ${LORA_RANK[@]}; do for split in $(seq -f "%05g" 0 0); do dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") done - colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \ - --pretrain $pretrain \ - --tokenizer_dir $tokenizer_dir \ - --dataset ${dataset[@]} \ - --eval_dataset ${dataset[@]} \ - --save_path $MODEL_SAVE_PATH \ - --config_file $MODELS_DIR/config.jsonl \ - $lora_config \ - --plugin $plugin \ - --batch_size $bs \ - --max_epochs 1 \ - --accumulation_steps $grad_accu \ - --tp $tp \ - --pp $pp \ - --zero_stage $zero_stage \ - --sp $sp \ - --sp_mode $sp_mode \ - $enable_sequence_parallelism \ - --lr 2e-5 \ - $grad_ckpt \ - --max_len 400 \ - --use_flash_attn + + if [[ $skip_eval ]]; then + colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \ + --pretrain $pretrain \ + --tokenizer_dir $tokenizer_dir \ + --dataset ${dataset[@]} \ + --save_path $MODEL_SAVE_PATH \ + --config_file $MODELS_DIR/config.jsonl \ + $lora_config \ + --plugin $plugin \ + --batch_size $bs \ + --max_epochs 1 \ + --accumulation_steps $grad_accu \ + --tp $tp \ + --pp $pp \ + --zero_stage $zero_stage \ + --sp $sp \ + --sp_mode $sp_mode \ + $enable_sequence_parallelism \ + --lr 2e-5 \ + $grad_ckpt \ + --max_len 400 \ + --use_flash_attn + else + colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \ + --pretrain $pretrain \ + --tokenizer_dir $tokenizer_dir \ + --dataset ${dataset[@]} \ + --eval_dataset ${dataset[@]} \ + --save_path $MODEL_SAVE_PATH \ + --config_file $MODELS_DIR/config.jsonl \ + $lora_config \ + --plugin $plugin \ + --batch_size $bs \ + --max_epochs 1 \ + --accumulation_steps $grad_accu \ + --tp $tp \ + --pp $pp \ + --zero_stage $zero_stage \ + --sp $sp \ + --sp_mode $sp_mode \ + $enable_sequence_parallelism \ + --lr 2e-5 \ + $grad_ckpt \ + --max_len 400 \ + --use_flash_attn + fi passed=$? if [ $passed -eq 0 ]; then rm -rf ${MODEL_SAVE_PATH:?}/*