From ecf73d4deba5f5161552bdab186b9d5d815fb882 Mon Sep 17 00:00:00 2001 From: Samuel Burbulla Date: Mon, 19 Aug 2024 11:10:46 +0200 Subject: [PATCH] Minor refactoring. --- src/continuiti/operators/__init__.py | 2 +- .../{deep_cat_operator.py => dco.py} | 53 ++++++++++--------- ...{test_deep_cat_operator.py => test_dco.py} | 0 3 files changed, 30 insertions(+), 25 deletions(-) rename src/continuiti/operators/{deep_cat_operator.py => dco.py} (75%) rename tests/operators/{test_deep_cat_operator.py => test_dco.py} (100%) diff --git a/src/continuiti/operators/__init__.py b/src/continuiti/operators/__init__.py index 945e75d9..8d74b291 100644 --- a/src/continuiti/operators/__init__.py +++ b/src/continuiti/operators/__init__.py @@ -19,7 +19,7 @@ from .fno import FourierNeuralOperator from .shape import OperatorShapes from .cnn import ConvolutionalNeuralNetwork -from .deep_cat_operator import DeepCatOperator +from .dco import DeepCatOperator __all__ = [ "Operator", diff --git a/src/continuiti/operators/deep_cat_operator.py b/src/continuiti/operators/dco.py similarity index 75% rename from src/continuiti/operators/deep_cat_operator.py rename to src/continuiti/operators/dco.py index 5f3b2855..92868ccb 100644 --- a/src/continuiti/operators/deep_cat_operator.py +++ b/src/continuiti/operators/dco.py @@ -1,7 +1,7 @@ """ -`continuiti.operators.deep_cat_operator` +`continuiti.operators.dco` -The DeepCatOperator architecture. +The DeepCatOperator (DCO) architecture. """ import torch @@ -15,35 +15,40 @@ class DeepCatOperator(Operator): """Deep Cat Operator. - This class implements the DeepCatOperator, a neural operator inspired by the DeepONet. It consists of three main - parts: + This class implements the DeepCatOperator, a neural operator inspired by the DeepONet. + + It consists of three main parts: + 1. **Branch Network**: Processes the sensor inputs (`u`). 2. **Trunk Network**: Processes the evaluation locations (`y`). 3. **Cat Network**: Combines the outputs from the Branch- and Trunk-Network to produce the final output. + The architecture has the following structure: + + ┌─────────────────────┐ ┌────────────────────┐ + │ *Branch Network* │ │ *Trunk Network* │ + │ Input (u) │ │ Input (y) │ + │ Output (b) │ │ Output (t) │ + └─────────────────┬───┘ └──┬─────────────────┘ + ┌ ─ ─ ─ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ─ ─ ─ ┐ + │ *Concatenation* │ + │ Input (b, t) │ + │ Output (c) │ + │ branch_cat_ratio = b.numel() / cat_net_width │ + └ ─ ─ ─ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ + ┌─────────┴────────┐ + │ *Cat Network* │ + │ Input (c) │ + │ Output (v) │ + └──────────────────┘ + + This allows the operator to integrate evaluation locations earlier, while ensuring that both the sensor inputs and the evaluation location contribute in a predictable form to the flow of information. Directly stacking both the sensors and evaluation location can lead to an imbalance in the number of features in the neural operator. The arg `branch_cat_ratio` dictates how this fraction is set (defaults to 50/50). The cat-network does not require the - neural operator to learn good basis functions. The information from the input space and the evaluation locations - can be taken into account early, allowing for better abstraction. - - ┌─────────────────────┐ ┌────────────────────┐ - │ *Branch Network* │ │ *Trunk Network* │ - │ Input (u) │ │ Input (y) │ - │ Output (b) │ │ Output (t) │ - └─────────────────┬───┘ └──┬─────────────────┘ - ┌─────────────────┴──────────┴─────────────────┐ - │ *Concatenation* │ - │ Input (b, t) │ - │ Output (c) │ - │ b.numel() / cat_net_width = branch_cat_ratio │ - └────────────────────┬─────────────────────────┘ - ┌────────┴─────────┐ - │ *Cat Network* │ - │ Input (c) │ - │ Output (v) │ - └──────────────────┘ + neural operator to learn good basis functions with the trunk network only. The information from the input space and + the evaluation locations can be taken into account early, allowing for better abstraction. Args: shapes: Operator shapes. @@ -80,7 +85,7 @@ def __init__( assert ( 0.0 < branch_cat_ratio < 1.0 - ), f"Ratio has to be in [0, 1], but found {branch_cat_ratio}" + ), f"Ratio has to be in (0, 1), but found {branch_cat_ratio}" branch_out_width = ceil(cat_net_width * branch_cat_ratio) assert ( branch_out_width != cat_net_width diff --git a/tests/operators/test_deep_cat_operator.py b/tests/operators/test_dco.py similarity index 100% rename from tests/operators/test_deep_cat_operator.py rename to tests/operators/test_dco.py