Skip to content

Commit

Permalink
[FP8] rebase main (hpcaitech#5963)
Browse files Browse the repository at this point in the history
* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (hpcaitech#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (hpcaitech#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (hpcaitech#5872)

* [Feature] remove modeling file, use auto config. (hpcaitech#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (hpcaitech#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (hpcaitech#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <[email protected]>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (hpcaitech#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for hpcaitech#5838 (hpcaitech#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (hpcaitech#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (hpcaitech#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (hpcaitech#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (hpcaitech#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (hpcaitech#5903)

* [compatibility] support torch 2.2 (hpcaitech#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (hpcaitech#5820)

* [misc] support torch2.3 (hpcaitech#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (hpcaitech#5912)

* [plugin] support all-gather overlap for hybrid parallel (hpcaitech#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (hpcaitech#5924)

Co-authored-by: Edenzzzz <[email protected]>

* [ColossalChat] Hotfix for ColossalChat (hpcaitech#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (hpcaitech#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (hpcaitech#5894)

* [FIX BUG] convert env param to int in (hpcaitech#5934)

* [Hotfix] Fix ZeRO typo hpcaitech#5936

Co-authored-by: Edenzzzz <[email protected]>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (hpcaitech#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (hpcaitech#5945)

* [shardformer] hotfix attn mask (hpcaitech#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (hpcaitech#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (hpcaitech#5951)

* [release] update version (hpcaitech#5952)

* [Chat] Fix lora (hpcaitech#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (hpcaitech#5958)

* [hotfix] Remove unused plan section (hpcaitech#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (hpcaitech#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (hpcaitech#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (hpcaitech#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor (hpcaitech#5932)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update low_level_optim.py

---------

Co-authored-by: YeAnbang <[email protected]>
Co-authored-by: Haze188 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <[email protected]>
Co-authored-by: Edenzzzz <[email protected]>
Co-authored-by: Runyu Lu <[email protected]>
Co-authored-by: Guangyao Zhang <[email protected]>
Co-authored-by: YeAnbang <[email protected]>
Co-authored-by: Hongxin Liu <[email protected]>
Co-authored-by: Stephan Kö <[email protected]>
Co-authored-by: アマデウス <[email protected]>
Co-authored-by: Tong Li <[email protected]>
Co-authored-by: zhurunhua <[email protected]>
Co-authored-by: Insu Jang <[email protected]>
Co-authored-by: Gao, Ruiyuan <[email protected]>
Co-authored-by: hxwang <[email protected]>
Co-authored-by: Michelle <[email protected]>
Co-authored-by: Wang Binluo <[email protected]>
Co-authored-by: HangXu <[email protected]>
  • Loading branch information
20 people committed Aug 16, 2024
1 parent 10fbe0e commit db2007c
Show file tree
Hide file tree
Showing 17 changed files with 336 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -392,4 +392,4 @@ def tokenize_kto(
"label": data_point["label"],
"input_id_decode": decoded_full_prompt,
"completion_decode": decoded_completion,
}
}
2 changes: 1 addition & 1 deletion applications/ColossalChat/coati/trainer/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,4 +356,4 @@ def _eval(self, epoch: int):
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()
step_bar.close()
2 changes: 1 addition & 1 deletion applications/ColossalChat/coati/trainer/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,4 +346,4 @@ def _eval(self, epoch: int):
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()
step_bar.close()
2 changes: 1 addition & 1 deletion applications/ColossalChat/coati/trainer/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,4 +323,4 @@ def _eval(self, epoch: int):
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()
step_bar.close()
2 changes: 1 addition & 1 deletion applications/ColossalChat/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -903,4 +903,4 @@ For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/mai
## Attention


The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.
The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.
104 changes: 104 additions & 0 deletions applications/ColossalChat/examples/inference/round.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@


==========
round 1:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.


<</SYS>>

tell me a story [/INST] Great, let’s hear a story. </s>

==========


==========
round 2:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.


<</SYS>>

tell me a story [/INST] Great, let’s hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s>

==========


==========
round 3:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.


<</SYS>>

tell me a story [/INST] Great, let’s hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s><s>[INST] who is the first president of the USA [/INST] The first president of the United States was George Washington. </s>

==========


==========
round 1:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.


<</SYS>>

who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s>

==========


==========
round 2:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.


<</SYS>>

who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s><s>[INST] tell me a story [/INST] One story that might be interesting is the story of how the United States was founded. In 1776, the Thirteen Colonies united together to form the new nation of America. The first president of the new nation was George Washington. The first state was Pennsylvania. The first capital was Philadelphia. The first presidential election was held in 1787, and was between George Read, a Federalist, and John Adams, a Republican. The Federalists wanted to keep the power of the federal government limited, while the Republicans wanted the power to be spread around to the states. The Federalists won, and George Washington was elected president. </s>

==========


==========
round 1:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.


<</SYS>>

tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s>

==========


==========
round 2:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.


<</SYS>>

tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s><s>[INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? </s>

==========


==========
round 3:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.


<</SYS>>

tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s><s>[INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? </s><s>[INST] Is Donald Trump the president of the United States [/INST] Yes, Donald Trump became the 45th president of the United States in January of 2016. </s>

==========
Original file line number Diff line number Diff line change
Expand Up @@ -368,4 +368,4 @@ def train(args):
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)
train(args)
Original file line number Diff line number Diff line change
Expand Up @@ -375,4 +375,4 @@ def train(args):
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)
train(args)
Original file line number Diff line number Diff line change
Expand Up @@ -340,4 +340,4 @@ def train(args):
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)
train(args)
2 changes: 1 addition & 1 deletion applications/ColossalChat/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ datasets
ninja==1.11.1
sentencepiece==0.1.99
flash-attn
tiktoken
tiktoken
168 changes: 168 additions & 0 deletions applications/ColossalChat/tests/test_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -641,3 +641,171 @@ for lora_rank in ${LORA_RANK[@]}; do
done
done
done



echo "[Test]: testing ORPO ..."

SKIPPED_TESTS=(
llama-3d-20 # 3d plugin doesn't support lora
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
llama-gemini-20 # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='2'
if [[ $plugin == "3d" ]]; then
tp='4'
bs='8'
fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto doesn't support generation
# (need to calculate ref_model logits through forwarding in inference mode)
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
declare -a dataset=()
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
$lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
--accumulation_steps $grad_accu \
--tp $tp \
--lr 2e-5 \
$grad_ckpt \
--max_len 400 \
--use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank"
exit 1
fi
done
done
done



echo "[Test]: testing KTO ..."

SKIPPED_TESTS=(
llama-3d-20 # 3d plugin doesn't support lora
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
llama-gemini-20 # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='2'
if [[ $plugin == "3d" ]]; then
tp='4'
bs='8'
fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto doesn't support generation
# (need to calculate ref_model logits through forwarding in inference mode)
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
declare -a dataset=()
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split")
done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
$lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
--accumulation_steps $grad_accu \
--tp $tp \
--lr 2e-5 \
--auto_weight \
--desirable_weight 1.2 \
$grad_ckpt \
--max_len 400 \
--use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank"
exit 1
fi
done
done
done
1 change: 1 addition & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,7 @@ def __init__(
make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
fp8_communication: bool = False,
overlap_allgather: bool = False,
inner_ring_size: int = None,
Expand Down
2 changes: 0 additions & 2 deletions colossalai/legacy/nn/layer/parallel_1d/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ def backward(ctx, grad_output):
# all-reduce scheduled first and have GPU resources allocated
# TODO: This seems to only work if you add torch.cuda.Event.wait()

# _ = torch.zeros(1, device=grad_output.device)

grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None

Expand Down
Loading

0 comments on commit db2007c

Please sign in to comment.