Skip to content

Commit

Permalink
Minor refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelburbulla committed Aug 19, 2024
1 parent 228802b commit ecf73d4
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/continuiti/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .fno import FourierNeuralOperator
from .shape import OperatorShapes
from .cnn import ConvolutionalNeuralNetwork
from .deep_cat_operator import DeepCatOperator
from .dco import DeepCatOperator

__all__ = [
"Operator",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
`continuiti.operators.deep_cat_operator`
`continuiti.operators.dco`
The DeepCatOperator architecture.
The DeepCatOperator (DCO) architecture.
"""

import torch
Expand All @@ -15,35 +15,40 @@
class DeepCatOperator(Operator):
"""Deep Cat Operator.
This class implements the DeepCatOperator, a neural operator inspired by the DeepONet. It consists of three main
parts:
This class implements the DeepCatOperator, a neural operator inspired by the DeepONet.
It consists of three main parts:
1. **Branch Network**: Processes the sensor inputs (`u`).
2. **Trunk Network**: Processes the evaluation locations (`y`).
3. **Cat Network**: Combines the outputs from the Branch- and Trunk-Network to produce the final output.
The architecture has the following structure:
┌─────────────────────┐ ┌────────────────────┐
│ *Branch Network* │ │ *Trunk Network* │
│ Input (u) │ │ Input (y) │
│ Output (b) │ │ Output (t) │
└─────────────────┬───┘ └──┬─────────────────┘
┌ ─ ─ ─ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ─ ─ ─ ┐
│ *Concatenation* │
│ Input (b, t) │
│ Output (c) │
│ branch_cat_ratio = b.numel() / cat_net_width │
└ ─ ─ ─ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘
┌─────────┴────────┐
│ *Cat Network* │
│ Input (c) │
│ Output (v) │
└──────────────────┘
This allows the operator to integrate evaluation locations earlier, while ensuring that both the sensor inputs and
the evaluation location contribute in a predictable form to the flow of information. Directly stacking both the
sensors and evaluation location can lead to an imbalance in the number of features in the neural operator. The
arg `branch_cat_ratio` dictates how this fraction is set (defaults to 50/50). The cat-network does not require the
neural operator to learn good basis functions. The information from the input space and the evaluation locations
can be taken into account early, allowing for better abstraction.
┌─────────────────────┐ ┌────────────────────┐
│ *Branch Network* │ │ *Trunk Network* │
│ Input (u) │ │ Input (y) │
│ Output (b) │ │ Output (t) │
└─────────────────┬───┘ └──┬─────────────────┘
┌─────────────────┴──────────┴─────────────────┐
│ *Concatenation* │
│ Input (b, t) │
│ Output (c) │
│ b.numel() / cat_net_width = branch_cat_ratio │
└────────────────────┬─────────────────────────┘
┌────────┴─────────┐
│ *Cat Network* │
│ Input (c) │
│ Output (v) │
└──────────────────┘
neural operator to learn good basis functions with the trunk network only. The information from the input space and
the evaluation locations can be taken into account early, allowing for better abstraction.
Args:
shapes: Operator shapes.
Expand Down Expand Up @@ -80,7 +85,7 @@ def __init__(

assert (
0.0 < branch_cat_ratio < 1.0
), f"Ratio has to be in [0, 1], but found {branch_cat_ratio}"
), f"Ratio has to be in (0, 1), but found {branch_cat_ratio}"
branch_out_width = ceil(cat_net_width * branch_cat_ratio)
assert (
branch_out_width != cat_net_width
Expand Down
File renamed without changes.

0 comments on commit ecf73d4

Please sign in to comment.