From 7fd7ab1cff3a1dd10a439f2ceee9a0f9636a0114 Mon Sep 17 00:00:00 2001 From: Alexander Nikitin <1243786+AlexanderVNikitin@users.noreply.github.com> Date: Tue, 4 Jun 2024 22:28:49 +0300 Subject: [PATCH] generalize dataloader --- tests/test_utils.py | 24 +++++++++--------- tsgm/utils/datasets.py | 56 +++++++++++++----------------------------- 2 files changed, 28 insertions(+), 52 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index bdc0562..616633d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -189,7 +189,7 @@ def test_mmd_3_test(): Y = np.random.normal(10, 100, 100)[:, None] Z = np.random.normal(0, 1, 100)[:, None] - # Use custome kernels with this (TF-sklearn compatibility) + # Use custom kernels with this (TF-sklearn compatibility) # sigma_XY = tsgm.utils.kernel_median_heuristic(X, Y); # sigma_XZ = tsgm.utils.kernel_median_heuristic(X, Z); # sigma = (sigma_XY + sigma_XZ) / 2 @@ -201,16 +201,14 @@ def test_mmd_3_test(): @pytest.mark.parametrize("dataset_name", [ - "beef", - "coffee", - "ecg200", - "electric", - "freezer", - "gunpoint", - "insect", - "mixed_shapes", - "starlight", - "wafer" + "Beef", + "Coffee", + "ECG200", + "ElectricDevices", + "GunPoint", + "MixedShapesRegularTrain", + "StarLightCurves", + "Wafer" ]) def test_ucr_loadable(dataset_name): ucr_data_manager = tsgm.utils.UCRDataManager(ds=dataset_name) @@ -222,11 +220,11 @@ def test_ucr_loadable(dataset_name): def test_ucr_raises(): with pytest.raises(ValueError) as excinfo: ucr_data_manager = tsgm.utils.UCRDataManager(ds="does not exist") - assert "ds should be in" in str(excinfo.value) + assert "ds should be listed at UCR website" in str(excinfo.value) def test_get_wafer(): - dataset = "wafer" + dataset = "Wafer" ucr_data_manager = tsgm.utils.UCRDataManager(ds=dataset) assert ucr_data_manager.summary() is None X_train, y_train, X_test, y_test = ucr_data_manager.get() diff --git a/tsgm/utils/datasets.py b/tsgm/utils/datasets.py index 37ef247..0f47639 100644 --- a/tsgm/utils/datasets.py +++ b/tsgm/utils/datasets.py @@ -130,6 +130,16 @@ def gen_sine_vs_const_dataset(N: int, T: int, D: int, max_value: int = 10, const class UCRDataManager: """ A manager for UCR collection of time series datasets. + If you find these datasets useful, please cite: + @misc{UCRArchive2018, + title = {The UCR Time Series Classification Archive}, + author = {Dau, Hoang Anh and Keogh, Eamonn and Kamgar, Kaveh and Yeh, Chin-Chia Michael and Zhu, Yan + and Gharghabi, Shaghayegh and Ratanamahatana, Chotirat Ann and Yanping and Hu, Bing + and Begum, Nurjahan and Bagnall, Anthony and Mueen, Abdullah and Batista, Gustavo, and Hexagon-ML}, + year = {2018}, + month = {October}, + note = {\\url{https://www.cs.ucr.edu/~eamonn/time_series_data_2018/}} + } """ mirrors = ["https://www.cs.ucr.edu/~eamonn/time_series_data_2018/"] resources = [("UCRArchive_2018.zip", 0)] @@ -140,7 +150,7 @@ def __init__(self, path: str = default_path, ds: str = "gunpoint") -> None: """ :param path: a relative path to the stored UCR dataset. :type path: str - :param ds: Name of the dataset. Should be in (beef | coffee | ecg200 | freezer | gunpoint | insect | mixed_shapes | starlight). + :param ds: Name of the dataset. The list of names is available at https://www.cs.ucr.edu/~eamonn/time_series_data_2018/ (case sensitive!). :type ds: str :raises ValueError: When there is no stored UCR archive, or the name of the dataset is incorrect. @@ -150,48 +160,16 @@ def __init__(self, path: str = default_path, ds: str = "gunpoint") -> None: self.ds = ds.strip().lower() self.y_all: T.Optional[T.Collection[T.Hashable]] = None + path = os.path.join(path, ds) + train_files = glob.glob(os.path.join(path, "*TRAIN.tsv")) - if ds == "beef": - self.regular_train_path = os.path.join(path, "Beef") - self.small_train_path = os.path.join(path, "Beef") - elif ds == "coffee": - self.regular_train_path = os.path.join(path, "Coffee") - self.small_train_path = os.path.join(path, "Coffee") - elif ds == "ecg200": - self.regular_train_path = os.path.join(path, "ECG200") - self.small_train_path = os.path.join(path, "ECG200") - elif ds == "electric": - self.regular_train_path = os.path.join(path, "ElectricDevices") - self.small_train_path = os.path.join(path, "ElectricDevices") - elif ds == "freezer": - self.regular_train_path = os.path.join(path, "FreezerRegularTrain") - self.small_train_path = os.path.join(path, "FreezerSmallTrain") - elif ds == "gunpoint": - self.regular_train_path = os.path.join(path, "GunPoint") - self.small_train_path = os.path.join(path, "GunPoint") - elif ds == "insect": - self.regular_train_path = os.path.join(path, "InsectEPGRegularTrain") - self.small_train_path = os.path.join(path, path, "InsectEPGSmallTrain") - elif ds == "mixed_shapes": - self.regular_train_path = os.path.join(path, path, "MixedShapesRegularTrain") - self.small_train_path = os.path.join(path, path, "MixedShapesSmallTrain") - elif ds == "starlight": - self.regular_train_path = os.path.join(path, path, "StarLightCurves") - self.small_train_path = os.path.join(path, path, "StarLightCurves") - elif ds == "wafer": - self.regular_train_path = os.path.join(path, path, "Wafer") - self.small_train_path = os.path.join(path, path, "Wafer") - else: - raise ValueError("ds should be in (beef | coffee | ecg200 | freezer | gunpoint | insect | mixed_shapes | starlight)") - - self.small_train_df = pd.read_csv( - glob.glob(os.path.join(self.small_train_path, "*TRAIN.tsv"))[0], - sep='\t', header=None) + if len(train_files) == 0: + raise ValueError("ds should be listed at UCR website") self.train_df = pd.read_csv( - glob.glob(os.path.join(self.regular_train_path, "*TRAIN.tsv"))[0], + glob.glob(os.path.join(path, "*TRAIN.tsv"))[0], sep='\t', header=None) self.test_df = pd.read_csv( - glob.glob(os.path.join(self.regular_train_path, "*TEST.tsv"))[0], + glob.glob(os.path.join(path, "*TEST.tsv"))[0], sep='\t', header=None) self.X_train, self.y_train = self.train_df[self.train_df.columns[1:]].to_numpy(), self.train_df[self.train_df.columns[0]].to_numpy()