From 5b077c43c565a70196583f167ada7478ff26c066 Mon Sep 17 00:00:00 2001 From: Adi Renduchintala Date: Sun, 25 Feb 2024 01:47:51 -0800 Subject: [PATCH] add alpha scaling to lora (#8248) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * removed pdeprecated eft model Signed-off-by: arendu * add alpha Signed-off-by: arendu * default for alpha Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add alpha scaling to lora (#8483) * coldfix (#8412) Signed-off-by: George Zelenfroynd Signed-off-by: Michal Futrega * Fixed errors in the CTM gen functions (#8416) (#8420) Signed-off-by: Taejin Park Co-authored-by: Taejin Park Signed-off-by: Michal Futrega * Add change_vocabulary and save_tokenizers() support to Multitask ASR models (#8357) (#8367) * Add change_vocabulary and save_tokenizers() support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update nemo/collections/asr/models/aed_multitask_models.py --------- Signed-off-by: smajumdar Signed-off-by: Somshubra Majumdar Co-authored-by: Somshubra Majumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Piotr Żelasko Signed-off-by: Michal Futrega * fix path location and branch (#8314) * fix path location and branch (#8304) * fix path location and branch Signed-off-by: Nithin Rao Koluguri * change to a floating point number Signed-off-by: Nithin Rao Koluguri --------- Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Koluguri Co-authored-by: Somshubra Majumdar * updat ebranch in tutorial Signed-off-by: Nithin Rao Koluguri --------- Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Co-authored-by: Somshubra Majumdar Co-authored-by: Nithin Rao Koluguri Signed-off-by: Michal Futrega * Add TP comm overlap knobs to AutocastTransformerLayer (#8290) Signed-off-by: Jaemin Choi Co-authored-by: Jaemin Choi Signed-off-by: Michal Futrega * add deallocate pipeline output optimization (#8279) (#8318) * add deallocate pipeline output optimization * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jimmy Zhang Co-authored-by: JimmyZhang12 <67203904+JimmyZhang12@users.noreply.github.com> Co-authored-by: Jimmy Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Michal Futrega * remove assertion (#8302) (#8321) Signed-off-by: dimapihtar Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Signed-off-by: Michal Futrega * Keep max_seqlen and cu_seqlens_argmin for later micro-batches when PP>1 (#8334) (#8346) Signed-off-by: Sangkug Lym Co-authored-by: Sangkug Lym Co-authored-by: Eric Harper Signed-off-by: Michal Futrega * Enable megatron core loggers for GPT pretraining (#8354) (#8384) * Logging changes tested for gpt_pretraining * Additional args * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Aishwarya Bhandare Co-authored-by: ashbhandare Co-authored-by: Aishwarya Bhandare Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper Signed-off-by: Michal Futrega * Fix dreambooth data sampler issue (#8400) (#8413) * Turn on drop last * Some neva fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: yaoyu-33 Co-authored-by: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Michal Futrega * add ensemble decoding fix (#8427) (#8433) Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Signed-off-by: Michal Futrega * NeVA Tutorial Notebook (#8217) * init commit - neva tutorial Signed-off-by: Pratyush Muthukumar * NeVA tutorial notebook Signed-off-by: Pratyush Muthukumar * init commit - neva tutorial Signed-off-by: Pratyush Muthukumar Signed-off-by: Pratyush Muthukumar Signed-off-by: Pratyush Muthukumar * NeVA tutorial notebook Signed-off-by: Pratyush Muthukumar Signed-off-by: Pratyush Muthukumar Signed-off-by: Pratyush Muthukumar * requested changes Signed-off-by: Pratyush Muthukumar Signed-off-by: Pratyush Muthukumar * add inference via script Signed-off-by: Pratyush Muthukumar * requested changes Signed-off-by: Pratyush Muthukumar * requested changes Signed-off-by: Pratyush Muthukumar * add codeblocks to run torchrun in notebook Signed-off-by: Pratyush Muthukumar --------- Signed-off-by: Pratyush Muthukumar Signed-off-by: Pratyush Muthukumar Co-authored-by: Pratyush Muthukumar Signed-off-by: Michal Futrega * mcore customization doc minor fix (#8421) (#8437) Signed-off-by: Huiying Li Co-authored-by: Huiying Signed-off-by: Michal Futrega * Add `loop_labels` algorithm for TDT greedy decoding (#8215) * Add `loop_labels` algorithm for TDT greedy decoding Signed-off-by: Vladimir Bataev * Use `loop_labels` by default Signed-off-by: Vladimir Bataev * Loop labels greedy decoding v2 Signed-off-by: Vladimir Bataev * Add comments. Clean up Signed-off-by: Vladimir Bataev * Add comments Signed-off-by: Vladimir Bataev * Add comments Signed-off-by: Vladimir Bataev * Add tests for batched hypotheses Signed-off-by: Vladimir Bataev * Add tests for batched alignments Signed-off-by: Vladimir Bataev * Add comments Signed-off-by: Vladimir Bataev * Fix comment Signed-off-by: Vladimir Bataev * Fix test Signed-off-by: Vladimir Bataev * Add computer for TDT Signed-off-by: Vladimir Bataev * Fix TDT decoding algorithm Signed-off-by: Vladimir Bataev * Use loop frames by default for TDT Signed-off-by: Vladimir Bataev * Remove "loop frames" implementation for TDT Signed-off-by: Vladimir Bataev * Clean up Signed-off-by: Vladimir Bataev * Add comments Signed-off-by: Vladimir Bataev * Fix confidence. Use tensor for durations. Signed-off-by: Vladimir Bataev --------- Signed-off-by: Vladimir Bataev Signed-off-by: Michal Futrega * Add dist ckpt support for regular optimizers (#7749) (#8293) * Add dist ckpt support for regular optimizers * [tutorial] fixed missing RIR scripts file. (#8257) * fix imports * imports fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * ci imports fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert asr notebook * revert asr notebook --------- Signed-off-by: Mikołaj Błaż Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: dimapihtar Co-authored-by: mikolajblaz Co-authored-by: Eric Harper Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Co-authored-by: dimapihtar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Michal Futrega * Multimodal r1.23.0 bug fix (#8315) (#8339) * Rename quick-gelu * ddpm config guard * Fix ddpm edit api * Fix insert_image_token cfg issue * neva updates * reformat * Add back jenkins * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix jenkins * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bugs * Update default neva template --------- Signed-off-by: yaoyu-33 Co-authored-by: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> Co-authored-by: Eric Harper Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Michal Futrega * mcore ds fix (#8283) (#8385) * [tutorial] fixed missing RIR scripts file. (#8257) * add values to en tts dict (#7879) * mcore ds fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update mcore * revert asr files * add comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add support for mcore mock dataset * update mcore version * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update gpt cfg * update mcore commit * fix Bert unit tests * update bert tests * fix bert mcore test * fix gpt jenkins tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update apex & TE commits * revert apex installation * turn off the fusion for jenkins --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Mariana Graterol Fuenmayor Signed-off-by: Dmytro Pykhtar Signed-off-by: dimapihtar Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Mariana <47233618+mgrafu@users.noreply.github.com> Co-authored-by: Dmytro Pykhtar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Pablo Garay Co-authored-by: Eric Harper Signed-off-by: Michal Futrega * MCore dataset compatibility for tokenizers (#8390) (#8397) * Add unique_identifiers for all tokenizers and eod for SentencePieceTokenizer * Add generalized token aliases to TokenizerSpec to conform with MegatronTokenizer's interface. Remove now-redundant individual fixes from AutoTokenizer and SentencePieceTokenizer. --------- Signed-off-by: Valerie Sarge Co-authored-by: Valerie Sarge Co-authored-by: Pablo Garay Co-authored-by: Eric Harper Signed-off-by: Michal Futrega * Canary: inference tokenization improvements; preserving custom keys when creating tarred manifests (#8432) * Improvements for Canary: - carry over custom keys when creatin tarred manifests - selectable text field in ASR eval - get rid of prompt slicing, create proper inference prompts Signed-off-by: Piotr Żelasko * set ensure_ascii=False in tarred conversion to avoid breaking tokenizers trained on UTF-8 encoding Signed-off-by: Piotr Żelasko --------- Signed-off-by: Piotr Żelasko Signed-off-by: Michal Futrega * add sbert to IR (#8445) * add sbert to IR Signed-off-by: ataghibakhsh * add doc Signed-off-by: ataghibakhsh * fix the auto_tokenizer property method reset bug Signed-off-by: ataghibakhsh * addressed bot comments Signed-off-by: ataghibakhsh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: ataghibakhsh Co-authored-by: Eric Harper Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Michal Futrega * Update readme (#8440) * update Signed-off-by: eharper * udpate Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * landing pages added * landing page added for vision * landing pages updated * some minor changes to the main readme * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * update Signed-off-by: eharper * typo fixed * update Signed-off-by: eharper --------- Signed-off-by: eharper Co-authored-by: ntajbakhsh Signed-off-by: Michal Futrega * NeMo-Mistral to HF converter bugfix. (#8353) (#8442) Signed-off-by: Alexandros Koumparoulis Co-authored-by: akoumpa <153118171+akoumpa@users.noreply.github.com> Signed-off-by: Michal Futrega * Fixing mcore bert for TP, PP and SP (#8336) (#8443) * Fixing mcore bert for TP, PP and SP * Fixing mcore bert for TP, PP and SP * Fixing mcore version * Fixing mcore version * Update Jenkinsfile * Update Jenkinsfile * Update Jenkinsfile --------- Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Co-authored-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Co-authored-by: Shanmugam Ramasamy Co-authored-by: Eric Harper Signed-off-by: Michal Futrega * Add LoRA support to all linear layers (#7988) * Added LoRA support for the Dense layer of Attention * Added LoRA MLP support to MCore and NeMo models. * Change LoRA config default to QKV. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed bug with ddp training. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * MCoreMixin chages. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * using new commit of meg-LM Signed-off-by: arendu * add cpu_offloading_num_layers to conversion script until bug in megatron is fixed Signed-off-by: Chen Cui * fix peft mixin arguments to follow mcore 0.5 Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update megatron commit to fix ci error Signed-off-by: Chen Cui * try to fix ci Signed-off-by: Chen Cui * try to fix ci Signed-off-by: Chen Cui * add cfg default Signed-off-by: Chen Cui --------- Signed-off-by: Adi Renduchintala Signed-off-by: Jiaqi Zeng Signed-off-by: arendu Signed-off-by: Chen Cui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adi Renduchintala Co-authored-by: Jiaqi Zeng Co-authored-by: arendu Co-authored-by: HeyyyyyyG <49757268+HeyyyyyyG@users.noreply.github.com> Co-authored-by: Chen Cui Co-authored-by: Eric Harper Signed-off-by: Michal Futrega * Add Neva Template for NV-DPO Models (#8358) * add/rename from nvgpt to nv_steerlm, add nv_dpo template Signed-off-by: HuiyingLi * add nv_dpo conversation to accomendate empty system message Signed-off-by: HuiyingLi * handle nv_dpo template text generation Signed-off-by: HuiyingLi * add prompt string to nvgpt Signed-off-by: HuiyingLi * bugfix for inference prompt template Signed-off-by: HuiyingLi * bug fix for grabbing clean text Signed-off-by: Huiying Li * fix code format Signed-off-by: Huiying Li --------- Signed-off-by: HuiyingLi Signed-off-by: Huiying Li Signed-off-by: Michal Futrega * Rebase scaling alpha Signed-off-by: Michal Futrega * default for alpha Signed-off-by: arendu Signed-off-by: Michal Futrega * Rebase scaling alpha Signed-off-by: Michal Futrega --------- Signed-off-by: George Zelenfroynd Signed-off-by: Michal Futrega Signed-off-by: Taejin Park Signed-off-by: smajumdar Signed-off-by: Somshubra Majumdar Signed-off-by: Nithin Rao Koluguri Signed-off-by: Jaemin Choi Signed-off-by: Jimmy Zhang Signed-off-by: dimapihtar Signed-off-by: Sangkug Lym Signed-off-by: Aishwarya Bhandare Signed-off-by: yaoyu-33 Signed-off-by: Pratyush Muthukumar Signed-off-by: Pratyush Muthukumar Signed-off-by: Huiying Li Signed-off-by: Vladimir Bataev Signed-off-by: Mikołaj Błaż Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Mariana Graterol Fuenmayor Signed-off-by: Dmytro Pykhtar Signed-off-by: Valerie Sarge Signed-off-by: Piotr Żelasko Signed-off-by: ataghibakhsh Signed-off-by: eharper Signed-off-by: Alexandros Koumparoulis Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Signed-off-by: Adi Renduchintala Signed-off-by: Jiaqi Zeng Signed-off-by: arendu Signed-off-by: Chen Cui Signed-off-by: HuiyingLi Co-authored-by: George <37293288+Jorjeous@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Taejin Park Co-authored-by: Somshubra Majumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Piotr Żelasko Co-authored-by: Nithin Rao Co-authored-by: Jaemin Choi Co-authored-by: Jaemin Choi Co-authored-by: JimmyZhang12 <67203904+JimmyZhang12@users.noreply.github.com> Co-authored-by: Jimmy Zhang Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Co-authored-by: Sangkug Lym Co-authored-by: Eric Harper Co-authored-by: ashbhandare Co-authored-by: Aishwarya Bhandare Co-authored-by: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> Co-authored-by: Pratyush Muthukumar <30813477+PannuMuthu@users.noreply.github.com> Co-authored-by: Pratyush Muthukumar Co-authored-by: Huiying Co-authored-by: Vladimir Bataev Co-authored-by: mikolajblaz Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: dimapihtar Co-authored-by: Mariana <47233618+mgrafu@users.noreply.github.com> Co-authored-by: Dmytro Pykhtar Co-authored-by: Pablo Garay Co-authored-by: Valerie Sarge Co-authored-by: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com> Co-authored-by: ntajbakhsh Co-authored-by: akoumpa <153118171+akoumpa@users.noreply.github.com> Co-authored-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Co-authored-by: Shanmugam Ramasamy Co-authored-by: Tugrul Konuk Co-authored-by: Adi Renduchintala Co-authored-by: Jiaqi Zeng Co-authored-by: arendu Co-authored-by: HeyyyyyyG <49757268+HeyyyyyyG@users.noreply.github.com> Co-authored-by: Chen Cui --------- Signed-off-by: arendu Signed-off-by: George Zelenfroynd Signed-off-by: Michal Futrega Signed-off-by: Taejin Park Signed-off-by: smajumdar Signed-off-by: Somshubra Majumdar Signed-off-by: Nithin Rao Koluguri Signed-off-by: Jaemin Choi Signed-off-by: Jimmy Zhang Signed-off-by: dimapihtar Signed-off-by: Sangkug Lym Signed-off-by: Aishwarya Bhandare Signed-off-by: yaoyu-33 Signed-off-by: Pratyush Muthukumar Signed-off-by: Pratyush Muthukumar Signed-off-by: Huiying Li Signed-off-by: Vladimir Bataev Signed-off-by: Mikołaj Błaż Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Mariana Graterol Fuenmayor Signed-off-by: Dmytro Pykhtar Signed-off-by: Valerie Sarge Signed-off-by: Piotr Żelasko Signed-off-by: ataghibakhsh Signed-off-by: eharper Signed-off-by: Alexandros Koumparoulis Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Signed-off-by: Adi Renduchintala Signed-off-by: Jiaqi Zeng Signed-off-by: Chen Cui Signed-off-by: HuiyingLi Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michal Futrega Co-authored-by: George <37293288+Jorjeous@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Taejin Park Co-authored-by: Somshubra Majumdar Co-authored-by: Piotr Żelasko Co-authored-by: Nithin Rao Co-authored-by: Jaemin Choi Co-authored-by: Jaemin Choi Co-authored-by: JimmyZhang12 <67203904+JimmyZhang12@users.noreply.github.com> Co-authored-by: Jimmy Zhang Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Co-authored-by: Sangkug Lym Co-authored-by: Eric Harper Co-authored-by: ashbhandare Co-authored-by: Aishwarya Bhandare Co-authored-by: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> Co-authored-by: Pratyush Muthukumar <30813477+PannuMuthu@users.noreply.github.com> Co-authored-by: Pratyush Muthukumar Co-authored-by: Huiying Co-authored-by: Vladimir Bataev Co-authored-by: mikolajblaz Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: dimapihtar Co-authored-by: Mariana <47233618+mgrafu@users.noreply.github.com> Co-authored-by: Dmytro Pykhtar Co-authored-by: Pablo Garay Co-authored-by: Valerie Sarge Co-authored-by: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com> Co-authored-by: ntajbakhsh Co-authored-by: akoumpa <153118171+akoumpa@users.noreply.github.com> Co-authored-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Co-authored-by: Shanmugam Ramasamy Co-authored-by: Tugrul Konuk Co-authored-by: Jiaqi Zeng Co-authored-by: HeyyyyyyG <49757268+HeyyyyyyG@users.noreply.github.com> Co-authored-by: Chen Cui --- .../tuning/conf/megatron_gpt_finetuning_config.yaml | 1 + .../modules/common/megatron/adapters/parallel_adapters.py | 6 ++++++ nemo/collections/nlp/parts/peft_config.py | 1 + 3 files changed, 8 insertions(+) diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml index af561ffe0aad..96752696da41 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml @@ -96,6 +96,7 @@ model: lora_tuning: target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) adapter_dim: 32 + alpha: ${model.peft.lora_tuning.adapter_dim} adapter_dropout: 0.0 column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index ac85ea7a1d2e..9690d5d21697 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -139,6 +139,7 @@ def __init__( input_is_parallel: bool = False, # NOTE: (@ertkonuk) we need this for LoRA adapters that are applied to RowParallelLinear layers dropout: float = 0.0, model_parallel_config: Optional[ModelParallelConfig] = None, + alpha: float | None = None, **kwargs, ): super().__init__() @@ -151,7 +152,9 @@ def __init__( self.activation = activation_registry[activation]() self.norm_position = norm_position self.dim = dim + self.alpha = alpha if alpha is not None else self.dim self.input_is_parallel = input_is_parallel + # megatron_gpt_peft_models will provide this arg, but deprecated ones do not. # in case this arg is not provided, use the dummy default config. if model_parallel_config is None: @@ -274,6 +277,8 @@ def forward(self, x): if self.dropout is not None: x = self.dropout(x) + x = x * (self.alpha / self.dim) + return x @@ -290,6 +295,7 @@ class ParallelLinearAdapterConfig(AdapterConfig): gather_output: bool = True input_is_parallel: bool = False dropout: float = 0.0 + alpha: float | None = None network_alpha: int | None = None _target_: str = "{0}.{1}".format(ParallelLinearAdapter.__module__, ParallelLinearAdapter.__name__) diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py index 815ad4d9e952..97305991d0b3 100644 --- a/nemo/collections/nlp/parts/peft_config.py +++ b/nemo/collections/nlp/parts/peft_config.py @@ -182,6 +182,7 @@ def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_ "row_init_method": lora_cfg.get("row_init_method", "zero"), "gather_output": False, "dropout": lora_cfg.adapter_dropout, + "alpha": lora_cfg.get("alpha", lora_cfg.adapter_dim), } if lora_cfg.weight_tying: