From 20bf678536673b2a2632641dc435815cc6ba3b72 Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati <144376261+yashaswikarnati@users.noreply.github.com> Date: Thu, 19 Dec 2024 23:20:13 -0800 Subject: [PATCH] Rename multimodal data module - EnergonMultiModalDataModule (#11654) * rename multimodal data module * Apply isort and black reformatting Signed-off-by: yashaswikarnati * fix long lengths * fix lint issues * fix long lint issues --------- Signed-off-by: yashaswikarnati Co-authored-by: yashaswikarnati --- .../data/diffusion_energon_datamodule.py | 6 ++--- nemo/collections/diffusion/train.py | 6 +++-- nemo/collections/multimodal/data/__init__.py | 6 ++--- .../multimodal/data/energon/__init__.py | 4 +-- .../multimodal/data/energon/base.py | 27 ++++++++++++------- scripts/vlm/llava_next_finetune.py | 4 +-- scripts/vlm/llava_next_pretrain.py | 4 +-- .../data/energon/test_data_module.py | 6 ++--- 8 files changed, 36 insertions(+), 27 deletions(-) diff --git a/nemo/collections/diffusion/data/diffusion_energon_datamodule.py b/nemo/collections/diffusion/data/diffusion_energon_datamodule.py index 07747528363a..5ad15c654555 100644 --- a/nemo/collections/diffusion/data/diffusion_energon_datamodule.py +++ b/nemo/collections/diffusion/data/diffusion_energon_datamodule.py @@ -19,10 +19,10 @@ from megatron.core import parallel_state from megatron.energon import DefaultTaskEncoder, WorkerConfig, get_savable_loader, get_train_dataset -from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule +from nemo.collections.multimodal.data.energon.base import EnergonMultiModalDataModule -class DiffusionDataModule(SimpleMultiModalDataModule): +class DiffusionDataModule(EnergonMultiModalDataModule): """ A PyTorch Lightning DataModule for handling multimodal datasets with images and text. @@ -62,7 +62,7 @@ def __init__( max_samples_per_sequence: int | None = None, ) -> None: """ - Initialize the SimpleMultiModalDataModule. + Initialize the EnergonMultiModalDataModule. Parameters: path (str): Path to the dataset. diff --git a/nemo/collections/diffusion/train.py b/nemo/collections/diffusion/train.py index 404602084b85..0db2e8fd2326 100644 --- a/nemo/collections/diffusion/train.py +++ b/nemo/collections/diffusion/train.py @@ -38,7 +38,7 @@ DiTXLConfig, ECDiTLlama1BConfig, ) -from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule +from nemo.collections.multimodal.data.energon.base import EnergonMultiModalDataModule from nemo.lightning.pytorch.callbacks import ModelCheckpoint, PreemptionCallback from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform @@ -64,7 +64,7 @@ def multimodal_datamodule() -> pl.LightningDataModule: @run.autoconvert def simple_datamodule() -> pl.LightningDataModule: """Simple Datamodule Initialization""" - data_module = SimpleMultiModalDataModule( + data_module = EnergonMultiModalDataModule( seq_length=2048, micro_batch_size=1, global_batch_size=32, @@ -221,6 +221,7 @@ def train_mock() -> run.Partial: @run.cli.factory(target=llm.train) def mock_ditllama5b_8k() -> run.Partial: + """DiT-5B mock Recipe""" recipe = pretrain() recipe.model.config = run.Config(DiTLlama5BConfig, max_frames=1) recipe.data = multimodal_fake_datamodule() @@ -256,6 +257,7 @@ def mock_ditllama5b_8k() -> run.Partial: @run.cli.factory(target=llm.train) def mock_dit7b_8k() -> run.Partial: + """DiT-7B mock Recipe""" recipe = mock_ditllama5b_8k() recipe.model.config = run.Config(DiT7BConfig, max_frames=1) recipe.data.model_config = recipe.model.config diff --git a/nemo/collections/multimodal/data/__init__.py b/nemo/collections/multimodal/data/__init__.py index 7e6ac24828f5..9a78712f026d 100644 --- a/nemo/collections/multimodal/data/__init__.py +++ b/nemo/collections/multimodal/data/__init__.py @@ -14,7 +14,7 @@ from nemo.utils.import_utils import safe_import_from -SimpleMultiModalDataModule, _ = safe_import_from( - "nemo.collections.multimodal.data.energon", "SimpleMultiModalDataModule" +EnergonMultiModalDataModule, _ = safe_import_from( + "nemo.collections.multimodal.data.energon", "EnergonMultiModalDataModule" ) -__all__ = ["SimpleMultiModalDataModule"] +__all__ = ["EnergonMultiModalDataModule"] diff --git a/nemo/collections/multimodal/data/energon/__init__.py b/nemo/collections/multimodal/data/energon/__init__.py index 04926758cbac..8c7465880b39 100644 --- a/nemo/collections/multimodal/data/energon/__init__.py +++ b/nemo/collections/multimodal/data/energon/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. -from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule +from nemo.collections.multimodal.data.energon.base import EnergonMultiModalDataModule from nemo.collections.multimodal.data.energon.config import ( ImageTextSample, ImageToken, @@ -28,7 +28,7 @@ ) __all__ = [ - "SimpleMultiModalDataModule", + "EnergonMultiModalDataModule", "ImageToken", "ImageTextSample", "MultiModalSampleConfig", diff --git a/nemo/collections/multimodal/data/energon/base.py b/nemo/collections/multimodal/data/energon/base.py index 8c7819c3d7dd..3dfd495edd82 100644 --- a/nemo/collections/multimodal/data/energon/base.py +++ b/nemo/collections/multimodal/data/energon/base.py @@ -30,7 +30,7 @@ from nemo.utils import logging -class SimpleMultiModalDataModule(pl.LightningDataModule, IOMixin): +class EnergonMultiModalDataModule(pl.LightningDataModule, IOMixin): """ A PyTorch Lightning DataModule for handling multimodal datasets with images and text. @@ -70,7 +70,7 @@ def __init__( decoder_seq_length: Optional[int] = None, ) -> None: """ - Initialize the SimpleMultiModalDataModule. + Initialize the EnergonMultiModalDataModule. Parameters: path (str): Path to the dataset. @@ -80,8 +80,10 @@ def __init__( micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1. num_workers (int, optional): Number of workers for data loading. Defaults to 1. pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True. - multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples. Defaults to MultiModalSampleConfig(). - task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples. If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None. + multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples. + Defaults to MultiModalSampleConfig(). + task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples. + If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None. """ super().__init__() @@ -113,7 +115,7 @@ def __init__( self.val_dataloader_object = None def io_init(self, **kwargs) -> fdl.Config[Self]: - # (pleasefixme) image_processor and task_encoder are problematic with Fiddle so we skip serializing them for now + cfg_kwargs = {k: deepcopy(v) for k, v in kwargs.items() if k not in ['image_processor', 'task_encoder']} for val in cfg_kwargs.values(): @@ -168,7 +170,8 @@ def train_dataloader(self) -> TRAIN_DATALOADERS: return self.train_dataloader_object if not parallel_state.is_initialized(): logging.info( - f"Muiltimodal data loader parallel state is not initialized, using default worker config with no_workers {self.num_workers}" + f"Muiltimodal data loader parallel state is not initialized," + f"using default worker config with no_workers {self.num_workers}" ) worker_config = WorkerConfig.default_worker_config(self.num_workers) else: @@ -176,7 +179,8 @@ def train_dataloader(self) -> TRAIN_DATALOADERS: world_size = parallel_state.get_data_parallel_world_size() data_parallel_group = parallel_state.get_data_parallel_group() logging.info( - f" Multimodal train dataloader initializing with rank {rank} world_size {world_size} data_parallel_group {data_parallel_group} ****** " + f" Multimodal train dataloader initializing with" + f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group} ****** " ) worker_config = WorkerConfig( rank=rank, @@ -206,7 +210,8 @@ def val_dataloader(self) -> EVAL_DATALOADERS: if not parallel_state.is_initialized(): logging.info( - f"Muiltimodal val data loader parallel state is not initialized, using default worker config with no_workers {self.num_workers}" + f"Muiltimodal val data loader parallel state is not initialized," + "using default worker config with no_workers {self.num_workers}" ) worker_config = WorkerConfig.default_worker_config(self.num_workers) else: @@ -276,7 +281,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ if not 'dataloader_state' in state_dict: logging.warning( - f"Data loader state cannot be resumed from state_dict, it does not have the required key dataloader_state. It has {state_dict.keys()}" + f"Data loader state cannot be resumed from state_dict," + f"it does not have the required key dataloader_state. It has {state_dict.keys()}" ) return @@ -288,7 +294,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: else: logging.error(f"Cannot restore state from state_dict {state_dict}") raise ValueError( - f"Cannot restore state from state_dict: Is the trainer object is initialized and attached to datamodule???" + f"Cannot restore state from state_dict: " + f"Is the trainer object is initialized and attached to datamodule???" ) except Exception as e: raise RuntimeError(f"Failed to dataloader restore state due to: {e}") diff --git a/scripts/vlm/llava_next_finetune.py b/scripts/vlm/llava_next_finetune.py index 334b360d7c70..91df8a39452d 100644 --- a/scripts/vlm/llava_next_finetune.py +++ b/scripts/vlm/llava_next_finetune.py @@ -49,7 +49,7 @@ def main(args): from transformers import AutoProcessor from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer - from nemo.collections.multimodal.data.energon import SimpleMultiModalDataModule + from nemo.collections.multimodal.data.energon import EnergonMultiModalDataModule from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig from nemo.collections.vlm import LlavaNextTaskEncoder @@ -65,7 +65,7 @@ def main(args): image_processor=processor.image_processor, multimodal_sample_config=multimodal_sample_config, ) - data = SimpleMultiModalDataModule( + data = EnergonMultiModalDataModule( path=data_path, tokenizer=tokenizer, image_processor=processor.image_processor, diff --git a/scripts/vlm/llava_next_pretrain.py b/scripts/vlm/llava_next_pretrain.py index bb84e3dae1e5..0beb9b5b08d0 100644 --- a/scripts/vlm/llava_next_pretrain.py +++ b/scripts/vlm/llava_next_pretrain.py @@ -49,7 +49,7 @@ def main(args): from transformers import AutoProcessor from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer - from nemo.collections.multimodal.data.energon import SimpleMultiModalDataModule + from nemo.collections.multimodal.data.energon import EnergonMultiModalDataModule from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig from nemo.collections.vlm import LlavaNextTaskEncoder @@ -67,7 +67,7 @@ def main(args): image_processor=processor.image_processor, multimodal_sample_config=multimodal_sample_config, ) - data = SimpleMultiModalDataModule( + data = EnergonMultiModalDataModule( path=data_path, tokenizer=tokenizer, image_processor=processor.image_processor, diff --git a/tests/collections/multimodal/data/energon/test_data_module.py b/tests/collections/multimodal/data/energon/test_data_module.py index 179d3f09f2df..c499ecfe9ca4 100644 --- a/tests/collections/multimodal/data/energon/test_data_module.py +++ b/tests/collections/multimodal/data/energon/test_data_module.py @@ -25,10 +25,10 @@ from PIL import Image from transformers import AutoProcessor -from nemo.collections.multimodal.data.energon import ImageToken, MultiModalSampleConfig, SimpleMultiModalDataModule +from nemo.collections.multimodal.data.energon import EnergonMultiModalDataModule, ImageToken, MultiModalSampleConfig -class TestSimpleMultiModalDataModuleWithDummyData(unittest.TestCase): +class TestEnergonMultiModalDataModuleWithDummyData(unittest.TestCase): @classmethod def setUpClass(cls): @@ -47,7 +47,7 @@ def setUp(self): self.create_vqa_test_dataset(self.dataset_path, 10) - self.data_module = SimpleMultiModalDataModule( + self.data_module = EnergonMultiModalDataModule( path=str(self.dataset_path), tokenizer=self.tokenizer, image_processor=self.image_processor,