Skip to content

Commit

Permalink
Rename multimodal data module - EnergonMultiModalDataModule (#11654)
Browse files Browse the repository at this point in the history
* rename multimodal data module

* Apply isort and black reformatting

Signed-off-by: yashaswikarnati <[email protected]>

* fix long lengths

* fix lint issues

* fix long lint issues

---------

Signed-off-by: yashaswikarnati <[email protected]>
Co-authored-by: yashaswikarnati <[email protected]>
  • Loading branch information
yashaswikarnati and yashaswikarnati authored Dec 20, 2024
1 parent ec6df08 commit 20bf678
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from megatron.core import parallel_state
from megatron.energon import DefaultTaskEncoder, WorkerConfig, get_savable_loader, get_train_dataset

from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule
from nemo.collections.multimodal.data.energon.base import EnergonMultiModalDataModule


class DiffusionDataModule(SimpleMultiModalDataModule):
class DiffusionDataModule(EnergonMultiModalDataModule):
"""
A PyTorch Lightning DataModule for handling multimodal datasets with images and text.
Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(
max_samples_per_sequence: int | None = None,
) -> None:
"""
Initialize the SimpleMultiModalDataModule.
Initialize the EnergonMultiModalDataModule.
Parameters:
path (str): Path to the dataset.
Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/diffusion/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
DiTXLConfig,
ECDiTLlama1BConfig,
)
from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule
from nemo.collections.multimodal.data.energon.base import EnergonMultiModalDataModule
from nemo.lightning.pytorch.callbacks import ModelCheckpoint, PreemptionCallback
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform
Expand All @@ -64,7 +64,7 @@ def multimodal_datamodule() -> pl.LightningDataModule:
@run.autoconvert
def simple_datamodule() -> pl.LightningDataModule:
"""Simple Datamodule Initialization"""
data_module = SimpleMultiModalDataModule(
data_module = EnergonMultiModalDataModule(
seq_length=2048,
micro_batch_size=1,
global_batch_size=32,
Expand Down Expand Up @@ -221,6 +221,7 @@ def train_mock() -> run.Partial:

@run.cli.factory(target=llm.train)
def mock_ditllama5b_8k() -> run.Partial:
"""DiT-5B mock Recipe"""
recipe = pretrain()
recipe.model.config = run.Config(DiTLlama5BConfig, max_frames=1)
recipe.data = multimodal_fake_datamodule()
Expand Down Expand Up @@ -256,6 +257,7 @@ def mock_ditllama5b_8k() -> run.Partial:

@run.cli.factory(target=llm.train)
def mock_dit7b_8k() -> run.Partial:
"""DiT-7B mock Recipe"""
recipe = mock_ditllama5b_8k()
recipe.model.config = run.Config(DiT7BConfig, max_frames=1)
recipe.data.model_config = recipe.model.config
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/multimodal/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from nemo.utils.import_utils import safe_import_from

SimpleMultiModalDataModule, _ = safe_import_from(
"nemo.collections.multimodal.data.energon", "SimpleMultiModalDataModule"
EnergonMultiModalDataModule, _ = safe_import_from(
"nemo.collections.multimodal.data.energon", "EnergonMultiModalDataModule"
)
__all__ = ["SimpleMultiModalDataModule"]
__all__ = ["EnergonMultiModalDataModule"]
4 changes: 2 additions & 2 deletions nemo/collections/multimodal/data/energon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule
from nemo.collections.multimodal.data.energon.base import EnergonMultiModalDataModule
from nemo.collections.multimodal.data.energon.config import (
ImageTextSample,
ImageToken,
Expand All @@ -28,7 +28,7 @@
)

__all__ = [
"SimpleMultiModalDataModule",
"EnergonMultiModalDataModule",
"ImageToken",
"ImageTextSample",
"MultiModalSampleConfig",
Expand Down
27 changes: 17 additions & 10 deletions nemo/collections/multimodal/data/energon/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from nemo.utils import logging


class SimpleMultiModalDataModule(pl.LightningDataModule, IOMixin):
class EnergonMultiModalDataModule(pl.LightningDataModule, IOMixin):
"""
A PyTorch Lightning DataModule for handling multimodal datasets with images and text.
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
decoder_seq_length: Optional[int] = None,
) -> None:
"""
Initialize the SimpleMultiModalDataModule.
Initialize the EnergonMultiModalDataModule.
Parameters:
path (str): Path to the dataset.
Expand All @@ -80,8 +80,10 @@ def __init__(
micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1.
num_workers (int, optional): Number of workers for data loading. Defaults to 1.
pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True.
multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples. Defaults to MultiModalSampleConfig().
task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples. If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None.
multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples.
Defaults to MultiModalSampleConfig().
task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples.
If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None.
"""

super().__init__()
Expand Down Expand Up @@ -113,7 +115,7 @@ def __init__(
self.val_dataloader_object = None

def io_init(self, **kwargs) -> fdl.Config[Self]:
# (pleasefixme) image_processor and task_encoder are problematic with Fiddle so we skip serializing them for now

cfg_kwargs = {k: deepcopy(v) for k, v in kwargs.items() if k not in ['image_processor', 'task_encoder']}

for val in cfg_kwargs.values():
Expand Down Expand Up @@ -168,15 +170,17 @@ def train_dataloader(self) -> TRAIN_DATALOADERS:
return self.train_dataloader_object
if not parallel_state.is_initialized():
logging.info(
f"Muiltimodal data loader parallel state is not initialized, using default worker config with no_workers {self.num_workers}"
f"Muiltimodal data loader parallel state is not initialized,"
f"using default worker config with no_workers {self.num_workers}"
)
worker_config = WorkerConfig.default_worker_config(self.num_workers)
else:
rank = parallel_state.get_data_parallel_rank()
world_size = parallel_state.get_data_parallel_world_size()
data_parallel_group = parallel_state.get_data_parallel_group()
logging.info(
f" Multimodal train dataloader initializing with rank {rank} world_size {world_size} data_parallel_group {data_parallel_group} ****** "
f" Multimodal train dataloader initializing with"
f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group} ****** "
)
worker_config = WorkerConfig(
rank=rank,
Expand Down Expand Up @@ -206,7 +210,8 @@ def val_dataloader(self) -> EVAL_DATALOADERS:

if not parallel_state.is_initialized():
logging.info(
f"Muiltimodal val data loader parallel state is not initialized, using default worker config with no_workers {self.num_workers}"
f"Muiltimodal val data loader parallel state is not initialized,"
"using default worker config with no_workers {self.num_workers}"
)
worker_config = WorkerConfig.default_worker_config(self.num_workers)
else:
Expand Down Expand Up @@ -276,7 +281,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
if not 'dataloader_state' in state_dict:
logging.warning(
f"Data loader state cannot be resumed from state_dict, it does not have the required key dataloader_state. It has {state_dict.keys()}"
f"Data loader state cannot be resumed from state_dict,"
f"it does not have the required key dataloader_state. It has {state_dict.keys()}"
)
return

Expand All @@ -288,7 +294,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
else:
logging.error(f"Cannot restore state from state_dict {state_dict}")
raise ValueError(
f"Cannot restore state from state_dict: Is the trainer object is initialized and attached to datamodule???"
f"Cannot restore state from state_dict: "
f"Is the trainer object is initialized and attached to datamodule???"
)
except Exception as e:
raise RuntimeError(f"Failed to dataloader restore state due to: {e}")
Expand Down
4 changes: 2 additions & 2 deletions scripts/vlm/llava_next_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main(args):
from transformers import AutoProcessor

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.multimodal.data.energon import SimpleMultiModalDataModule
from nemo.collections.multimodal.data.energon import EnergonMultiModalDataModule
from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig
from nemo.collections.vlm import LlavaNextTaskEncoder

Expand All @@ -65,7 +65,7 @@ def main(args):
image_processor=processor.image_processor,
multimodal_sample_config=multimodal_sample_config,
)
data = SimpleMultiModalDataModule(
data = EnergonMultiModalDataModule(
path=data_path,
tokenizer=tokenizer,
image_processor=processor.image_processor,
Expand Down
4 changes: 2 additions & 2 deletions scripts/vlm/llava_next_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main(args):
from transformers import AutoProcessor

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.multimodal.data.energon import SimpleMultiModalDataModule
from nemo.collections.multimodal.data.energon import EnergonMultiModalDataModule
from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig
from nemo.collections.vlm import LlavaNextTaskEncoder

Expand All @@ -67,7 +67,7 @@ def main(args):
image_processor=processor.image_processor,
multimodal_sample_config=multimodal_sample_config,
)
data = SimpleMultiModalDataModule(
data = EnergonMultiModalDataModule(
path=data_path,
tokenizer=tokenizer,
image_processor=processor.image_processor,
Expand Down
6 changes: 3 additions & 3 deletions tests/collections/multimodal/data/energon/test_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from PIL import Image
from transformers import AutoProcessor

from nemo.collections.multimodal.data.energon import ImageToken, MultiModalSampleConfig, SimpleMultiModalDataModule
from nemo.collections.multimodal.data.energon import EnergonMultiModalDataModule, ImageToken, MultiModalSampleConfig


class TestSimpleMultiModalDataModuleWithDummyData(unittest.TestCase):
class TestEnergonMultiModalDataModuleWithDummyData(unittest.TestCase):

@classmethod
def setUpClass(cls):
Expand All @@ -47,7 +47,7 @@ def setUp(self):

self.create_vqa_test_dataset(self.dataset_path, 10)

self.data_module = SimpleMultiModalDataModule(
self.data_module = EnergonMultiModalDataModule(
path=str(self.dataset_path),
tokenizer=self.tokenizer,
image_processor=self.image_processor,
Expand Down

0 comments on commit 20bf678

Please sign in to comment.