Skip to content

Commit

Permalink
refactor: replace arg training with self attribute in new added models;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Sep 27, 2024
1 parent f674a3c commit 40e9602
Show file tree
Hide file tree
Showing 25 changed files with 39 additions and 37 deletions.
4 changes: 2 additions & 2 deletions pypots/imputation/grud/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
)
self.output_projection = nn.Linear(rnn_hidden_size, n_features)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
"""Forward processing of GRU-D.
Parameters
Expand Down Expand Up @@ -66,7 +66,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

# if in training mode, return results with losses
if training:
if self.training:
results["loss"] = calc_mse(reconstruction, X, missing_mask)

return results
2 changes: 1 addition & 1 deletion pypots/imputation/grud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
imputed_data = results["imputed_data"]
imputation_collector.append(imputed_data)

Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/imputeformer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
# apply SAITS loss function to Transformer on the imputation task
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
x, missing_mask = inputs["X"], inputs["missing_mask"]

# x: (batch_size, in_steps, num_nodes)
Expand Down Expand Up @@ -132,7 +132,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

# if in training mode, return results with losses
if training:
if self.training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask)
results["ORT_loss"] = ORT_loss
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/imputeformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
imputed_data = results["imputed_data"]
imputation_collector.append(imputed_data)

Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/koopa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
imputation_collector.append(results["imputed_data"])

# Step 3: output collection and return
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/micn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
# for the imputation task, the output dim is the same as input dim
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

seasonal_init, trend_init = self.decomp_multi(X)
Expand All @@ -82,7 +82,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

# if in training mode, return results with losses
if training:
if self.training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask)
results["ORT_loss"] = ORT_loss
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/micn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
imputation_collector.append(results["imputed_data"])

# Step 3: output collection and return
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/moderntcn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
individual,
)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

if self.apply_nonstationary_norm:
Expand All @@ -88,7 +88,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

# if in training mode, return results with losses
if training:
if self.training:
loss = calc_mse(reconstruction, inputs["X_ori"], inputs["indicating_mask"])
results["loss"] = loss

Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/moderntcn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
imputation_collector.append(results["imputed_data"])

# Step 3: output collection and return
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/reformer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
self.output_projection = nn.Linear(d_model, n_features)
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

# WDU: the original Reformer paper isn't proposed for imputation task. Hence the model doesn't take
Expand All @@ -75,7 +75,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

# if in training mode, return results with losses
if training:
if self.training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask)
results["ORT_loss"] = ORT_loss
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/reformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
imputation_collector.append(results["imputed_data"])

# Step 3: output collection and return
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/revinscinet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
# for the imputation task, the output dim is the same as input dim
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]
X = self.revin(X, missing_mask, mode="norm")

Expand All @@ -80,7 +80,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

# if in training mode, return results with losses
if training:
if self.training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask)
results["ORT_loss"] = ORT_loss
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/revinscinet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
imputation_collector.append(results["imputed_data"])

# Step 3: output collection and return
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/scinet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
# for the imputation task, the output dim is the same as input dim
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

# WDU: the original SCINet paper isn't proposed for imputation task. Hence the model doesn't take
Expand All @@ -76,7 +76,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

# if in training mode, return results with losses
if training:
if self.training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask)
results["ORT_loss"] = ORT_loss
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/scinet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
imputation_collector.append(results["imputed_data"])

# Step 3: output collection and return
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/stemgnn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self.output_projection = nn.Linear(d_model, n_features)
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

# WDU: the original StemGNN paper isn't proposed for imputation task. Hence the model doesn't take
Expand All @@ -69,7 +69,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

# if in training mode, return results with losses
if training:
if self.training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask)
results["ORT_loss"] = ORT_loss
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/stemgnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
imputation_collector.append(results["imputed_data"])

# Step 3: output collection and return
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/tcn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
self.output_projection = nn.Linear(channel_sizes[-1], n_features)
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

# WDU: the original TCN paper isn't proposed for imputation task. Hence the model doesn't take
Expand All @@ -68,7 +68,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

# if in training mode, return results with losses
if training:
if self.training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask)
results["ORT_loss"] = ORT_loss
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/tcn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
imputation_collector.append(results["imputed_data"])

# Step 3: output collection and return
Expand Down
5 changes: 3 additions & 2 deletions pypots/imputation/tefn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
n_fod,
)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

if self.apply_nonstationary_norm:
Expand All @@ -51,7 +51,8 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
"imputed_data": imputed_data,
}

if training:
# if in training mode, return results with losses
if self.training:
# `loss` is always the item for backward propagating to update the model
loss = calc_mse(out, inputs["X_ori"], inputs["indicating_mask"])
results["loss"] = loss
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/tefn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
imputation_collector.append(results["imputed_data"])

# Step 3: output collection and return
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/tide/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
# self.output_projection = nn.Linear(d_model, n_features)
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

# # WDU: the original TiDE paper isn't proposed for imputation task. Hence the model doesn't take
Expand Down Expand Up @@ -112,7 +112,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

# if in training mode, return results with losses
if training:
if self.training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask)
results["ORT_loss"] = ORT_loss
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/tide/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
imputation_collector.append(results["imputed_data"])

# Step 3: output collection and return
Expand Down
5 changes: 3 additions & 2 deletions pypots/imputation/timemixer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
use_future_temporal_feature=False,
)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

if self.apply_nonstationary_norm:
Expand All @@ -75,7 +75,8 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
"imputed_data": imputed_data,
}

if training:
# if in training mode, return results with losses
if self.training:
# `loss` is always the item for backward propagating to update the model
loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"])
results["loss"] = loss
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/timemixer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
imputation_collector.append(results["imputed_data"])

# Step 3: output collection and return
Expand Down

0 comments on commit 40e9602

Please sign in to comment.