Skip to content

Commit

Permalink
refactor: ➖ Remove funlib dependency.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed May 7, 2024
1 parent 561eb5c commit e92fd3c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 20 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
name = LeibNetz
version = 0.2.0
author = Jeff Rhoades, Larissa Heinrich
author_email = [email protected], [email protected]
author_email = [email protected]
url = https://github.com/janelia-cellmap/LeibNetz
description = A lightweight and modular library for rapidly developing and constructing PyTorch models deep learning.
description = A lightweight and modular library for rapidly developing and constructing PyTorch models for deep learning.
long_description = file: README.md
long_description_content_type = text/markdown
keywords = image-segmentation, convolutional-neural-networks, deep-learning, pytorch
Expand Down
20 changes: 9 additions & 11 deletions src/leibnetz/leibnet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Iterable, Optional, Sequence, Tuple, Union
from typing import Iterable, Sequence, Tuple
import networkx as nx
from torch import device
import torch
from torch.nn import Module
import numpy as np
from leibnetz.nodes import Node
from funlib.learn.torch.models.conv4d import Conv4d

# from funlib.learn.torch.models.conv4d import Conv4d

# from model_opt.apis import optimize

Expand Down Expand Up @@ -43,29 +44,26 @@ def __init__(
self.apply(
lambda m: (
torch.nn.init.kaiming_normal_(m.weight, mode="fan_out")
if isinstance(m, torch.nn.Conv2d)
or isinstance(m, torch.nn.Conv3d)
or isinstance(m, Conv4d)
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv3d)
# or isinstance(m, Conv4d)
else None
)
)
elif initialization == "xavier":
self.apply(
lambda m: (
torch.nn.init.xavier_normal_(m.weight)
if isinstance(m, torch.nn.Conv2d)
or isinstance(m, torch.nn.Conv3d)
or isinstance(m, Conv4d)
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv3d)
# or isinstance(m, Conv4d)
else None
)
)
elif initialization == "orthogonal":
self.apply(
lambda m: (
torch.nn.init.orthogonal_(m.weight)
if isinstance(m, torch.nn.Conv2d)
or isinstance(m, torch.nn.Conv3d)
or isinstance(m, Conv4d)
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv3d)
# or isinstance(m, Conv4d)
else None
)
)
Expand Down
10 changes: 7 additions & 3 deletions src/leibnetz/nodes/node_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torch import nn
import numpy as np
from funlib.learn.torch.models.conv4d import Conv4d

# from funlib.learn.torch.models.conv4d import Conv4d


class ConvPass(nn.Module):
Expand Down Expand Up @@ -95,10 +96,13 @@ def __init__(
self.dims = len(kernel_size)

try:
conv = {2: nn.Conv2d, 3: nn.Conv3d, 4: Conv4d}[self.dims]
# TODO: Implement Conv4d
# conv = {2: nn.Conv2d, 3: nn.Conv3d, 4: Conv4d}[self.dims]
conv = {2: nn.Conv2d, 3: nn.Conv3d}[self.dims]
except KeyError:
raise ValueError(
f"Only 2D, 3D and 4D convolutions are supported, not {self.dims}D"
# f"Only 2D, 3D and 4D convolutions are supported, not {self.dims}D"
f"Only 2D and 3D convolutions are supported, not {self.dims}D"
)

layers.append(
Expand Down
15 changes: 11 additions & 4 deletions src/leibnetz/nodes/resample_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import math
from torch import nn
import torch
from funlib.learn.torch.models.conv4d import Conv4d

# from funlib.learn.torch.models.conv4d import Conv4d

from logging import getLogger

Expand Down Expand Up @@ -52,7 +51,15 @@ def __init__(

self.dims = len(kernel_sizes)

conv = {2: nn.Conv2d, 3: nn.Conv3d, 4: Conv4d}[self.dims]
try:
# TODO: Implement Conv4d
# conv = {2: nn.Conv2d, 3: nn.Conv3d, 4: Conv4d}[self.dims]
conv = {2: nn.Conv2d, 3: nn.Conv3d}[self.dims]
except KeyError:
raise ValueError(
# f"Only 2D, 3D and 4D convolutions are supported, not {self.dims}D"
f"Only 2D and 3D convolutions are supported, not {self.dims}D"
)

try:
layers.append(
Expand Down

0 comments on commit e92fd3c

Please sign in to comment.