Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
suiyoubi committed Jan 3, 2025
1 parent 5122ded commit 8fe6533
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/llm/bert/data/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


class FineTuningDataModule(pl.LightningDataModule):
"""Base class for fine-tuning an LLM.
"""Base class for fine-tuning an Bert.
This class provides a foundation for building custom data modules for fine-tuning Nemo NLP models. It inherits from
`pl.LightningDataModule` from the PyTorch Lightning library and handles data loading, preprocessing, and batch
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/llm/bert/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Literal

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(
scale: float = 20,
label_smoothing: float = 0.0,
global_in_batch_negatives: bool = False,
backprop_type: str = 'local',
backprop_type: Literal["local", "global"] = 'local',
) -> None:
super().__init__()
self.validation_step = validation_step
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/llm/bert/model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import sys
from dataclasses import dataclass
from typing import Callable, Dict, Optional
from typing import Callable, Dict, Optional, Literal

import lightning.pytorch as L
import torch
Expand Down Expand Up @@ -78,15 +78,15 @@ def bert_embedding_forward_step(model: L.LightningModule, batch: Dict[str, torch
class BertEmbeddingConfig(BertConfig):
"""Bert Embedding Config"""

bert_type: str = 'huggingface'
bert_type: Literal["huggingface", "megatron"] = 'huggingface'
ce_loss_scale: float = 20
label_smoothing: float = 0.0
add_lm_head: bool = False
bert_binary_head: bool = False
num_hard_negatives: int = 1
num_tokentypes: int = 2
global_in_batch_negatives: bool = True
backprop_type: str = 'local'
backprop_type: Literal["local", "global"] = 'local'
forward_step_fn: Callable = bert_embedding_forward_step
data_step_fn: Callable = bert_embedding_data_step

Expand Down

0 comments on commit 8fe6533

Please sign in to comment.