diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 400c89940..5981bfeaa 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -457,19 +457,6 @@ def __init__( drop_missing=drop_missing, ) - # Training - self.config_train = configure.Train( - quantiles=quantiles, - learning_rate=learning_rate, - epochs=epochs, - batch_size=batch_size, - loss_func=loss_func, - optimizer=optimizer, - newer_samples_weight=newer_samples_weight, - newer_samples_start=newer_samples_start, - trend_reg_threshold=trend_reg_threshold, - ) - if isinstance(collect_metrics, list): log.info( DeprecationWarning( @@ -499,6 +486,19 @@ def __init__( trend_local_reg=trend_local_reg, ) + # Training + self.config_train = configure.Train( + quantiles=quantiles, + learning_rate=learning_rate, + epochs=epochs, + batch_size=batch_size, + loss_func=loss_func, + optimizer=optimizer, + newer_samples_weight=newer_samples_weight, + newer_samples_start=newer_samples_start, + trend_reg_threshold=self.config_trend.trend_reg_threshold, + ) + # Seasonality self.config_seasonality = configure.ConfigSeasonality( mode=seasonality_mode,