Skip to content

Commit

Permalink
fix: Changes to support current implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek <[email protected]>
  • Loading branch information
Abhishek-TAMU authored and dushyantbehl committed Nov 22, 2024
1 parent 5e6fbf1 commit 0469db1
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 4 deletions.
116 changes: 116 additions & 0 deletions tests/utils/test_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# Local
from tuning.config import configs
from tuning.data.setup_dataprocessor import process_dataargs
from tuning.utils.preprocessing_utils import (
combine_sequence,
format_dataset,
Expand Down Expand Up @@ -441,6 +442,78 @@ def test_format_dataset(data_args):
assert dataset_text_field in eval_set.column_names


@pytest.mark.parametrize(
"data_args",
[
# single sequence JSON and response template
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_JSON,
validation_data_path=TWITTER_COMPLAINTS_DATA_JSON,
dataset_text_field="output",
response_template="\n### Label:",
)
),
# single sequence JSONL and response template
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_JSONL,
validation_data_path=TWITTER_COMPLAINTS_DATA_JSONL,
dataset_text_field="output",
response_template="\n### Label:",
)
),
# data formatter template with input/output JSON
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
dataset_text_field="formatted_field",
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
)
),
# data formatter template with input/output JSONL
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
dataset_text_field="formatted_field",
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
)
),
# input/output JSON with masking on input
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
)
),
# input/output JSONL with masking on input
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
)
),
],
)
def test_process_dataargs(data_args):
"""Ensure that the train/eval data are properly formatted based on the data args / text field"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
train_set, eval_set, dataset_text_field = process_dataargs(
data_args, tokenizer, max_seq_length=1024
)
assert isinstance(train_set, Dataset)
assert isinstance(eval_set, Dataset)
if dataset_text_field is None:
column_names = set(["input_ids", "attention_mask", "labels"])
assert set(eval_set.column_names) == column_names
assert set(train_set.column_names) == column_names
else:
assert dataset_text_field in train_set.column_names
assert dataset_text_field in eval_set.column_names


@pytest.mark.parametrize(
"data_args",
[
Expand Down Expand Up @@ -482,3 +555,46 @@ def test_format_dataset_pretokenized(data_args):
assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
if eval_set:
assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names))


@pytest.mark.parametrize(
"data_args",
[
# JSON pretokenized train and validation datasets
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSON,
validation_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSON,
)
),
# JSONL pretokenized train and validation datasets
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
validation_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
)
),
# JSON pretokenized train datasets
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSON,
)
),
# JSONL pretokenized train datasets
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
)
),
],
)
def test_process_dataargs_pretokenized(data_args):
"""Ensure that pretokenized datasets are loaded and returned as is"""
train_set, eval_set, _ = process_dataargs(data_args, None, max_seq_length=1024)
assert isinstance(train_set, Dataset)
if eval_set:
assert isinstance(eval_set, Dataset)

assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
if eval_set:
assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names))
7 changes: 5 additions & 2 deletions tuning/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ def tokenize_and_apply_input_masking(
# TODO: Eventually move the code here
combined = combine_sequence(input, output, eos_token=tokenizer.eos_token)

tokenized_comb_seqs = tokenizer(combined, **tokenizer_kwargs)
tokenized_input = tokenizer(input, **tokenizer_kwargs)
fn_kwargs = tokenizer_kwargs.get("fn_kwargs", {})
tokenizer_inner_kwargs = fn_kwargs.get("tokenizer_kwargs", {})

tokenized_comb_seqs = tokenizer(combined, **tokenizer_inner_kwargs)
tokenized_input = tokenizer(input, **tokenizer_inner_kwargs)

masked_labels = [-100] * len(
tokenized_input.input_ids
Expand Down
2 changes: 1 addition & 1 deletion tuning/data/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _process_dataset_configs(

kwargs["fn_kwargs"] = dict(kwargs["fn_kwargs"], **extra_kwargs)

logging.info("Applying Handler : {data_handler} Args : {kwargs}")
logging.info(f"Applying Handler : {data_handler} Args : {kwargs}")

raw_datasets = raw_datasets.map(handler, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def process_dataargs(
kwargs = {
"fn_kwargs": fn_kwargs,
"batched": False,
"remove_columns": [JSON_INPUT_KEY, JSON_OUTPUT_KEY],
"remove_columns": "all",
}

handler = DataHandlerConfig(
Expand Down

0 comments on commit 0469db1

Please sign in to comment.