Skip to content

Commit

Permalink
Offer config parameter to initialize model weights
Browse files Browse the repository at this point in the history
  • Loading branch information
lmanan committed Oct 12, 2023
1 parent 255fa32 commit 4b17700
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions cellulus/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ def train(experiment_config):
model = model.to(device)

# initialize model weights
# for _name, layer in model.named_modules():
# if isinstance(layer, torch.nn.modules.conv._ConvNd):
# torch.nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
if model_config.initialize:
for _name, layer in model.named_modules():
if isinstance(layer, torch.nn.modules.conv._ConvNd):
torch.nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")

# set loss
criterion = get_loss(
Expand Down

0 comments on commit 4b17700

Please sign in to comment.