Skip to content

Commit

Permalink
add requested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobEliasWagner committed Aug 14, 2024
1 parent d8feabb commit 228802b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 43 deletions.
98 changes: 57 additions & 41 deletions src/continuiti/operators/deep_cat_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,57 @@ class DeepCatOperator(Operator):
This class implements the DeepCatOperator, a neural operator inspired by the DeepONet. It consists of three main
parts:
1. **Input Network**: Analogous to the "branch network," it processes the sensor inputs (`u`).
2. **Eval Network**: Analogous to the "trunk network," it processes the evaluation locations (`y`).
3. **Cat Network**: Combines the outputs from the Input and Eval Networks to produce the final output.
The architecture offers three potential advantages:
1. It allows the operator to integrate evaluation locations earlier, enabling a higher level of adaptive
abstraction.
2. The hyperparameters can be thought of as a control mechanism, dictating the flow of information. The
`input_cat_ratio` hyperparameter provides a control mechanism for the information flow, allowing fine-tuning of
the contributions from the Input and Eval Networks.
3. It can achieve a high level of abstraction without relying on learning basis functions, evaluated in a single
operation (dot product).
1. **Branch Network**: Processes the sensor inputs (`u`).
2. **Trunk Network**: Processes the evaluation locations (`y`).
3. **Cat Network**: Combines the outputs from the Branch- and Trunk-Network to produce the final output.
This allows the operator to integrate evaluation locations earlier, while ensuring that both the sensor inputs and
the evaluation location contribute in a predictable form to the flow of information. Directly stacking both the
sensors and evaluation location can lead to an imbalance in the number of features in the neural operator. The
arg `branch_cat_ratio` dictates how this fraction is set (defaults to 50/50). The cat-network does not require the
neural operator to learn good basis functions. The information from the input space and the evaluation locations
can be taken into account early, allowing for better abstraction.
┌─────────────────────┐ ┌────────────────────┐
│ *Branch Network* │ │ *Trunk Network* │
│ Input (u) │ │ Input (y) │
│ Output (b) │ │ Output (t) │
└─────────────────┬───┘ └──┬─────────────────┘
┌─────────────────┴──────────┴─────────────────┐
│ *Concatenation* │
│ Input (b, t) │
│ Output (c) │
│ b.numel() / cat_net_width = branch_cat_ratio │
└────────────────────┬─────────────────────────┘
┌────────┴─────────┐
│ *Cat Network* │
│ Input (c) │
│ Output (v) │
└──────────────────┘
Args:
shapes: Operator shapes.
input_net_width: Width of the input net (deep residual network). Defaults to 32.
input_net_depth: Depth of the input net (deep residual network). Defaults to 4.
eval_net_width: Width of the eval net (deep residual network). Defaults to 32.
eval_net_depth: Depth of the eval net (deep residual network). Defaults to 4.
input_cat_ratio: Ratio indicating how many values of the concatenated tensor originates from the input net.
Controls flow of information into input- and eval-net. Defaults to 0.5.
branch_width: Width of the branch net (deep residual network). Defaults to 32.
branch_depth: Depth of the branch net (deep residual network). Defaults to 4.
trunk_width: Width of the trunk net (deep residual network). Defaults to 32.
trunk_depth: Depth of the trunk net (deep residual network). Defaults to 4.
branch_cat_ratio: Ratio indicating which fraction of the concatenated tensor originates from the branch net.
Controls flow of information into branch- and trunk-net. Defaults to 0.5.
cat_net_width: Width of the cat net (deep residual network). Defaults to 32.
cat_net_depth: Depth of the cat net (deep residual network). Defaults to 4.
act: Activation function. Defaults to Tanh.
device: Device.
"""

def __init__(
self,
shapes: OperatorShapes,
input_net_width: int = 32,
input_net_depth: int = 4,
eval_net_width: int = 32,
eval_net_depth: int = 4,
input_cat_ratio: float = 0.5,
branch_width: int = 32,
branch_depth: int = 4,
trunk_width: int = 32,
trunk_depth: int = 4,
branch_cat_ratio: float = 0.5,
cat_net_width: int = 32,
cat_net_depth: int = 4,
act: Optional[nn.Module] = None,
Expand All @@ -63,29 +79,29 @@ def __init__(
act = nn.Tanh()

assert (
1.0 > input_cat_ratio > 0.0
), f"Ratio has to be in [0, 1], but found {input_cat_ratio}"
input_out_width = ceil(cat_net_width * input_cat_ratio)
0.0 < branch_cat_ratio < 1.0
), f"Ratio has to be in [0, 1], but found {branch_cat_ratio}"
branch_out_width = ceil(cat_net_width * branch_cat_ratio)
assert (
input_out_width != cat_net_width
), f"Input cat ratio {input_cat_ratio} results in eval net width equal zero."
branch_out_width != cat_net_width
), f"Input cat ratio {branch_cat_ratio} results in eval net width equal zero."

input_in_width = prod(shapes.u.size) * shapes.u.dim
self.input_net = DeepResidualNetwork(
self.branch_net = DeepResidualNetwork(
input_size=input_in_width,
output_size=input_out_width,
width=input_net_width,
depth=input_net_depth,
output_size=branch_out_width,
width=branch_width,
depth=branch_depth,
act=act,
device=device,
)

eval_out_width = cat_net_width - input_out_width
self.eval_net = DeepResidualNetwork(
eval_out_width = cat_net_width - branch_out_width
self.trunk_net = DeepResidualNetwork(
input_size=shapes.y.dim,
output_size=eval_out_width,
width=eval_net_width,
depth=eval_net_depth,
width=trunk_width,
depth=trunk_depth,
act=act,
device=device,
)
Expand All @@ -108,18 +124,18 @@ def forward(
Args:
_: Tensor containing sensor locations. Ignored.
u: Tensor containing values of sensors. Of shape (batch_size, u_dim, num_sensors...).
y: Tensor containing evaluation locations. Of shape (batch_size, y_dim, num_evaluations...).
u: Tensor containing values of sensors of shape (batch_size, u_dim, num_sensors...).
y: Tensor containing evaluation locations of shape (batch_size, y_dim, num_evaluations...).
Returns:
Tensor of predicted evaluation values. Of shape (batch_size, v_dim, num_evaluations...).
Tensor of predicted evaluation values of shape (batch_size, v_dim, num_evaluations...).
"""
ipt = torch.flatten(u, start_dim=1)
ipt = self.input_net(ipt)
ipt = self.branch_net(ipt)

y_num = y.shape[2:]
eval = y.flatten(start_dim=2).transpose(1, -1)
eval = self.eval_net(eval)
eval = self.trunk_net(eval)

ipt = ipt.unsqueeze(1).expand(-1, eval.size(1), -1)
cat = torch.cat([ipt, eval], dim=-1)
Expand Down
4 changes: 2 additions & 2 deletions tests/operators/test_deep_cat_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def test_can_initialize(self, random_operator_dataset):
def test_can_initialize_default_networks(self, random_operator_dataset):
operator = DeepCatOperator(shapes=random_operator_dataset.shapes)

assert isinstance(operator.input_net, DeepResidualNetwork)
assert isinstance(operator.eval_net, DeepResidualNetwork)
assert isinstance(operator.branch_net, DeepResidualNetwork)
assert isinstance(operator.trunk_net, DeepResidualNetwork)
assert isinstance(operator.cat_net, DeepResidualNetwork)

def test_forward_shapes_correct(self, dcos, random_shape_operator_datasets):
Expand Down

0 comments on commit 228802b

Please sign in to comment.