Skip to content

Commit

Permalink
Decoder network works better as an operation on top of DeepONet
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Sep 27, 2023
1 parent 2e9b987 commit a9ee18c
Showing 1 changed file with 23 additions and 33 deletions.
56 changes: 23 additions & 33 deletions simulai/models/_pytorch_models/_deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ def __init__(
else:
self.bias_wrapper = self._wrapper_bias_inactive

# Using a decoder on top of the model or not
if self.decoder_network is not None:
self.decoder_wrapper = self._wrapper_decoder_active
else:
self.decoder_wrapper = self._wrapper_decoder_inactive

# Checking the compatibility of the subnetworks outputs for each kind of product being employed.
if self.product_type != "dense":
output_branch = self.branch_network.output_size
Expand Down Expand Up @@ -208,33 +214,6 @@ def _bias_compatibility_is_correct(self, dim_trunk: Union[int, tuple],
"of the branch output should be" +
"trunk output + var_dim.")

def _forward_decoder(
self, output_trunk: torch.Tensor = None, output_branch: torch.Tensor = None
) -> torch.Tensor:
"""
Forward method used when a decoder networks is present in the system.
Parameters
----------
output_trunk: torch.Tensor
The embedding generated by the trunk network.
output_branch: torch.Tensor
The embedding generated by the branch network.
Returns
-------
torch.Tensor
The product between the two embeddings.
"""

output_encoder = torch.sum(output_trunk * output_branch, dim=-1, keepdim=True)
output = self.decoder_network.forward(output_encoder)

return output

def _forward_dense(
self, output_trunk: torch.Tensor = None, output_branch: torch.Tensor = None
) -> torch.Tensor:
Expand Down Expand Up @@ -344,12 +323,9 @@ def _forward_selector_(self) -> callable:
"""

if self.var_dim > 1:
# The decoder network can be used for producing multidimensional outputs
if self.decoder_network is not None:
return self._forward_decoder

# In contrast, a simple reshaping operation also can be used
elif self.product_type == "dense":
# It operates as a dense layer
if self.product_type == "dense":
return self._forward_dense

else:
Expand Down Expand Up @@ -404,6 +380,20 @@ def _wrapper_bias_active(

return output

def _wrapper_decoder_active(
self,
input_data: Union[np.ndarray, torch.Tensor] = None,
) -> torch.Tensor:

return self.decoder.forward(input_data=input_data)

def _wrapper_decoder_inactive(
self,
input_data: Union[np.ndarray, torch.Tensor] = None,
) -> torch.Tensor:

return input_data

def forward(
self,
input_trunk: Union[np.ndarray, torch.Tensor] = None,
Expand Down Expand Up @@ -436,7 +426,7 @@ def forward(

output = self.bias_wrapper(output_trunk=output_trunk, output_branch=output_branch)

return output * self.rescale_factors
return self.decoder_wrapper(input_data=output) * self.rescale_factors

@guarantee_device
def eval(
Expand Down

0 comments on commit a9ee18c

Please sign in to comment.