diff --git a/simulai/models/_pytorch_models/_deeponet.py b/simulai/models/_pytorch_models/_deeponet.py index 46f11fbe..b26a87ed 100644 --- a/simulai/models/_pytorch_models/_deeponet.py +++ b/simulai/models/_pytorch_models/_deeponet.py @@ -126,6 +126,12 @@ def __init__( else: self.bias_wrapper = self._wrapper_bias_inactive + # Using a decoder on top of the model or not + if self.decoder_network is not None: + self.decoder_wrapper = self._wrapper_decoder_active + else: + self.decoder_wrapper = self._wrapper_decoder_inactive + # Checking the compatibility of the subnetworks outputs for each kind of product being employed. if self.product_type != "dense": output_branch = self.branch_network.output_size @@ -208,33 +214,6 @@ def _bias_compatibility_is_correct(self, dim_trunk: Union[int, tuple], "of the branch output should be" + "trunk output + var_dim.") - def _forward_decoder( - self, output_trunk: torch.Tensor = None, output_branch: torch.Tensor = None - ) -> torch.Tensor: - """ - - Forward method used when a decoder networks is present in the system. - - Parameters - ---------- - - output_trunk: torch.Tensor - The embedding generated by the trunk network. - output_branch: torch.Tensor - The embedding generated by the branch network. - - Returns - ------- - torch.Tensor - The product between the two embeddings. - - """ - - output_encoder = torch.sum(output_trunk * output_branch, dim=-1, keepdim=True) - output = self.decoder_network.forward(output_encoder) - - return output - def _forward_dense( self, output_trunk: torch.Tensor = None, output_branch: torch.Tensor = None ) -> torch.Tensor: @@ -344,12 +323,9 @@ def _forward_selector_(self) -> callable: """ if self.var_dim > 1: - # The decoder network can be used for producing multidimensional outputs - if self.decoder_network is not None: - return self._forward_decoder - # In contrast, a simple reshaping operation also can be used - elif self.product_type == "dense": + # It operates as a dense layer + if self.product_type == "dense": return self._forward_dense else: @@ -404,6 +380,20 @@ def _wrapper_bias_active( return output + def _wrapper_decoder_active( + self, + input_data: Union[np.ndarray, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.decoder.forward(input_data=input_data) + + def _wrapper_decoder_inactive( + self, + input_data: Union[np.ndarray, torch.Tensor] = None, + ) -> torch.Tensor: + + return input_data + def forward( self, input_trunk: Union[np.ndarray, torch.Tensor] = None, @@ -436,7 +426,7 @@ def forward( output = self.bias_wrapper(output_trunk=output_trunk, output_branch=output_branch) - return output * self.rescale_factors + return self.decoder_wrapper(input_data=output) * self.rescale_factors @guarantee_device def eval(