Skip to content

Commit

Permalink
Github Actions tests for Llava Next and modify pretrain recipe to hav…
Browse files Browse the repository at this point in the history
…e language model path (#11424)

* modified pretrain recipe to have language_model_from_pretrained

* ci test for llava next

* fixed indent/lint issue in cicd yml file

* fix lint issues

* Apply isort and black reformatting

Signed-off-by: yashaswikarnati <[email protected]>

* Update .github/workflows/cicd-main.yml

Co-authored-by: oliver könig <[email protected]>
Signed-off-by: Yashaswi Karnati <[email protected]>

* Update .github/workflows/cicd-main.yml

Co-authored-by: oliver könig <[email protected]>
Signed-off-by: Yashaswi Karnati <[email protected]>

---------

Signed-off-by: yashaswikarnati <[email protected]>
Signed-off-by: Yashaswi Karnati <[email protected]>
Co-authored-by: yashaswikarnati <[email protected]>
Co-authored-by: oliver könig <[email protected]>
  • Loading branch information
3 people authored Dec 13, 2024
1 parent 98efd37 commit aa6eba2
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 13 deletions.
16 changes: 16 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4607,6 +4607,21 @@ jobs:
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_ckpt
rm -rf /tmp/nemo2_ptq_engine
L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure-gpus-1
SCRIPT: |
python tests/collections/vlm/test_llava_next_train.py \
--devices=1 \
--max-steps=5 \
--experiment-dir=/tmp/nemo2_llava_next_results/${{ github.run_id }}
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_llava_next_results
Nemo_CICD_Test:
needs:
Expand Down Expand Up @@ -4771,6 +4786,7 @@ jobs:
- L2_Megatron_GPT_Reranker
- L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact
- L2_NeMo_2_PTQ_Llama2_FP8
- L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING
if: always()
runs-on: ubuntu-latest
steps:
Expand Down
12 changes: 6 additions & 6 deletions nemo/collections/vlm/llava_next/data/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils import data
from torch.utils.data import DataLoader, Dataset
from transformers import AutoProcessor

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.vlm.neva.data.multimodal_tokens import IMAGE_TOKEN_INDEX
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils import logging
Expand Down Expand Up @@ -79,20 +81,18 @@ def __init__(
self.persistent_workers = persistent_workers
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size

model_name = ''
processor = None
if tokenizer is None or image_processor is None:
logging.warning(
f"Processor or tokenizer are not provided! Fall back to `llava-hf/llava-v1.6-vicuna-7b-hf`."
)
from transformers import AutoProcessor

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer

model_name = "llava-hf/llava-v1.6-vicuna-7b-hf"

processor = AutoProcessor.from_pretrained(model_name)
self.tokenizer = tokenizer or AutoTokenizer(model_name)
self.image_processor = image_processor or processor.image_processor
self.tokenizer = tokenizer or AutoTokenizer(model_name)
self.image_processor = image_processor or processor.image_processor
self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
decoder_seq_len=self.decoder_seq_len,
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/vlm/recipes/llava_next_7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def pretrain_recipe(
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
peft_scheme: Optional[str] = 'none',
language_model_from_pretrained: Optional[str] = None,
) -> run.Partial:
"""
Create a Pre-training recipe for Llava1.6 7B model.
Expand Down Expand Up @@ -223,6 +223,7 @@ def pretrain_recipe(
freeze_language_model=True,
freeze_vision_model=True,
freeze_vision_projection=False,
language_model_from_pretrained=language_model_from_pretrained,
)
),
trainer=trainer,
Expand Down
13 changes: 7 additions & 6 deletions scripts/vlm/llava_next_nemo_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
from nemo.collections import vlm


def configure_recipe(nodes: int = 1, gpus_per_node: int = 8, pretrain=False):
def configure_recipe(nodes: int = 1, gpus_per_node: int = 8, pretrain=False, language_model_from_pretrained=None):
"""Configure the recipe"""
if pretrain:
recipe = vlm.llava_next_7b.pretrain_recipe(
dir="./outputs/checkpoints/llava", # Path to store checkpoints
name="llava_pretrain",
num_nodes=nodes,
num_gpus_per_node=gpus_per_node,
language_model_from_pretrained=language_model_from_pretrained,
)
else:
recipe = vlm.llava_next_7b.finetune_recipe(
Expand All @@ -33,8 +34,8 @@ def configure_recipe(nodes: int = 1, gpus_per_node: int = 8, pretrain=False):
num_nodes=nodes,
num_gpus_per_node=gpus_per_node,
)
recipe.trainer.max_steps = 100
recipe.trainer.val_check_interval = 100
recipe.trainer.max_steps = 20
recipe.trainer.val_check_interval = 20
recipe.model.config.freeze_vision_model = True
return recipe

Expand All @@ -49,9 +50,9 @@ def local_executor_torchrun(nodes: int = 1, devices: int = 8) -> run.LocalExecut
return executor


def run_pretraining():
def run_pretraining(language_model_from_pretrained=None):
# pylint: disable=C0115,C0116
recipe = configure_recipe(pretrain=True)
recipe = configure_recipe(pretrain=True, language_model_from_pretrained=language_model_from_pretrained)
executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices)

run.run(recipe, executor=executor)
Expand All @@ -67,5 +68,5 @@ def run_finetuning():

# This condition is necessary for the script to be compatible with Python's multiprocessing module.
if __name__ == "__main__":
run_pretraining()
run_pretraining(language_model_from_pretrained='/root/.cache/nemo/models/lmsys/vicuna-7b-v1.5/')
# run_finetuning()
157 changes: 157 additions & 0 deletions tests/collections/vlm/test_llava_next_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## NOTE: This script is present for github-actions testing only.
## There are no guarantees that this script is up-to-date with latest NeMo.

import argparse

import torch
from megatron.core.optimizer import OptimizerConfig
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import AutoProcessor

from nemo import lightning as nl
from nemo.collections import llm, vlm
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.llm.api import train
from nemo.lightning import AutoResume, NeMoLogger
from nemo.lightning.pytorch.callbacks import ModelCheckpoint, ParameterDebugger
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule


def get_args():
# pylint: disable=C0115,C0116
parser = argparse.ArgumentParser(description='Train a small Llava Next model using NeMo 2.0')
parser.add_argument('--devices', type=int, default=1, help="Number of devices to use for training")
parser.add_argument('--max-steps', type=int, default=5, help="Number of steps to train for")
parser.add_argument(
'--experiment-dir', type=str, default=None, help="directory to write results and checkpoints to"
)

return parser.parse_args()


if __name__ == '__main__':

args = get_args()

gbs = 2
mbs = 2
decoder_seq_length = 1024
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
tokenizer = AutoTokenizer("llava-hf/llava-v1.6-vicuna-7b-hf")

data = vlm.LlavaNextMockDataModule(
seq_length=decoder_seq_length,
tokenizer=tokenizer,
image_processor=processor.image_processor,
global_batch_size=gbs,
micro_batch_size=mbs,
num_workers=0,
)

# Transformer configurations
language_transformer_config = llm.Llama2Config7B(seq_length=decoder_seq_length, num_layers=2)

vision_transformer_config = vlm.HFCLIPVisionConfig(
pretrained_model_name_or_path="openai/clip-vit-large-patch14-336"
)
vision_projection_config = vlm.MultimodalProjectorConfig(
projector_type="mlp2x_gelu",
input_size=1024,
hidden_size=4096,
ffn_hidden_size=4096,
)

# Llava Next model configuration
neva_config = vlm.LlavaNextConfig(
language_transformer_config=language_transformer_config,
vision_transformer_config=vision_transformer_config,
vision_projection_config=vision_projection_config,
freeze_language_model=True,
freeze_vision_model=True,
)

model = vlm.LlavaNextModel(neva_config, tokenizer=data.tokenizer)

strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
encoder_pipeline_model_parallel_size=0,
pipeline_dtype=torch.bfloat16,
)
checkpoint_callback = ModelCheckpoint(
every_n_train_steps=5000,
save_optim_on_train_end=True,
)

def create_verify_precision(precision: torch.dtype):
def verify_precision(tensor: torch.Tensor) -> None:
assert tensor.dtype == precision

return verify_precision

debugger = ParameterDebugger(
param_fn=create_verify_precision(torch.bfloat16),
grad_fn=create_verify_precision(torch.float32),
log_on_hooks=["on_train_start", "on_train_end"],
)
callbacks = [checkpoint_callback, debugger]

loggers = []
tensorboard_logger = TensorBoardLogger(
save_dir='dummy', ## NOTE: this gets overwritten by default
)
loggers.append(tensorboard_logger)

opt_config = OptimizerConfig(
optimizer='adam',
lr=6e-4,
min_lr=6e-5,
use_distributed_optimizer=False,
bf16=True,
)
opt = MegatronOptimizerModule(config=opt_config)

trainer = nl.Trainer(
devices=args.devices,
max_steps=args.max_steps,
accelerator="gpu",
strategy=strategy,
logger=loggers,
callbacks=callbacks,
log_every_n_steps=1,
limit_val_batches=2,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)

nemo_logger = NeMoLogger(
log_dir=args.experiment_dir,
)

resume = AutoResume(
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
)

train(
model=model,
data=data,
trainer=trainer,
log=nemo_logger,
resume=resume,
tokenizer='data',
optim=opt,
)

0 comments on commit aa6eba2

Please sign in to comment.