Skip to content

Commit

Permalink
Add networks submodule.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelburbulla committed Apr 23, 2024
1 parent 186e069 commit 4b84ff8
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 43 deletions.
10 changes: 10 additions & 0 deletions src/continuiti/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
`continuiti.networks`
Networks in continuiti.
"""

from .fully_connected import FullyConnected
from .deep_residual_network import DeepResidualNetwork

__all__ = ["FullyConnected", "DeepResidualNetwork"]
Original file line number Diff line number Diff line change
@@ -1,47 +1,13 @@
"""
`continuiti.operators.common`
`continuiti.networks.deep_residual_network`
Common functionality for operators in continuiti.
Deep residual network in continuiti.
"""

import torch
from typing import Optional


class FullyConnected(torch.nn.Module):
"""Fully connected network.
Args:
input_size: Input dimension.
output_size: Output dimension.
width: Width of the hidden layer.
act: Activation function.
device: Device.
"""

def __init__(
self,
input_size: int,
output_size: int,
width: int,
act: Optional[torch.nn.Module] = None,
device: Optional[torch.device] = None,
):
super().__init__()
self.inner_layer = torch.nn.Linear(input_size, width, device=device)
self.outer_layer = torch.nn.Linear(width, output_size, device=device)
self.act = act or torch.nn.GELU()
self.norm = torch.nn.LayerNorm(width, device=device)

def forward(self, x: torch.Tensor):
"""Forward pass."""
x = self.inner_layer(x)
x = self.act(x)
x = self.norm(x)
x = self.outer_layer(x)
return x


class ResidualLayer(torch.nn.Module):
"""Residual layer.
Expand Down
42 changes: 42 additions & 0 deletions src/continuiti/networks/fully_connected.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
`continuiti.networks.fully_connected`
Fully connected neural network in continuiti.
"""

import torch
from typing import Optional


class FullyConnected(torch.nn.Module):
"""Fully connected network.
Args:
input_size: Input dimension.
output_size: Output dimension.
width: Width of the hidden layer.
act: Activation function.
device: Device.
"""

def __init__(
self,
input_size: int,
output_size: int,
width: int,
act: Optional[torch.nn.Module] = None,
device: Optional[torch.device] = None,
):
super().__init__()
self.inner_layer = torch.nn.Linear(input_size, width, device=device)
self.outer_layer = torch.nn.Linear(width, output_size, device=device)
self.act = act or torch.nn.GELU()
self.norm = torch.nn.LayerNorm(width, device=device)

def forward(self, x: torch.Tensor):
"""Forward pass."""
x = self.inner_layer(x)
x = self.act(x)
x = self.norm(x)
x = self.outer_layer(x)
return x
2 changes: 1 addition & 1 deletion src/continuiti/operators/belnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from typing import Optional
from continuiti.operators import Operator
from continuiti.operators.common import DeepResidualNetwork
from continuiti.networks import DeepResidualNetwork
from continuiti.operators.shape import OperatorShapes


Expand Down
2 changes: 1 addition & 1 deletion src/continuiti/operators/deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from typing import Optional
from continuiti.operators import Operator
from continuiti.operators.common import DeepResidualNetwork
from continuiti.networks import DeepResidualNetwork
from continuiti.operators.shape import OperatorShapes


Expand Down
2 changes: 1 addition & 1 deletion src/continuiti/operators/dno.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from typing import Optional
from continuiti.operators import Operator
from continuiti.operators.common import DeepResidualNetwork
from continuiti.networks import DeepResidualNetwork
from continuiti.operators.shape import OperatorShapes


Expand Down
2 changes: 1 addition & 1 deletion src/continuiti/operators/integralkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import ABC, abstractmethod
from typing import Optional
from continuiti.operators import Operator
from continuiti.operators.common import DeepResidualNetwork
from continuiti.networks import DeepResidualNetwork
from continuiti.operators.shape import OperatorShapes


Expand Down
Empty file added tests/networks/__init__.py
Empty file.
59 changes: 59 additions & 0 deletions tests/networks/test_deep_residual_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch.nn as nn
import torch
import pytest
from continuiti.networks import DeepResidualNetwork


@pytest.fixture(scope="session")
def trivial_deep_residual_network():
return DeepResidualNetwork(input_size=3, output_size=5, width=15, depth=3)


@pytest.fixture(scope="session")
def random_vector():
return torch.rand(
3,
)


class TestDeepResidualNetwork:
def test_can_initialize(self, trivial_deep_residual_network):
assert isinstance(trivial_deep_residual_network, DeepResidualNetwork)

def test_can_forward(self, trivial_deep_residual_network, random_vector):
trivial_deep_residual_network(random_vector)
assert True

def test_shape_correct(self, trivial_deep_residual_network, random_vector):
out = trivial_deep_residual_network(random_vector)
assert out.shape == torch.Size([5])

def test_can_backward(self, trivial_deep_residual_network, random_vector):
out = trivial_deep_residual_network(random_vector)
loss = nn.L1Loss()(
out,
torch.rand(
5,
),
)
loss.backward()
assert True

def test_can_overfit(self, trivial_deep_residual_network, random_vector):
out_vec = torch.rand(
5,
)
criterion = nn.L1Loss()
optim = torch.optim.Adam(trivial_deep_residual_network.parameters(), lr=1e-4)

loss = torch.inf
for _ in range(1000):
optim.zero_grad()
out = trivial_deep_residual_network(random_vector)
loss = criterion(out, out_vec)
loss.backward()
if loss.item() <= 1e-3:
break
optim.step()

assert loss.item() <= 1e-3
59 changes: 59 additions & 0 deletions tests/networks/test_fully_connected.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch.nn as nn
import torch
import pytest
from continuiti.networks import FullyConnected


@pytest.fixture(scope="session")
def trivial_fully_connected():
return FullyConnected(input_size=3, output_size=5, width=7)


@pytest.fixture(scope="session")
def random_vector():
return torch.rand(
3,
)


class TestFullyConnected:
def test_can_initialize(self, trivial_fully_connected):
assert isinstance(trivial_fully_connected, FullyConnected)

def test_can_forward(self, trivial_fully_connected, random_vector):
trivial_fully_connected(random_vector)
assert True

def test_shape_correct(self, trivial_fully_connected, random_vector):
out = trivial_fully_connected(random_vector)
assert out.shape == torch.Size([5])

def test_can_backward(self, trivial_fully_connected, random_vector):
out = trivial_fully_connected(random_vector)
loss = nn.L1Loss()(
out,
torch.rand(
5,
),
)
loss.backward()
assert True

def test_can_overfit(self, trivial_fully_connected, random_vector):
out_vec = torch.rand(
5,
)
criterion = nn.L1Loss()
optim = torch.optim.Adam(trivial_fully_connected.parameters())

loss = torch.inf
for _ in range(1000):
optim.zero_grad()
out = trivial_fully_connected(random_vector)
loss = criterion(out, out_vec)
loss.backward()
if loss.item() <= 1e-3:
break
optim.step()

assert loss.item() <= 1e-3
6 changes: 3 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import pytest
import torch
from continuiti.operators import DeepONet
from continuiti.operators.common import DeepResidualNetwork
from continuiti.networks import DeepResidualNetwork
from continuiti.benchmarks.sine import SineBenchmark
from continuiti.trainer import Trainer


def train():
dataset = SineBenchmark(n_train=32).train_dataset
operator = DeepONet(dataset.shapes, trunk_depth=16)
dataset = SineBenchmark(n_train=8).train_dataset
operator = DeepONet(dataset.shapes, trunk_depth=8)

Trainer(operator).fit(dataset, tol=1e-2)

Expand Down

0 comments on commit 4b84ff8

Please sign in to comment.