Skip to content

Commit

Permalink
Commentaries are welcome
Browse files Browse the repository at this point in the history
Signed-off-by: Joao Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Aug 7, 2024
1 parent f148259 commit 9d95d32
Showing 1 changed file with 36 additions and 57 deletions.
93 changes: 36 additions & 57 deletions simulai/residuals/_pytorch_residuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@


class SymbolicOperator(torch.nn.Module):
"""The SymbolicOperatorClass is a class that constructs tensor operators using symbolic expressions written in PyTorch.
"""The SymbolicOperatorClass is a class that constructs tensor operators
using symbolic expressions written in PyTorch.
Returns:
object: An instance of the SymbolicOperatorClass.
"""
Expand Down Expand Up @@ -59,8 +58,11 @@ def __init__(
else:
pass

# The engine used to build the expressions.
# Usually PyTorch.
self.engine = importlib.import_module(engine)

# Basic attributes
self.constants = constants

if trainable_parameters is not None:
Expand All @@ -73,6 +75,7 @@ def __init__(
self.processing = processing
self.periodic_bc_protected_key = "periodic"

# Special funcions that must be replaced before the compilation
self.protected_funcs = [
"cos",
"sin",
Expand All @@ -83,8 +86,12 @@ def __init__(
"sech",
"sinh",
]

# Special operators that must be replaced before the compilation
self.protected_operators = ["L", "Div", "Grad", "Identity", "Kronecker"]

# Replacing special functions and operatorswith corresponding classes
# and objects
self.protected_funcs_subs = self._construct_protected_functions()
self.protected_operators_subs = self._construct_implict_operators()

Expand Down Expand Up @@ -139,9 +146,11 @@ def __init__(

self.output = None

self.f_expressions = list() # Main expressions, as PDEs and ODEs
self.g_expressions = dict() # Auxiliary expressions, as boundary conditions
self.h_expressions = list() # Others auxiliary expressions, as those used to evaluate special loss functions
self.f_expressions = list() # Main expressions, as PDEs and ODEs
self.g_expressions = dict() # Auxiliary expressions, as boundary conditions
self.h_expressions = (
list()
) # Others auxiliary expressions, as those used to evaluate special loss functions

self.feed_vars = None

Expand Down Expand Up @@ -171,6 +180,7 @@ def __init__(
subs.update(self.protected_funcs_subs)
subs.update(self.protected_operators_subs)

# Compiling expressions to tensor operators
for expr in self.expressions:
if not callable(expr):
f_expr = sympy.lambdify(self.all_vars, expr, subs)
Expand Down Expand Up @@ -219,12 +229,10 @@ def _subs_expr(self, expr=None, constants=None):
else:
expr = expr.subs(constants)

return expr
return expr

def _construct_protected_functions(self):
"""This function creates a dictionary of protected functions from the engine object attribute.
Returns:
dict: A dictionary of function names and their corresponding function objects.
"""
Expand All @@ -245,8 +253,6 @@ def _construct_protected_functions(self):

def _construct_implict_operators(self):
"""This function creates a dictionary of protected operators from the operators engine module.
Returns:
dict: A dictionary of operator names and their corresponding function objects.
"""
Expand Down Expand Up @@ -326,18 +332,15 @@ def _collect_data_from_inputs_list(self, inputs_list: dict = None) -> list:

def _parse_expression(self, expr=Union[sympy.Expr, str]) -> sympy.Expr:
"""Parses the input expression and returns a SymPy expression.
Args:
expr (Union[sympy.Expr, str], optional, optional): The expression to parse, by default None. It can either be a SymPy expression or a string.
expr (Union[sympy.Expr, str], optional, optional): The expression to parse, by default None.
It can either be a SymPy expression or a string.
Returns:
sympy.Expr: The parsed SymPy expression.
Raises:
Exception: If the `constants` attribute is not defined, and the input expression is a string.
"""

if isinstance(expr, str):
try:
expr_ = sympify(
Expand All @@ -347,7 +350,9 @@ def _parse_expression(self, expr=Union[sympy.Expr, str]) -> sympy.Expr:
if self.constants is not None:
expr_ = self._subs_expr(expr=expr_, constants=self.constants)
if self.trainable_parameters is not None:
expr_ = self._subs_expr(expr=expr_, constants=self.trainable_parameters)
expr_ = self._subs_expr(
expr=expr_, constants=self.trainable_parameters
)

except ValueError:
if self.constants is not None:
Expand All @@ -370,54 +375,47 @@ def _parse_expression(self, expr=Union[sympy.Expr, str]) -> sympy.Expr:

def _parse_variable(self, var=Union[sympy.Symbol, str]) -> sympy.Symbol:
"""Parse the input variable and return a SymPy Symbol.
Args:
var (Union[sympy.Symbol, str], optional, optional): The input variable, either a SymPy Symbol or a string. (Default value = Union[sympy.Symbol, str])
var (Union[sympy.Symbol, str], optional, optional): The input variable, either a SymPy Symbol or a string.
(Default value = Union[sympy.Symbol, str])
Returns:
sympy.Symbol: A SymPy Symbol representing the input variable.
"""

if isinstance(var, str):
return sympy.Symbol(var)
else:
return var

def _forward_tensor(self, input_data: torch.Tensor = None) -> torch.Tensor:
"""Forward the input tensor through the function.
Args:
input_data (torch.Tensor, optional): The input tensor. (Default value = None)
Returns:
torch.Tensor: The output tensor after forward pass.
"""

return self.function.forward(input_data=input_data)

def _forward_dict(self, input_data: dict = None) -> torch.Tensor:
"""Forward the input dictionary through the function.
Args:
input_data (dict, optional): The input dictionary. (Default value = None)
Returns:
torch.Tensor: The output tensor after forward pass.
"""

return self.function.forward(**input_data)

def _factory_process_expression_serial(self, expressions: list = None):
def _process_expression_serial(feed_vars: dict = None) -> List[torch.Tensor]:
"""Process the expression list serially using the given feed variables.
Args:
feed_vars (dict, optional): The feed variables. (Default value = None)
Returns:
List[torch.Tensor]: A list of tensors after evaluating the expressions serially.
"""

return [f(**feed_vars).to(self.device) for f in expressions]

return _process_expression_serial
Expand All @@ -427,35 +425,28 @@ def _process_expression_individual(
index: int = None, feed_vars: dict = None
) -> torch.Tensor:
"""Evaluates a single expression specified by index from the f_expressions list with given feed variables.
Args:
index (int, optional): Index of the expression to be evaluated, by default None
feed_vars (dict, optional): Dictionary of feed variables, by default None
Returns:
torch.Tensor: Result of evaluating the specified expression with given feed variables
"""

return self.expressions[index](**feed_vars).to(self.device)

return _process_expression_individual

def _create_input_for_eval(self, inputs_data: Union[np.ndarray, dict]=None) -> List[torch.Tensor]:
def _create_input_for_eval(
self, inputs_data: Union[np.ndarray, dict] = None
) -> List[torch.Tensor]:
"""Evaluate the symbolic expression.
This function takes either a numpy array or a dictionary of numpy arrays as input.
Args:
inputs_data (Union[np.ndarray, dict], optional): Union (Default value = None)
Returns:
List[torch.Tensor]: List[torch.Tensor]: A list of tensors containing the evaluated expressions.
Raises:
Raises:
does: not match with the inputs_key attribute
"""

constructor = MakeTensor(
Expand Down Expand Up @@ -497,21 +488,15 @@ def __call__(
self, inputs_data: Union[np.ndarray, dict] = None
) -> List[torch.Tensor]:
"""Evaluate the symbolic expression.
This function takes either a numpy array or a dictionary of numpy arrays as input.
Args:
inputs_data (Union[np.ndarray, dict], optional): Union (Default value = None)
Returns:
List[torch.Tensor]: List[torch.Tensor]: A list of tensors containing the evaluated expressions.
Raises:
Raises:
does: not match with the inputs_key attribute
"""

outputs, inputs = self._create_input_for_eval(inputs_data=inputs_data)

feed_vars = {**outputs, **inputs}
Expand All @@ -522,11 +507,9 @@ def __call__(

def eval_expression(self, key, inputs_list):
"""This function evaluates an expression stored in the class attribute 'g_expressions' using the inputs in 'inputs_list'. If the expression has a periodic boundary condition, the function evaluates the expression at the lower and upper boundaries and returns the difference. If the inputs are provided as a list, they are split into individual tensors and stored in a dictionary with the keys as the input names. If the inputs are provided as an np.ndarray, they are converted to tensors and split along the second axis. If the inputs are provided as a dict, they are extracted using the 'inputs_key' attribute. The inputs, along with the outputs obtained from running the function, are then passed as arguments to the expression using the 'g(**feed_vars)' syntax.
Args:
key (str): the key used to retrieve the expression from the 'g_expressions' attribute
inputs_list (list): either a list of arrays, an np.ndarray, or a dict containing the inputs to the function
Returns:
the result of evaluating the expression using the inputs.:
Expand Down Expand Up @@ -629,8 +612,8 @@ def eval_expression(self, key, inputs_list):
assert (
self.inputs_key is not None
), "If inputs_list is dict, \
it is necessary to provide\
a key."
it is necessary to provide\
a key."

inputs = {
key: value
Expand All @@ -652,11 +635,9 @@ def eval_expression(self, key, inputs_list):
@staticmethod
def gradient(feature, param):
"""Calculates the gradient of the given feature with respect to the given parameter.
Args:
feature (torch.Tensor): Tensor with the input feature.
param (torch.Tensor): Tensor with the parameter to calculate the gradient with respect to.
Returns:
torch.Tensor: Tensor with the gradient of the feature with respect to the given parameter.
Example:
Expand All @@ -679,10 +660,8 @@ def gradient(feature, param):

def jac(self, inputs):
"""Calculates the Jacobian of the forward function of the model with respect to its inputs.
Args:
inputs (torch.Tensor): Tensor with the input data to the forward function.
Returns:
torch.Tensor: Tensor with the Jacobian of the forward function with respect to its inputs.
Example:
Expand Down

0 comments on commit 9d95d32

Please sign in to comment.