Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support/alternative expressions #195

Merged
merged 9 commits into from
Aug 2, 2024
135 changes: 103 additions & 32 deletions simulai/residuals/_pytorch_residuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
device: str = "cpu",
engine: str = "torch",
auxiliary_expressions: list = None,
special_expressions: list = None,
) -> None:
if engine == "torch":
super(SymbolicOperator, self).__init__()
Expand All @@ -72,8 +73,17 @@ def __init__(
self.processing = processing
self.periodic_bc_protected_key = "periodic"

self.protected_funcs = ["cos", "sin", "sqrt", "exp", "tanh", "cosh", "sech", "sinh"]
self.protected_operators = ["L", "Div", "Identity", "Kronecker"]
self.protected_funcs = [
"cos",
"sin",
"sqrt",
"exp",
"tanh",
"cosh",
"sech",
"sinh",
]
self.protected_operators = ["L", "Div", "Grad", "Identity", "Kronecker"]

self.protected_funcs_subs = self._construct_protected_functions()
self.protected_operators_subs = self._construct_implict_operators()
Expand Down Expand Up @@ -103,6 +113,8 @@ def __init__(
else:
self.auxiliary_expressions = auxiliary_expressions

self.special_expressions = special_expressions

self.input_vars = [self._parse_variable(var=var) for var in input_vars]
self.output_vars = [self._parse_variable(var=var) for var in output_vars]

Expand All @@ -127,8 +139,9 @@ def __init__(

self.output = None

self.f_expressions = list()
self.g_expressions = dict()
self.f_expressions = list() # Main expressions, as PDEs and ODEs
self.g_expressions = dict() # Auxiliary expressions, as boundary conditions
self.h_expressions = list() # Others auxiliary expressions, as those used to evaluate special loss functions

self.feed_vars = None

Expand All @@ -152,9 +165,11 @@ def __init__(
else:
gradient_function = gradient

# Diff symbol is related to automatic differentiation
subs = {self.diff_symbol.name: gradient_function}
subs.update(self.external_functions)
subs.update(self.protected_funcs_subs)
subs.update(self.protected_operators_subs)

for expr in self.expressions:
if not callable(expr):
Expand All @@ -164,6 +179,7 @@ def __init__(

self.f_expressions.append(f_expr)

# auxiliary expressions (usually boundary conditions)
if self.auxiliary_expressions is not None:
for key, expr in self.auxiliary_expressions.items():
if not callable(expr):
Expand All @@ -173,12 +189,38 @@ def __init__(

self.g_expressions[key] = g_expr

# special expressions (usually employed for certain kinds of loss functions)
if special_expressions is not None:
for expr in self.special_expressions:
if not callable(expr):
h_expr = sympy.lambdify(self.all_vars, expr, subs)
else:
h_expr = expr

self.h_expressions.append(h_expr)

self.process_special_expression = self._factory_process_expression_serial(
expressions=self.h_expressions
)

# Method for executing the expressions evaluation
if self.processing == "serial":
self.process_expression = self._process_expression_serial
self.process_expression = self._factory_process_expression_serial(
expressions=self.f_expressions
)
else:
raise Exception(f"Processing case {self.processing} not supported.")

def _subs_expr(self, expr=None, constants=None):

if isinstance(expr, list):
for j, e in enumerate(expr):
expr[j] = e.subs(constants)
else:
expr = expr.subs(constants)

return expr

def _construct_protected_functions(self):
"""This function creates a dictionary of protected functions from the engine object attribute.

Expand Down Expand Up @@ -303,9 +345,10 @@ def _parse_expression(self, expr=Union[sympy.Expr, str]) -> sympy.Expr:
)

if self.constants is not None:
expr_ = expr_.subs(self.constants)
expr_ = self._subs_expr(expr=expr_, constants=self.constants)
if self.trainable_parameters is not None:
expr_ = expr_.subs(self.trainable_parameters)
expr_ = self._subs_expr(expr=expr_, constants=self.trainable_parameters)

except ValueError:
if self.constants is not None:
_expr = expr
Expand Down Expand Up @@ -364,36 +407,40 @@ def _forward_dict(self, input_data: dict = None) -> torch.Tensor:
"""
return self.function.forward(**input_data)

def _process_expression_serial(self, feed_vars: dict = None) -> List[torch.Tensor]:
"""Process the expression list serially using the given feed variables.
def _factory_process_expression_serial(self, expressions: list = None):
def _process_expression_serial(feed_vars: dict = None) -> List[torch.Tensor]:
"""Process the expression list serially using the given feed variables.

Args:
feed_vars (dict, optional): The feed variables. (Default value = None)
Args:
feed_vars (dict, optional): The feed variables. (Default value = None)

Returns:
List[torch.Tensor]: A list of tensors after evaluating the expressions serially.
Returns:
List[torch.Tensor]: A list of tensors after evaluating the expressions serially.

"""
return [f(**feed_vars).to(self.device) for f in self.f_expressions]
"""
return [f(**feed_vars).to(self.device) for f in expressions]

def _process_expression_individual(
self, index: int = None, feed_vars: dict = None
) -> torch.Tensor:
"""Evaluates a single expression specified by index from the f_expressions list with given feed variables.
return _process_expression_serial

Args:
index (int, optional): Index of the expression to be evaluated, by default None
feed_vars (dict, optional): Dictionary of feed variables, by default None
def _factory_process_expression_individual(self, expressions: list = None):
def _process_expression_individual(
index: int = None, feed_vars: dict = None
) -> torch.Tensor:
"""Evaluates a single expression specified by index from the f_expressions list with given feed variables.

Returns:
torch.Tensor: Result of evaluating the specified expression with given feed variables
Args:
index (int, optional): Index of the expression to be evaluated, by default None
feed_vars (dict, optional): Dictionary of feed variables, by default None

"""
return self.f_expressions[index](**feed_vars).to(self.device)
Returns:
torch.Tensor: Result of evaluating the specified expression with given feed variables

def __call__(
self, inputs_data: Union[np.ndarray, dict] = None
) -> List[torch.Tensor]:
"""
return self.expressions[index](**feed_vars).to(self.device)

return _process_expression_individual

def _create_input_for_eval(self, inputs_data: Union[np.ndarray, dict]=None) -> List[torch.Tensor]:
"""Evaluate the symbolic expression.

This function takes either a numpy array or a dictionary of numpy arrays as input.
Expand All @@ -410,6 +457,7 @@ def __call__(
does: not match with the inputs_key attribute

"""

constructor = MakeTensor(
input_names=self.input_names, output_names=self.output_names
)
Expand Down Expand Up @@ -443,6 +491,29 @@ def __call__(
for inputs_list"
)

return outputs, inputs

def __call__(
self, inputs_data: Union[np.ndarray, dict] = None
) -> List[torch.Tensor]:
"""Evaluate the symbolic expression.

This function takes either a numpy array or a dictionary of numpy arrays as input.

Args:
inputs_data (Union[np.ndarray, dict], optional): Union (Default value = None)

Returns:
List[torch.Tensor]: List[torch.Tensor]: A list of tensors containing the evaluated expressions.

Raises:

Raises:
does: not match with the inputs_key attribute

"""
outputs, inputs = self._create_input_for_eval(inputs_data=inputs_data)

feed_vars = {**outputs, **inputs}

# It returns a list of tensors containing the expressions
Expand Down Expand Up @@ -631,20 +702,20 @@ def sech(self, x):

cosh = getattr(self.engine, "cosh")

return 1/cosh(x)
return 1 / cosh(x)

def csch(self, x):

sinh = getattr(self.engine, "sinh")

return 1/sinh(x)
return 1 / sinh(x)

def coth(self, x):

cosh = getattr(self.engine, "cosh")
sinh = getattr(self.engine, "sinh")

return cosh(x)/sinh(x)
return cosh(x) / sinh(x)


def diff(feature: torch.Tensor, param: torch.Tensor) -> torch.Tensor:
Expand Down
32 changes: 32 additions & 0 deletions simulai/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,38 @@ def Div(u: sympy.Symbol, vars: tuple) -> callable:

return l

def Grad(u: sympy.Symbol, vars: tuple) -> callable:
"""
Generate a callable object to compute the gradient operator.

The gradient operator is a first-order differential operator that measures the
magnitude and direction of a flow of a vector field from its source and
convergence to a point.

Parameters
----------
u : sympy.Symbol
The vector field to compute the divergence of.
vars : tuple
A tuple of variables to compute the divergence with respect to.

Returns
-------
callable
A callable object that computes the divergence of a vector field with respect
to the given variables.

Examples
--------
>>> x, y, z = sympy.symbols('x y z')
>>> u = sympy.Matrix([x**2, y**2, z**2])
>>> Grad(u, (x, y, z))
2*x + 2*y + 2*z
"""
g = [D(u, var) for var in vars]

return g


def Gp(
g0: Union[torch.tensor, float], r: Union[torch.tensor, float], n: int
Expand Down
36 changes: 36 additions & 0 deletions tests/residuals/test_symbolicoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,42 @@ def test_symbolic_operator_diff_operators(self):

assert all([isinstance(item, torch.Tensor) for item in residual(data)])

def test_symbolic_operator_grad_operator(self):

f = "D(u, x) - D(u, y)"
s_1 = "D(D(u, x) - D(u, y), x)"
s_2 = "D(D(u, x) - D(u, y), y)"

input_labels = ["x", "y"]
output_labels = ["u"]

L_x = 1
L_y = 1
N_x = 100
N_y = 100
dx = L_x / N_x
dy = L_y / N_y

grid = np.mgrid[0:L_x:dx, 0:L_y:dy]

data = np.hstack([grid[1].flatten()[:, None], grid[0].flatten()[:, None]])

net = model(n_inputs=len(input_labels), n_outputs=len(output_labels))

residual = SymbolicOperator(
expressions=[f],
special_expressions=[s_1, s_2],
input_vars=input_labels,
output_vars=output_labels,
function=net,
engine="torch",
)
u = net(input_data=data)
outputs, inputs = residual._create_input_for_eval(inputs_data=data)
feed_vars = {**outputs, **inputs}

all(isinstance(item, torch.Tensor) for item in residual.process_special_expression(feed_vars))

def test_symbolic_operator_1d_pde(self):
# Allen-Cahn equation
f_0 = "D(u, t) - mu*D(D(u, x), x) + alpha*(u**3) + beta*u"
Expand Down
Loading