Skip to content

Commit

Permalink
test: _process_dataconfig_file
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek <[email protected]>
  • Loading branch information
Abhishek-TAMU committed Nov 25, 2024
1 parent 6cad504 commit 086b8a1
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ datasets:
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_instruction_masking
- name: apply_custom_data_formatting_template
arguments:
remove_columns: all
batched: false
Expand Down
86 changes: 85 additions & 1 deletion tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Standard
import os
import tempfile

# Third Party
from datasets import Dataset
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
from trl import DataCollatorForCompletionOnlyLM
import datasets
import pytest
import yaml

# First Party
from tests.testdata import (
Expand All @@ -25,7 +30,23 @@
validate_data_args,
)
from tuning.data.data_processors import get_dataprocessor
from tuning.data.setup_dataprocessor import is_pretokenized_dataset, process_dataargs
from tuning.data.setup_dataprocessor import (
_process_dataconfig_file,
is_pretokenized_dataset,
process_dataargs,
)

BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
PREDEFINED_DATA_CONFIGS = os.path.join(BASE_DIR, "examples", "predefined_data_configs")
APPLY_CUSTOM_TEMPLATE_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml"
)
PRETOKENIZE_JSON_DATA_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml"
)
TOKENIZE_AND_INSTRUCTION_MASKING_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "tokenize_and_instruction_masking.yaml"
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -364,6 +385,69 @@ def test_validate_args_pretokenized(data_args, packing):
)


@pytest.mark.parametrize(
"data_config_path, data_path",
[
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
(
TOKENIZE_AND_INSTRUCTION_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
),
(
TOKENIZE_AND_INSTRUCTION_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
),
],
)
def test_process_dataconfig_file(data_config_path, data_path):
"""Ensure that the train/eval data are properly formatted based on the data args / text field"""
with open(data_config_path, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"][0] = data_path
datasets_name = yaml_content["datasets"][0]["name"]

# Modify input_field_name and output_field_name according to dataset
if datasets_name == "text_dataset_input_output_masking":
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"input_field_name": "input",
"output_field_name": "output",
}

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
if datasets_name == "apply_custom_data_template":
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"dataset_text_field": formatted_dataset_field,
"template": template,
}

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
packing = (False,)
max_seq_length = 1024
(train_set, _, _, _, _, _) = _process_dataconfig_file(
data_args, tokenizer, packing, max_seq_length
)
assert isinstance(train_set, Dataset)
if datasets_name == "text_dataset_input_output_masking":
column_names = set(["input_ids", "attention_mask", "labels"])
assert set(train_set.column_names) == column_names
elif datasets_name == "pretokenized_dataset":
assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
elif datasets_name == "apply_custom_data_template":
assert formatted_dataset_field in set(train_set.column_names)


@pytest.mark.parametrize(
"data_args",
[
Expand Down
2 changes: 1 addition & 1 deletion tuning/data/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig:
for p in data_paths:
assert isinstance(p, str), f"path {p} should be of the type string"
assert os.path.exists(p), f"data_paths {p} does not exist"
if not os.isabs(p):
if not os.path.isabs(p):
_p = os.path.abspath(p)
logging.warning(
" Provided path %s is not absolute changing it to %s", p, _p
Expand Down

0 comments on commit 086b8a1

Please sign in to comment.