-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #39 from aai-institute/feature/grad
Feature: grad, div, ... operators
- Loading branch information
Showing
4 changed files
with
226 additions
and
47 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
""" | ||
`continuity.pde.grad` | ||
Functional gradients in Continuity. | ||
Derivatives are function operators, so it is natural to define them as operators | ||
within Continuity. | ||
The following gradients define several derivation operators (e.g., grad, div) | ||
that simplify the definition of PDEs in physics-informed losses. | ||
""" | ||
|
||
import torch | ||
from torch import Tensor | ||
from typing import Optional, Callable | ||
from continuity.operators.operator import Operator | ||
|
||
|
||
class Grad(Operator): | ||
"""Gradient operator. | ||
The gradient is a function operator that maps a function to its gradient. | ||
""" | ||
|
||
def forward(self, x: Tensor, u: Tensor, y: Optional[Tensor] = None) -> Tensor: | ||
"""Forward pass through the operator. | ||
Args: | ||
x: Tensor of sensor positions of shape (batch_size, num_sensors, input_coordinate_dim) | ||
u: Tensor of sensor values of shape (batch_size, num_sensors, input_channels) | ||
y: Tensor of evaluation positions of shape (batch_size, y_size, output_coordinate_dim) | ||
Returns: | ||
Tensor of evaluations of the mapped function of shape (batch_size, y_size, output_channels) | ||
""" | ||
if y is not None: | ||
assert torch.equal(x, y), "x and y must be equal for gradient operator" | ||
|
||
assert x.requires_grad, "x must require gradients for gradient operator" | ||
|
||
# Compute gradients | ||
gradients = torch.autograd.grad( | ||
u, | ||
x, | ||
grad_outputs=torch.ones_like(u), | ||
create_graph=True, | ||
retain_graph=True, | ||
)[0] | ||
|
||
return gradients | ||
|
||
|
||
def grad(u: Callable[[Tensor], Tensor]) -> Callable[[Tensor], Tensor]: | ||
"""Compute the gradient of a function. | ||
Example: | ||
Computing the gradient of the output function of an operator: | ||
```python | ||
v = lambda y: operator(x, u, y) | ||
g = grad(v)(y) | ||
``` | ||
Args: | ||
u: Function to compute the gradient of. | ||
Returns: | ||
Function that computes the gradient of the input function. | ||
""" | ||
return lambda x: Grad()(x, u(x)) | ||
|
||
|
||
class Div(Operator): | ||
"""Divergence operator. | ||
The divergence is a function operator that maps a function to its divergence. | ||
""" | ||
|
||
def forward(self, x: Tensor, u: Tensor, y: Optional[Tensor] = None) -> Tensor: | ||
"""Forward pass through the operator. | ||
Args: | ||
x: Tensor of sensor positions of shape (batch_size, num_sensors, input_coordinate_dim) | ||
u: Tensor of sensor values of shape (batch_size, num_sensors, input_channels) | ||
y: Tensor of evaluation positions of shape (batch_size, y_size, output_coordinate_dim) | ||
Returns: | ||
Tensor of evaluations of the mapped function of shape (batch_size, y_size, output_channels) | ||
""" | ||
if y is not None: | ||
assert torch.equal(x, y), "x and y must be equal for divergence operator" | ||
|
||
assert x.requires_grad, "x must require gradients for divergence operator" | ||
|
||
# Compute divergence | ||
gradients = Grad()(x, u) | ||
return torch.sum(gradients, dim=-1, keepdim=True) | ||
|
||
|
||
def div(u: Callable[[Tensor], Tensor]) -> Callable[[Tensor], Tensor]: | ||
"""Compute the divergence of a function. | ||
Example: | ||
Computing the divergence of the output function of an operator: | ||
```python | ||
v = lambda y: operator(x, u, y) | ||
d = div(v)(y) | ||
``` | ||
Args: | ||
u: Function to compute the divergence of. | ||
Returns: | ||
Function that computes the divergence of the input function. | ||
""" | ||
return lambda x: Div()(x, u(x)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import torch | ||
from continuity.pde.grad import grad, Grad, div, Div | ||
|
||
# Set random seed | ||
torch.manual_seed(0) | ||
|
||
|
||
def test_grad(): | ||
# f(x) = x_0^2 + x_1^3 | ||
def f(x): | ||
return (x[:, 0] ** 2 + x[:, 1] ** 3).unsqueeze(1) | ||
|
||
# df(x) = [2 * x_0, 3 * x_1^2] | ||
def df(x): | ||
return torch.stack([2 * x[:, 0], 3 * x[:, 1] ** 2], dim=1) | ||
|
||
x = torch.rand(100, 2).requires_grad_(True) | ||
u = f(x) | ||
|
||
# Test gradient of function | ||
gf = grad(f) | ||
assert torch.norm(gf(x) - df(x)) < 1e-6 | ||
|
||
# Test gradient operator | ||
du = Grad()(x, u, x) | ||
assert torch.norm(du - df(x)) < 1e-6 | ||
|
||
|
||
def test_div(): | ||
# f(x) = x_0^2 + x_1^3 | ||
def f(x): | ||
return (x[:, 0] ** 2 + x[:, 1] ** 3).unsqueeze(1) | ||
|
||
# div_f(x) = 2 * x_0 + 3 * x_1^2 | ||
def div_f(x): | ||
return (2 * x[:, 0] + 3 * x[:, 1] ** 2).unsqueeze(1) | ||
|
||
x = torch.rand(100, 2).requires_grad_(True) | ||
u = f(x) | ||
|
||
# Test divergence of function | ||
df = div(f) | ||
assert torch.norm(df(x) - div_f(x)) < 1e-6 | ||
|
||
# Test divergence operator | ||
div_u = Div()(x, u, x) | ||
assert torch.norm(div_u - div_f(x)) < 1e-6 | ||
|
||
|
||
if __name__ == "__main__": | ||
test_grad() | ||
test_div() |