Skip to content

Commit

Permalink
feat: ✨ Add activation specification for UNets/ScaleNets
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed May 2, 2024
1 parent 1449ca3 commit 1f22595
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/leibnetz/leibnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
self,
nodes: Iterable,
outputs: dict[str, Sequence[Tuple]],
retain_buffer=True,
# retain_buffer=True,
initialization="kaiming",
name="LeibNet",
):
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
pass
else:
raise ValueError(f"Unknown initialization {initialization}")
self.retain_buffer = retain_buffer
# self.retain_buffer = retain_buffer
self.retain_buffer = True
if torch.cuda.is_available():
self.cuda()
Expand Down
5 changes: 5 additions & 0 deletions src/leibnetz/nets/attentive_scalenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def build_subnet(
norm_layer=None,
residual=False,
dropout_prob=None,
activation="ReLU",
):
# define downsample nodes
downsample_factors = np.array(downsample_factors)
Expand All @@ -42,6 +43,7 @@ def build_subnet(
norm_layer=norm_layer,
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
),
)
c += 1
Expand Down Expand Up @@ -76,6 +78,7 @@ def build_subnet(
norm_layer=norm_layer,
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
)
)
input_key = output_key
Expand Down Expand Up @@ -121,6 +124,7 @@ def build_subnet(
norm_layer=norm_layer,
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
)
)
input_key = output_key
Expand All @@ -138,6 +142,7 @@ def build_subnet(
norm_layer=norm_layer,
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
)
)
outputs = {
Expand Down
5 changes: 5 additions & 0 deletions src/leibnetz/nets/scalenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def build_subnet(
norm_layer=None,
residual=False,
dropout_prob=None,
activation="ReLU",
):
# define downsample nodes
downsample_factors = np.array(downsample_factors)
Expand All @@ -37,6 +38,7 @@ def build_subnet(
norm_layer=norm_layer,
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
),
)
c += 1
Expand Down Expand Up @@ -71,6 +73,7 @@ def build_subnet(
norm_layer=norm_layer,
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
)
)
input_key = output_key
Expand Down Expand Up @@ -101,6 +104,7 @@ def build_subnet(
norm_layer=norm_layer,
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
)
)
input_key = output_key
Expand All @@ -118,6 +122,7 @@ def build_subnet(
norm_layer=norm_layer,
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
)
)
outputs = {
Expand Down
5 changes: 5 additions & 0 deletions src/leibnetz/nets/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def build_unet(
norm_layer=None,
residual=False,
dropout_prob=None,
activation="ReLU",
):
# define downsample nodes
downsample_factors = np.array(downsample_factors)
Expand All @@ -34,6 +35,7 @@ def build_unet(
norm_layer=norm_layer,
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
),
)
c += 1
Expand Down Expand Up @@ -62,6 +64,7 @@ def build_unet(
norm_layer=norm_layer,
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
)
)
input_key = output_key
Expand Down Expand Up @@ -92,6 +95,7 @@ def build_unet(
norm_layer=norm_layer,
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
)
)
input_key = output_key
Expand All @@ -109,6 +113,7 @@ def build_unet(
norm_layer=norm_layer, # TODO: remove?
residual=residual,
dropout_prob=dropout_prob,
activation=activation,
)
)

Expand Down

0 comments on commit 1f22595

Please sign in to comment.