diff --git a/basicts/data/__init__.py b/basicts/data/__init__.py index f8af8648..5c6e48c1 100644 --- a/basicts/data/__init__.py +++ b/basicts/data/__init__.py @@ -3,7 +3,7 @@ from easytorch.utils.registry import scan_modules from .registry import SCALER_REGISTRY -from .dataset import TimeSeriesForecastingDataset +from .dataset_zoo.simple_tsf_dataset import TimeSeriesForecastingDataset __all__ = ["SCALER_REGISTRY", "TimeSeriesForecastingDataset"] diff --git a/basicts/data/dataset.py b/basicts/data/dataset_zoo/simple_tsf_dataset.py similarity index 98% rename from basicts/data/dataset.py rename to basicts/data/dataset_zoo/simple_tsf_dataset.py index 42731a2c..2fa38d7a 100644 --- a/basicts/data/dataset.py +++ b/basicts/data/dataset_zoo/simple_tsf_dataset.py @@ -3,7 +3,7 @@ import torch from torch.utils.data import Dataset -from ..utils import load_pkl +from ...utils import load_pkl class TimeSeriesForecastingDataset(Dataset): diff --git a/basicts/runners/base_tsf_runner.py b/basicts/runners/base_tsf_runner.py index 9ca5168b..8c376d6f 100644 --- a/basicts/runners/base_tsf_runner.py +++ b/basicts/runners/base_tsf_runner.py @@ -1,4 +1,5 @@ import math +import inspect import functools from typing import Tuple, Union, Optional, Dict @@ -300,16 +301,19 @@ def metric_forward(self, metric_func, args) -> torch.Tensor: Returns: torch.Tensor: metric value. """ - # filter out keys that are not in function arguments - args = {k: v for k, v in args.items() if k in metric_func.__code__.co_varnames} + covariate_names = inspect.signature(metric_func).parameters.keys() + args = {k: v for k, v in args.items() if k in covariate_names} if isinstance(metric_func, functools.partial): # support partial function # users can define their partial function in the config file # e.g., functools.partial(masked_mase, freq="4", null_val=np.nan) + if "null_val" in covariate_names and "null_val" not in metric_func.keywords: # if null_val is required but not provided + args["null_val"] = self.null_val metric_item = metric_func(**args) elif callable(metric_func): # is a function + # filter out keys that are not in function arguments metric_item = metric_func(**args, null_val=self.null_val) else: raise TypeError("Unknown metric type: {0}".format(type(metric_func))) diff --git a/basicts/utils/__init__.py b/basicts/utils/__init__.py index e5141a90..ab892012 100644 --- a/basicts/utils/__init__.py +++ b/basicts/utils/__init__.py @@ -1,5 +1,9 @@ from .serialization import load_adj, load_pkl, dump_pkl, load_node2vec_emb from .misc import clock, check_nan_inf, remove_nan_inf +from .misc import partial_func as partial from .xformer import data_transformation_4_xformer -__all__ = ["load_adj", "load_pkl", "dump_pkl", "load_node2vec_emb", "clock", "check_nan_inf", "remove_nan_inf", "data_transformation_4_xformer"] +__all__ = ["load_adj", "load_pkl", "dump_pkl", + "load_node2vec_emb", "clock", "check_nan_inf", + "remove_nan_inf", "data_transformation_4_xformer", + "partial"] diff --git a/basicts/utils/misc.py b/basicts/utils/misc.py index 35387c53..2abf9a89 100644 --- a/basicts/utils/misc.py +++ b/basicts/utils/misc.py @@ -1,8 +1,17 @@ import time +from functools import partial import torch +class partial_func(partial): + """partial class. + __str__ in functools.partial contains the address of the function, which changes randomly and will disrupt easytorch's md5 calculation. + """ + + def __str__(self): + return "partial({}, {})".format(self.func.__name__, self.keywords) + def clock(func): """clock decorator""" def clocked(*args, **kw):