Skip to content

Commit

Permalink
change slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
MaiBe-ctrl committed Aug 5, 2024
1 parent 84ca0bb commit 989f163
Showing 1 changed file with 32 additions and 93 deletions.
125 changes: 32 additions & 93 deletions neuralprophet/time_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,31 +129,21 @@ def __init__(

# Construct index map
self.sample2index_map, self.length = self.create_sample2index_map(self.df, self.df_tensors)
self.time_offset = torch.tensor(datetime(1900, 1, 1).timestamp())
self.df_tensors["ds_seasonality"] = (self.df_tensors["ds"] - self.time_offset).float() / (3600 * 24.0)

if config_seasonality and hasattr(self.config_seasonality, "periods"):
# Precompute Fourier factors for all seasonalities
self.fourier_factors = {
name: 2.0 * np.pi / period.period for name, period in config_seasonality.periods.items()
}
else:
self.fourier_factors = {}

self.t_offset = torch.tensor(datetime(1900, 1, 1).timestamp()).float() / (3600 * 24.0)
self.precomputed_seasonality_terms = self.precompute_seasonality_terms()

# Precompute timestamps in days since 1900-01-01
self.time_tensors = self.precompute_time_tensors()
def precompute_seasonality_terms(self):
precomputed_terms = OrderedDict()

def precompute_time_tensors(self):
time_tensors = []
for origin_index in range(len(self.df_tensors["ds"])):
if self.max_lags == 0:
dates = self.df_tensors["ds"][origin_index].unsqueeze(0)
else:
dates = self.df_tensors["ds"][origin_index - self.n_lags + 1 : origin_index + self.n_forecasts + 1]
t = (dates.float() / (3600 * 24.0)) - self.t_offset
time_tensors.append(t)

return time_tensors
for name, period in self.config_seasonlity_periods.items():
if period.resolution > 0:
factor = 2.0 * np.pi / period.period
arrange_tensor = torch.arrange(1, period.resolution + 1, dtype=torch.float32)
factor_arrange = factor * arrange_tensor
precomputed_terms[name] = factor_arrange
return precomputed_terms

def __getitem__(self, index):
"""Overrides parent class method to get an item at index.
Expand Down Expand Up @@ -192,8 +182,6 @@ def __getitem__(self, index):
# Tabularize - extract features from dataframe at given target index position
inputs, target = tabularize_univariate_datetime_single_index(
df_tensors=self.df_tensors,
time_tensors=self.time_tensors,
period_factors=self.fourier_factors,
origin_index=df_index,
predict_mode=self.predict_mode,
n_lags=self.n_lags,
Expand Down Expand Up @@ -359,80 +347,35 @@ def get_sample_lagged_regressors(df_tensors, origin_index, config_lagged_regress
return lagged_regressors


def get_sample_seasonalities(
df_tensors, origin_index, n_forecasts, max_lags, n_lags, config_seasonality, time_tensors, period_factors
):
def get_sample_seasonalities(df_tensors, origin_index, n_forecasts, max_lags, n_lags, config_seasonality):
seasonalities = OrderedDict({})
t = time_tensors[origin_index]
if max_lags == 0:
dates = df_tensors["ds_seasonality"][origin_index].unsqueeze(0)
else:
dates = df_tensors["ds_seasonality"][origin_index - n_lags + 1 : origin_index + n_forecasts + 1]

for name, period in config_seasonality.periods.items():
if period.resolution > 0:
features = compute_seasonal_features(name, t, period_factors[name], period)
if config_seasonality.computation == "fourier":
factor = 2.0 * np.pi * dates[:, None] / period.period
sin_terms = torch.sin(factor * torch.arange(1, period.resolution + 1))
cos_terms = torch.cos(factor * torch.arange(1, period.resolution + 1))
features = torch.cat((sin_terms, cos_terms), dim=1)
else:
raise NotImplementedError

if period.condition_name is not None:
condition_values = get_condition_values(df_tensors, origin_index, n_forecasts, max_lags, n_lags, period)
features *= condition_values
if max_lags == 0:
condition_values = df_tensors[period.condition_name][origin_index].unsqueeze(0).unsqueeze(1)
else:
condition_values = df_tensors[period.condition_name][
origin_index - n_lags + 1 : origin_index + n_forecasts + 1
].unsqueeze(1)
features = features * condition_values
seasonalities[name] = features

return seasonalities


def get_condition_values(df_tensors, origin_index, n_forecasts, max_lags, n_lags, period):
if max_lags == 0:
return df_tensors[period.condition_name][origin_index].unsqueeze(0).unsqueeze(1)
else:
return df_tensors[period.condition_name][origin_index - n_lags + 1 : origin_index + n_forecasts + 1].unsqueeze(
1
)


def compute_seasonal_features(name, t, period_factor, period):
factor = period_factor * t[:, None]
sin_terms = torch.sin(factor * torch.arange(1, period.resolution + 1))
cos_terms = torch.cos(factor * torch.arange(1, period.resolution + 1))
return torch.cat((sin_terms, cos_terms), dim=1)


# def get_sample_seasonalities(
# df_tensors, timestamp_days, fourier_factors, origin_index, n_forecasts, max_lags, n_lags, config_seasonality
# ):
# # TODO
# # 1. Speedup: t = (dates - torch.tensor(datetime(1900, 1, 1).timestamp())).float() / (3600 * 24.0) -> save it in init
# # separate into small functions an dtrack what is causing most of teh overhead
# # experiemnt with the number of seasonalities and see how it affecst the performance
# # check pytorch and pytorch lightning, lightning fabric to do multiprocessing

# #

# seasonalities = OrderedDict({})
# if max_lags == 0:
# dates = df_tensors["ds"][origin_index].unsqueeze(0)
# else:
# dates = df_tensors["ds"][origin_index - n_lags + 1 : origin_index + n_forecasts + 1]

# t = timestamp_days[origin_index - n_lags + 1 : origin_index + n_forecasts + 1]

# for name, period in config_seasonality.periods.items():
# if period.resolution > 0:
# if config_seasonality.computation == "fourier":
# factor = fourier_factors[name]
# sin_terms = torch.sin(t[:, None] * factor)
# cos_terms = torch.cos(t[:, None] * factor)
# features = torch.cat((sin_terms, cos_terms), dim=1)
# else:
# raise NotImplementedError

# if period.condition_name is not None:
# if max_lags == 0:
# condition_values = df_tensors[period.condition_name][origin_index].unsqueeze(0).unsqueeze(1)
# else:
# condition_values = df_tensors[period.condition_name][
# origin_index - n_lags + 1 : origin_index + n_forecasts + 1
# ].unsqueeze(1)
# features = features * condition_values
# seasonalities[name] = features
# return seasonalities


def get_sample_future_regressors(
df_tensors, origin_index, n_forecasts, max_lags, n_lags, additive_regressors_names, multiplicative_regressors_names
):
Expand Down Expand Up @@ -502,8 +445,6 @@ def log_input_shapes(inputs):

def tabularize_univariate_datetime_single_index(
df_tensors: dict,
time_tensors,
period_factors,
origin_index: int,
predict_mode: bool = False,
n_lags: int = 0,
Expand Down Expand Up @@ -610,8 +551,6 @@ def tabularize_univariate_datetime_single_index(
max_lags=max_lags,
n_lags=n_lags,
config_seasonality=config_seasonality,
time_tensors=time_tensors,
period_factors=period_factors,
)

# FUTURE REGRESSORS: get the future regressors features
Expand Down

0 comments on commit 989f163

Please sign in to comment.