Skip to content

Commit

Permalink
docs: add doc strings
Browse files Browse the repository at this point in the history
  • Loading branch information
lllAlexanderlll committed Aug 13, 2024
1 parent 4f53f0c commit ed50d2f
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/modalities/dataloader/apply_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@


def apply_chat_template(config_file_path: Path):
"""
Applies a chat template to the given configuration file.
Args:
config_file_path (Path): The path to the configuration file.
Returns:
None
Raises:
None
"""
config_dict = load_app_config_dict(config_file_path=config_file_path)
config = InstructionTuningInstantiationModel(**config_dict)
instruction_data = _stream_jsonl(config.settings.src_path)
Expand Down
26 changes: 26 additions & 0 deletions src/modalities/dataloader/collate_fns/collator_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ def __init__(
)

def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch:
"""
Collates a batch of data by calling the wrapped collate function and applies target masking.
Args:
batch (List[Dict[str, torch.Tensor]]): A list of dictionaries, where each dictionary represents a sample
in the batch. Each dictionary contains keys corresponding to different data modalities and their
respective tensors.
Returns:
DatasetBatch: A batch of collated data with masked targets.
"""
dataset_batch = self.wrapped_collate_fn(batch)
for target_key_to_mask in self.target_keys_to_mask:
target = dataset_batch.targets[target_key_to_mask]
Expand Down Expand Up @@ -96,6 +108,20 @@ def _mask_target(
cumsum would include the begin mask token. Example without shift:
mask_no_shift_2 [0,1,0,0,-1,0,0,0]
cumsum_no_shift [0,1,1,1, 0,0,0,0]
Args:
target (torch.Tensor): The target tensor to be masked.
b_mask_token_id (int): The token ID indicating the beginning of the mask.
e_mask_token_id (int): The token ID indicating the end of the mask.
loss_ignore_index (int): The index to replace masked tokens with.
Returns:
torch.Tensor: The masked target tensor.
Raises:
ValueError: If the b_mask_token_id or e_mask_token_id is not found in the target tensor.
ValueError: If the end mask token indicator is before the begin mask token indicator in the target tensor.
ValueError: If the masking tokens are not alternating in the target tensor.
"""
error_msg = ""
if b_mask_token_id not in target:
Expand Down
11 changes: 11 additions & 0 deletions src/modalities/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,17 @@ def __getitem__(self, idx: int) -> BatchEncoding:

class PackedMemMapDatasetContinuous(PackedMemMapDatasetBase):
def __init__(self, raw_data_path: Path, sample_key: str, block_size: int, reuse_last_target: bool = True):
"""
Initializes a Dataset object. In case `reuse_last_target` is True,
we reuse the last target token as the first one for the next sample. If `reuse_last_target` is False,
we don't reuse the last target in the next sample but never have the the first token of a sample as the target.
Args:
raw_data_path (Path): The path to the raw data.
sample_key (str): The key to access the sample data.
block_size (int): The size of each data block.
reuse_last_target (bool, optional): Whether to reuse the last target. Defaults to True.
"""
self.block_size = block_size
self.reuse_last_target = reuse_last_target
super().__init__(raw_data_path=raw_data_path, sample_key=sample_key)
Expand Down
14 changes: 14 additions & 0 deletions src/modalities/dataloader/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ def get_mem_map_dataset(
def get_packed_mem_map_dataset_continuous(
raw_data_path: Path, sequence_length: int, sample_key: str, reuse_last_target: Optional[bool] = True
) -> PackedMemMapDatasetContinuous:
"""
Initializes a Dataset object. In case `reuse_last_target` is True,
we reuse the last target token as the first one for the next sample. If `reuse_last_target` is False,
we don't reuse the last target in the next sample but never have the the first token of a sample as the target.
Args:
raw_data_path (Path): The path to the raw data.
sequence_length (int): The length of each sequence.
sample_key (str): The key to access the sample data.
reuse_last_target (Optional[bool], optional): Whether to reuse the last target. Defaults to True.
Returns:
PackedMemMapDatasetContinuous: The created dataset object.
"""
dataset = PackedMemMapDatasetContinuous(
raw_data_path=raw_data_path,
block_size=sequence_length + 1,
Expand Down

0 comments on commit ed50d2f

Please sign in to comment.