Skip to content

Commit

Permalink
docs: ✏️ support partial function and update metric_forward
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Dec 11, 2023
1 parent 9129716 commit f6d0aeb
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 5 deletions.
2 changes: 1 addition & 1 deletion basicts/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from easytorch.utils.registry import scan_modules

from .registry import SCALER_REGISTRY
from .dataset import TimeSeriesForecastingDataset
from .dataset_zoo.simple_tsf_dataset import TimeSeriesForecastingDataset

__all__ = ["SCALER_REGISTRY", "TimeSeriesForecastingDataset"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch.utils.data import Dataset

from ..utils import load_pkl
from ...utils import load_pkl


class TimeSeriesForecastingDataset(Dataset):
Expand Down
8 changes: 6 additions & 2 deletions basicts/runners/base_tsf_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import inspect
import functools
from typing import Tuple, Union, Optional, Dict

Expand Down Expand Up @@ -300,16 +301,19 @@ def metric_forward(self, metric_func, args) -> torch.Tensor:
Returns:
torch.Tensor: metric value.
"""
# filter out keys that are not in function arguments
args = {k: v for k, v in args.items() if k in metric_func.__code__.co_varnames}
covariate_names = inspect.signature(metric_func).parameters.keys()
args = {k: v for k, v in args.items() if k in covariate_names}

if isinstance(metric_func, functools.partial):
# support partial function
# users can define their partial function in the config file
# e.g., functools.partial(masked_mase, freq="4", null_val=np.nan)
if "null_val" in covariate_names and "null_val" not in metric_func.keywords: # if null_val is required but not provided
args["null_val"] = self.null_val
metric_item = metric_func(**args)
elif callable(metric_func):
# is a function
# filter out keys that are not in function arguments
metric_item = metric_func(**args, null_val=self.null_val)
else:
raise TypeError("Unknown metric type: {0}".format(type(metric_func)))
Expand Down
6 changes: 5 additions & 1 deletion basicts/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .serialization import load_adj, load_pkl, dump_pkl, load_node2vec_emb
from .misc import clock, check_nan_inf, remove_nan_inf
from .misc import partial_func as partial
from .xformer import data_transformation_4_xformer

__all__ = ["load_adj", "load_pkl", "dump_pkl", "load_node2vec_emb", "clock", "check_nan_inf", "remove_nan_inf", "data_transformation_4_xformer"]
__all__ = ["load_adj", "load_pkl", "dump_pkl",
"load_node2vec_emb", "clock", "check_nan_inf",
"remove_nan_inf", "data_transformation_4_xformer",
"partial"]
9 changes: 9 additions & 0 deletions basicts/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import time
from functools import partial

import torch


class partial_func(partial):
"""partial class.
__str__ in functools.partial contains the address of the function, which changes randomly and will disrupt easytorch's md5 calculation.
"""

def __str__(self):
return "partial({}, {})".format(self.func.__name__, self.keywords)

def clock(func):
"""clock decorator"""
def clocked(*args, **kw):
Expand Down

0 comments on commit f6d0aeb

Please sign in to comment.