Skip to content

Commit

Permalink
Makes deterministic optional
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneNx committed Jun 24, 2020
1 parent 63496c7 commit f575a81
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions nnfabrik/utility/nn_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,20 @@ def get_module_output(model, input_shape):
return output.shape


def set_random_seed(seed):
def set_random_seed(seed: int, deterministic: bool = True):
"""
Sets all random seeds
:param seed: (int) seed to be set
:param deterministic: (bool) activates cudnn.deterministic, which might slow down things
"""
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
cudnn.benchmark = False
cudnn.deterministic = True
if deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.cuda.manual_seed(seed)


Expand All @@ -86,4 +90,3 @@ def move_to_device(model, gpu=True, multi_gpu=True):
model = nn.DataParallel(model)
model = model.to(device)
return model, device

0 comments on commit f575a81

Please sign in to comment.