diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py index 020432b9ec3c..d2bc38aa086e 100755 --- a/applications/ColossalChat/coati/dataset/tokenization_utils.py +++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py @@ -392,4 +392,4 @@ def tokenize_kto( "label": data_point["label"], "input_id_decode": decoded_full_prompt, "completion_decode": decoded_completion, - } + } \ No newline at end of file diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py index 24ddca6545c8..c6a174317c35 100755 --- a/applications/ColossalChat/coati/trainer/dpo.py +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -356,4 +356,4 @@ def _eval(self, epoch: int): os.makedirs(self.save_dir, exist_ok=True) with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: f.write(msg) - step_bar.close() + step_bar.close() \ No newline at end of file diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py index 6462ba816686..11fda0aa8290 100755 --- a/applications/ColossalChat/coati/trainer/kto.py +++ b/applications/ColossalChat/coati/trainer/kto.py @@ -346,4 +346,4 @@ def _eval(self, epoch: int): os.makedirs(self.save_dir, exist_ok=True) with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: f.write(msg) - step_bar.close() + step_bar.close() \ No newline at end of file diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py index c2f75771cdff..948232bff096 100644 --- a/applications/ColossalChat/coati/trainer/orpo.py +++ b/applications/ColossalChat/coati/trainer/orpo.py @@ -323,4 +323,4 @@ def _eval(self, epoch: int): os.makedirs(self.save_dir, exist_ok=True) with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: f.write(msg) - step_bar.close() + step_bar.close() \ No newline at end of file diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md index fec7bc061270..7fd6011deaf4 100755 --- a/applications/ColossalChat/examples/README.md +++ b/applications/ColossalChat/examples/README.md @@ -903,4 +903,4 @@ For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/mai ## Attention -The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance. +The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance. \ No newline at end of file diff --git a/applications/ColossalChat/examples/inference/round.txt b/applications/ColossalChat/examples/inference/round.txt new file mode 100644 index 000000000000..ba02074c1a03 --- /dev/null +++ b/applications/ColossalChat/examples/inference/round.txt @@ -0,0 +1,104 @@ + + +========== +round 1: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +tell me a story [/INST] Great, let’s hear a story. + +========== + + +========== +round 2: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +tell me a story [/INST] Great, let’s hear a story. [INST] calculate 1+1 [/INST] 1+1=2 + +========== + + +========== +round 3: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +tell me a story [/INST] Great, let’s hear a story. [INST] calculate 1+1 [/INST] 1+1=2 [INST] who is the first president of the USA [/INST] The first president of the United States was George Washington. + +========== + + +========== +round 1: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +who is the first president of the USA? [/INST] The first president of the United States was George Washington. + +========== + + +========== +round 2: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +who is the first president of the USA? [/INST] The first president of the United States was George Washington. [INST] tell me a story [/INST] One story that might be interesting is the story of how the United States was founded. In 1776, the Thirteen Colonies united together to form the new nation of America. The first president of the new nation was George Washington. The first state was Pennsylvania. The first capital was Philadelphia. The first presidential election was held in 1787, and was between George Read, a Federalist, and John Adams, a Republican. The Federalists wanted to keep the power of the federal government limited, while the Republicans wanted the power to be spread around to the states. The Federalists won, and George Washington was elected president. + +========== + + +========== +round 1: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? + +========== + + +========== +round 2: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? [INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? + +========== + + +========== +round 3: +[INST] <> +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + + +<> + +tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? [INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? [INST] Is Donald Trump the president of the United States [/INST] Yes, Donald Trump became the 45th president of the United States in January of 2016. + +========== diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py index d88750aebc8f..c2fb3bb79d39 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py @@ -368,4 +368,4 @@ def train(args): os.makedirs(os.path.dirname(args.config_file), exist_ok=True) with open(args.config_file, "w") as f: json.dump(args.__dict__, f, indent=4) - train(args) + train(args) \ No newline at end of file diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.py b/applications/ColossalChat/examples/training_scripts/train_kto.py index 598fd8062fcf..db48e33e329d 100755 --- a/applications/ColossalChat/examples/training_scripts/train_kto.py +++ b/applications/ColossalChat/examples/training_scripts/train_kto.py @@ -375,4 +375,4 @@ def train(args): os.makedirs(os.path.dirname(args.config_file), exist_ok=True) with open(args.config_file, "w") as f: json.dump(args.__dict__, f, indent=4) - train(args) + train(args) \ No newline at end of file diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.py b/applications/ColossalChat/examples/training_scripts/train_orpo.py index 87860f7ea023..fd99af828528 100755 --- a/applications/ColossalChat/examples/training_scripts/train_orpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_orpo.py @@ -340,4 +340,4 @@ def train(args): os.makedirs(os.path.dirname(args.config_file), exist_ok=True) with open(args.config_file, "w") as f: json.dump(args.__dict__, f, indent=4) - train(args) + train(args) \ No newline at end of file diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt index ac40ae821d0a..1978120d1f2e 100755 --- a/applications/ColossalChat/requirements.txt +++ b/applications/ColossalChat/requirements.txt @@ -20,4 +20,4 @@ datasets ninja==1.11.1 sentencepiece==0.1.99 flash-attn -tiktoken +tiktoken \ No newline at end of file diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh index 69036de635c9..c34250d9816b 100755 --- a/applications/ColossalChat/tests/test_train.sh +++ b/applications/ColossalChat/tests/test_train.sh @@ -641,3 +641,171 @@ for lora_rank in ${LORA_RANK[@]}; do done done done + + + +echo "[Test]: testing ORPO ..." + +SKIPPED_TESTS=( + llama-3d-20 # 3d plugin doesn't support lora + llama-gemini_auto-20 # gemini_auto plugin doesn't support lora + llama-gemini-20 # gemini doesn't support lora +) +GRAD_CKPTS=('--grad_checkpoint') +for lora_rank in ${LORA_RANK[@]}; do + for model in ${MODELS[@]}; do + for plugin in ${PLUGINS[@]}; do + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$plugin-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + pretrain=$(get_pretrain $model) + tokenizer_dir=$(get_tokenizer_dirs $model) + grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") + tp='1' + bs='2' + if [[ $plugin == "3d" ]]; then + tp='4' + bs='8' + fi + if [[ $plugin == "zero2" ]]; then + lora_config=$LORA_CONFIG_ENABLE + else + lora_config="" + fi + grad_accu='2' + # gemini_auto and gemini doesn't support gradient accumulation + if [[ $plugin == "gemini_auto" ]]; then + grad_accu='1' + fi + # gemini_auto doesn't support generation + # (need to calculate ref_model logits through forwarding in inference mode) + if [[ $plugin == "gemini_auto" ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$plugin-$lora_rank, attempt $i" + declare -a dataset=() + for split in $(seq -f "%05g" 0 0); do + dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split") + done + colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \ + --pretrain $pretrain \ + --tokenizer_dir $tokenizer_dir \ + --dataset ${dataset[@]} \ + --eval_dataset ${dataset[@]} \ + --save_dir $MODEL_SAVE_PATH \ + --config_file $MODELS_DIR/config.jsonl \ + $lora_config \ + --plugin $plugin \ + --batch_size $bs \ + --max_epochs 1 \ + --accumulation_steps $grad_accu \ + --tp $tp \ + --lr 2e-5 \ + $grad_ckpt \ + --max_len 400 \ + --use_flash_attn + passed=$? + if [ $passed -eq 0 ]; then + rm -rf ${MODEL_SAVE_PATH:?}/* + rm -rf ${MODELS_DIR:?}/* + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$plugin-$lora_rank" + exit 1 + fi + done + done +done + + + +echo "[Test]: testing KTO ..." + +SKIPPED_TESTS=( + llama-3d-20 # 3d plugin doesn't support lora + llama-gemini_auto-20 # gemini_auto plugin doesn't support lora + llama-gemini-20 # gemini doesn't support lora +) +GRAD_CKPTS=('--grad_checkpoint') +for lora_rank in ${LORA_RANK[@]}; do + for model in ${MODELS[@]}; do + for plugin in ${PLUGINS[@]}; do + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$plugin-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + pretrain=$(get_pretrain $model) + tokenizer_dir=$(get_tokenizer_dirs $model) + grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") + tp='1' + bs='2' + if [[ $plugin == "3d" ]]; then + tp='4' + bs='8' + fi + if [[ $plugin == "zero2" ]]; then + lora_config=$LORA_CONFIG_ENABLE + else + lora_config="" + fi + grad_accu='2' + # gemini_auto and gemini doesn't support gradient accumulation + if [[ $plugin == "gemini_auto" ]]; then + grad_accu='1' + fi + # gemini_auto doesn't support generation + # (need to calculate ref_model logits through forwarding in inference mode) + if [[ $plugin == "gemini_auto" ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$plugin-$lora_rank, attempt $i" + declare -a dataset=() + for split in $(seq -f "%05g" 0 0); do + dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split") + done + colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \ + --pretrain $pretrain \ + --tokenizer_dir $tokenizer_dir \ + --dataset ${dataset[@]} \ + --eval_dataset ${dataset[@]} \ + --save_dir $MODEL_SAVE_PATH \ + --config_file $MODELS_DIR/config.jsonl \ + $lora_config \ + --plugin $plugin \ + --batch_size $bs \ + --max_epochs 1 \ + --accumulation_steps $grad_accu \ + --tp $tp \ + --lr 2e-5 \ + --auto_weight \ + --desirable_weight 1.2 \ + $grad_ckpt \ + --max_len 400 \ + --use_flash_attn + passed=$? + if [ $passed -eq 0 ]; then + rm -rf ${MODEL_SAVE_PATH:?}/* + rm -rf ${MODELS_DIR:?}/* + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$plugin-$lora_rank" + exit 1 + fi + done + done +done diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index eba9d60625f5..9c0d64b82716 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1020,6 +1020,7 @@ def __init__( make_vocab_size_divisible_by: int = 64, dp_outside: bool = True, overlap_p2p: bool = True, + overlap_allgather: bool = False, fp8_communication: bool = False, overlap_allgather: bool = False, inner_ring_size: int = None, diff --git a/colossalai/legacy/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py index e892336bcf87..157bc46bee19 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_1d/_operation.py @@ -83,8 +83,6 @@ def backward(ctx, grad_output): # all-reduce scheduled first and have GPU resources allocated # TODO: This seems to only work if you add torch.cuda.Event.wait() - # _ = torch.zeros(1, device=grad_output.device) - grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index d61581140592..cea5ffa5710b 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -30,11 +30,9 @@ from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, RingAttention, dist_cross_entropy - +from ..layer import ColoAttention, cross_entropy_1d, RingAttention, dist_cross_entropy _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] - class LlamaPipelineForwards: """ This class serves as a micro library for forward function substitution of Llama models @@ -102,7 +100,7 @@ def llama_model_forward( sp_group = shard_config.sequence_parallel_process_group sp_size = shard_config.sequence_parallel_size if sp_mode == "all_to_all" and not stage_manager.is_first_stage(): - # For generating full positions ids, as the states will be gather along the seq dim in the attention layer later. + # For correct positions ids. The states will be gather along the seq dim in the attention layer later. seq_length *= sp_size past_seen_tokens = 0 @@ -226,8 +224,10 @@ def llama_model_forward( if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) # add hidden states from the last decoder layer if output_hidden_states: @@ -551,7 +551,6 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -683,7 +682,7 @@ def forward( if shard_config.enable_flash_attention: mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len) - attn_kwargs: dict = ColoAttention.prepare_attn_kwargs( + attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, inputs_embeds.dtype, inputs_embeds.device, @@ -850,9 +849,25 @@ def forward( else: logits = self.lm_head(hidden_states) logits = logits.float() - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + dtype=self.model.dtype, + ) + if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 75ec1f33d195..11b414aee35f 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -22,6 +22,7 @@ from colossalai.logging import get_dist_logger from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8, all_reduce_fp8, reduce_scatter_fp8 +from colossalai.tensor.moe_tensor.api import is_moe_tensor from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor from .bookkeeping import BucketStore, GradientStore, TensorBucket @@ -569,25 +570,29 @@ def step(self, closure=None): working_param = real_working_params[group_id][idx] param_to_gather = master_param.to(device).to(self._dtype) pg = self.param_to_pg[working_param] - if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: - buffer_tensor = torch.empty_like( - torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))]) - ) - if self._fp8_communication: - all_gather_into_tensor_flat_fp8(buffer_tensor, param_to_gather, pg, fp8_format="e4m3") - else: - dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg) - working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param)) - continue - try: - self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) - except RuntimeError: - self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication) - self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) + padded_working_param = self._working_param_to_padded_working_param[working_param] + if self._overlap_allgather: + handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True) + set_all_gather_handle(working_param, handle) + else: + if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: + if self._fp8_communication: + all_gather_into_tensor_flat_fp8( + padded_working_param, param_to_gather, pg, fp8_format="e4m3" + ) + else: + dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) + continue + try: + self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) + except RuntimeError: + self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication) + self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] - for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): - if not tensor_bucket.is_empty(): - tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) + if not self._overlap_allgather: + for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): + if not tensor_bucket.is_empty(): + tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: r""" diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 093377e7a034..d30c8a552bed 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -213,6 +213,8 @@ def empty_init(): enable_flash_attention=args.xformers, microbatch_size=args.mbs, precision="bf16", + dp_outside=False, + overlap_p2p=args.overlap, enable_metadata_cache=not args.no_cache, overlap_allgather=args.overlap_allgather, **hybrid_kwargs, diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 8dd516dbf46a..c6f49de21173 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -199,28 +199,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { # Ulysess + TP - "tp_size": 2, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - }, - { # Ulysess + PP + { # Ulysess + Flash attention "tp_size": 1, "pp_size": 2, "sp_size": 2, "num_microbatches": 2, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", - "enable_all_optimization": True, + "enable_flash_attention": True, "use_lazy_init": True, - "zero_stage": 0, + "zero_stage": 1, "precision": "fp16", "initial_scale": 1, },