From 8fe6533f553ae7993ed2795e2866a0dbda547d9a Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Fri, 3 Jan 2025 11:45:37 -0500 Subject: [PATCH] minor changes --- nemo/collections/llm/bert/data/fine_tuning.py | 2 +- nemo/collections/llm/bert/loss.py | 4 ++-- nemo/collections/llm/bert/model/embedding.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nemo/collections/llm/bert/data/fine_tuning.py b/nemo/collections/llm/bert/data/fine_tuning.py index 4316ea35b67c..0edc9862f4f4 100644 --- a/nemo/collections/llm/bert/data/fine_tuning.py +++ b/nemo/collections/llm/bert/data/fine_tuning.py @@ -31,7 +31,7 @@ class FineTuningDataModule(pl.LightningDataModule): - """Base class for fine-tuning an LLM. + """Base class for fine-tuning an Bert. This class provides a foundation for building custom data modules for fine-tuning Nemo NLP models. It inherits from `pl.LightningDataModule` from the PyTorch Lightning library and handles data loading, preprocessing, and batch diff --git a/nemo/collections/llm/bert/loss.py b/nemo/collections/llm/bert/loss.py index ca2e3ff87ad1..400d97418124 100644 --- a/nemo/collections/llm/bert/loss.py +++ b/nemo/collections/llm/bert/loss.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Literal import torch import torch.nn.functional as F @@ -118,7 +118,7 @@ def __init__( scale: float = 20, label_smoothing: float = 0.0, global_in_batch_negatives: bool = False, - backprop_type: str = 'local', + backprop_type: Literal["local", "global"] = 'local', ) -> None: super().__init__() self.validation_step = validation_step diff --git a/nemo/collections/llm/bert/model/embedding.py b/nemo/collections/llm/bert/model/embedding.py index de99624451de..47ba64002f4e 100644 --- a/nemo/collections/llm/bert/model/embedding.py +++ b/nemo/collections/llm/bert/model/embedding.py @@ -13,7 +13,7 @@ # limitations under the License. import sys from dataclasses import dataclass -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Literal import lightning.pytorch as L import torch @@ -78,7 +78,7 @@ def bert_embedding_forward_step(model: L.LightningModule, batch: Dict[str, torch class BertEmbeddingConfig(BertConfig): """Bert Embedding Config""" - bert_type: str = 'huggingface' + bert_type: Literal["huggingface", "megatron"] = 'huggingface' ce_loss_scale: float = 20 label_smoothing: float = 0.0 add_lm_head: bool = False @@ -86,7 +86,7 @@ class BertEmbeddingConfig(BertConfig): num_hard_negatives: int = 1 num_tokentypes: int = 2 global_in_batch_negatives: bool = True - backprop_type: str = 'local' + backprop_type: Literal["local", "global"] = 'local' forward_step_fn: Callable = bert_embedding_forward_step data_step_fn: Callable = bert_embedding_data_step