Skip to content

Commit

Permalink
RT training - option to save dataset (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
jannisborn authored Dec 5, 2022
1 parent 4732d93 commit cc23dc8
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/gt4sd/properties/tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"molecular_weight": 186.05199999999996,
"lipinski": 1,
"esol": -2.6649954522215555,
"scscore": 4.391860681753299,
"scscore": 1.6081393182467014,
"sas": 1.6564993918409403,
"bertz": 197.86256853719752,
"tpsa": 26.02,
Expand Down
7 changes: 7 additions & 0 deletions src/gt4sd/training_pipelines/regression_transformer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ class RegressionTransformerDataArguments(TrainingPipelineArguments):
"help": "Whether lines of text in the dataset are to be handled as distinct samples."
},
)
save_datasets: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to save the datasets to disk. Datasets will be saved as `.txt` file to "
"the same location where `train_data_path` and `test_data_path` live. Defaults to False."
},
)


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import json
import logging
import os
import shutil
import tempfile
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -268,6 +269,7 @@ def setup_dataset(
test_data_path: str,
line_by_line: Optional[bool],
augment: int = 0,
save_datasets: bool = False,
*args,
**kwargs,
):
Expand All @@ -282,6 +284,8 @@ def setup_dataset(
at least one column of numerical properties.
line_by_line: Whether the data can be read line-by-line from disk.
augment: How many times each training sample is augmented.
save_datasets: Whether to save the datasets to disk (will be stored in
same location as `train_data_path` and `test_data_path`).
"""

logger.info("Preparing/reading data...")
Expand All @@ -290,20 +294,31 @@ def setup_dataset(
self.tokenizer, train_data_path, test_data_path, augment=augment
)
self.tokenizer, self.properties = tokenizer, properties
datasets = [
self.create_dataset_from_list(data) for data in [train_data, test_data]
]

train_dataset = self.create_dataset_from_list(
train_data,
save_path=train_data_path.replace(".csv", ".txt")
if save_datasets
else None,
)
test_dataset = self.create_dataset_from_list(
test_data,
save_path=test_data_path.replace(".csv", ".txt") if save_datasets else None,
)
logger.info("Finished data setup.")
return datasets
return [train_dataset, test_dataset]

def create_dataset_from_list(self, data: List[str]) -> LineByLineTextDataset:
def create_dataset_from_list(
self, data: List[str], save_path: Optional[str] = None
) -> LineByLineTextDataset:
"""
Creates a LineByLineTextDataset from a List of strings.
Args:
data: List of strings with the samples.
save_path: Path to save the dataset to. Defaults to None, meaning
the dataset will not be saved.
"""

# Write files to temporary location and create data
with tempfile.TemporaryDirectory() as temp:
f_name = os.path.join(temp, "tmp_data.txt")
Expand All @@ -317,6 +332,8 @@ def create_dataset_from_list(self, data: List[str]) -> LineByLineTextDataset:
dataset = LineByLineTextDataset(
file_path=f_name, tokenizer=self.tokenizer, block_size=2**64
)
if save_path:
shutil.copyfile(f_name, save_path)
return dataset


Expand Down
7 changes: 4 additions & 3 deletions src/gt4sd/training_pipelines/regression_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import pandas as pd
from pytoda.smiles.transforms import Augment
from pytoda.transforms import AugmentByReversing
from sklearn.utils import shuffle
from terminator.selfies import encoder
from terminator.tokenization import ExpressionBertTokenizer
from transformers.hf_argparser import string_to_bool
Expand Down Expand Up @@ -126,7 +127,7 @@ def prepare_datasets_from_files(
raise TypeError(f"Please provide a csv file not {path}.")

# Load data
df = pd.read_csv(path)
df = shuffle(pd.read_csv(path))
if "text" not in df.columns:
raise ValueError("Please provide text in the `text` column.")

Expand All @@ -139,7 +140,7 @@ def prepare_datasets_from_files(
properties.remove("text")

# Parse data and create RT-compatible format
for i, row in df.iterrows():
for j, row in df.iterrows():
line = "".join(
[
f"<{p}>{row[p]:.3f}{tokenizer.expression_separator}"
Expand All @@ -152,7 +153,7 @@ def prepare_datasets_from_files(
# Perform augmentation on training data if applicable
if i == 0 and augment is not None and augment > 1:
for _ in range(augment):
for i, row in df.iterrows():
for j, row in df.iterrows():
line = "".join(
[
f"<{p}>{row[p]:.3f}{tokenizer.expression_separator}"
Expand Down

0 comments on commit cc23dc8

Please sign in to comment.