Skip to content

Commit

Permalink
Merge pull request #80 from ArneNx/move_to_device
Browse files Browse the repository at this point in the history
Move to device
  • Loading branch information
KonstantinWilleke authored Jun 24, 2020
2 parents 83b9181 + f575a81 commit d6be40d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ mei
nnvision
.ipynb_checkpoints/
Untitled*.ipynb
.history.*
.vscode/*
31 changes: 28 additions & 3 deletions nnfabrik/utility/nn_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# helper functions concerning the ANN architecture

import torch
from torch import nn
from torch.backends import cudnn

from mlutils.training import eval_state
import numpy as np
import random
Expand Down Expand Up @@ -46,8 +49,8 @@ def get_module_output(model, input_shape):
:return: output dimensions of the core
"""
initial_device = 'cuda' if next(iter(model.parameters())).is_cuda else 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
initial_device = "cuda" if next(iter(model.parameters())).is_cuda else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"
with eval_state(model):
with torch.no_grad():
input = torch.zeros(1, *input_shape[1:]).to(device)
Expand All @@ -56,12 +59,34 @@ def get_module_output(model, input_shape):
return output.shape


def set_random_seed(seed):
def set_random_seed(seed: int, deterministic: bool = True):
"""
Sets all random seeds
:param seed: (int) seed to be set
:param deterministic: (bool) activates cudnn.deterministic, which might slow down things
"""
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
if deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.cuda.manual_seed(seed)


def move_to_device(model, gpu=True, multi_gpu=True):
"""
Moves given model to GPU(s) if they are available
:param model: (torch.nn.Module) model to move
:param gpu: (bool) if True attempt to move to GPU
:param multi_gpu: (bool) if True attempt to use multi-GPU
:return: torch.nn.Module, str
"""
device = "cuda" if torch.cuda.is_available() and gpu else "cpu"
if multi_gpu and torch.cuda.device_count() > 1:
print("Using ", torch.cuda.device_count(), "GPUs")
model = nn.DataParallel(model)
model = model.to(device)
return model, device

0 comments on commit d6be40d

Please sign in to comment.