Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed distributed Optuna running on multiple GPUs #495

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions mala/network/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def train_network(self):
self.data.training_data_sets[0].shuffle()

if self.parameters._configuration["gpu"]:
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])
tsample = time.time()
t0 = time.time()
batchid = 0
Expand Down Expand Up @@ -309,7 +309,7 @@ def train_network(self):
training_loss_sum += loss

if batchid != 0 and (batchid + 1) % self.parameters.training_report_frequency == 0:
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])
sample_time = time.time() - tsample
avg_sample_time = sample_time / self.parameters.training_report_frequency
avg_sample_tput = self.parameters.training_report_frequency * inputs.shape[0] / sample_time
Expand All @@ -319,14 +319,14 @@ def train_network(self):
min_verbosity=2)
tsample = time.time()
batchid += 1
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])
t1 = time.time()
printout(f"training time: {t1 - t0}", min_verbosity=2)

training_loss = training_loss_sum.item() / batchid

# Calculate the validation loss. and output it.
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])
else:
batchid = 0
for loader in self.training_data_loaders:
Expand Down Expand Up @@ -375,14 +375,14 @@ def train_network(self):
self.tensor_board.close()

if self.parameters._configuration["gpu"]:
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])

# Mix the DataSets up (this function only does something
# in the lazy loading case).
if self.parameters.use_shuffling_for_samplers:
self.data.mix_datasets()
if self.parameters._configuration["gpu"]:
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])

# If a scheduler is used, update it.
if self.scheduler is not None:
Expand Down Expand Up @@ -636,8 +636,8 @@ def __process_mini_batch(self, network, input_data, target_data):
if self.parameters._configuration["gpu"]:
if self.parameters.use_graphs and self.train_graph is None:
printout("Capturing CUDA graph for training.", min_verbosity=2)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
s = torch.cuda.Stream(self.parameters._configuration["device"])
s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"]))
# Warmup for graphs
with torch.cuda.stream(s):
for _ in range(20):
Expand All @@ -651,7 +651,7 @@ def __process_mini_batch(self, network, input_data, target_data):
self.gradscaler.scale(loss).backward()
else:
loss.backward()
torch.cuda.current_stream().wait_stream(s)
torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s)

# Create static entry point tensors to graph
self.static_input_data = torch.empty_like(input_data)
Expand Down Expand Up @@ -742,7 +742,7 @@ def __validate_network(self, network, data_set_type, validation_type):
with torch.no_grad():
if self.parameters._configuration["gpu"]:
report_freq = self.parameters.training_report_frequency
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])
tsample = time.time()
batchid = 0
for loader in data_loaders:
Expand All @@ -754,15 +754,15 @@ def __validate_network(self, network, data_set_type, validation_type):

if self.parameters.use_graphs and self.validation_graph is None:
printout("Capturing CUDA graph for validation.", min_verbosity=2)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
s = torch.cuda.Stream(self.parameters._configuration["device"])
s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"]))
# Warmup for graphs
with torch.cuda.stream(s):
for _ in range(20):
with torch.cuda.amp.autocast(enabled=self.parameters.use_mixed_precision):
prediction = network(x)
loss = network.calculate_loss(prediction, y)
torch.cuda.current_stream().wait_stream(s)
torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s)

# Create static entry point tensors to graph
self.static_input_validation = torch.empty_like(x)
Expand All @@ -786,7 +786,7 @@ def __validate_network(self, network, data_set_type, validation_type):
loss = network.calculate_loss(prediction, y)
validation_loss_sum += loss
if batchid != 0 and (batchid + 1) % report_freq == 0:
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])
sample_time = time.time() - tsample
avg_sample_time = sample_time / report_freq
avg_sample_tput = report_freq * x.shape[0] / sample_time
Expand All @@ -796,7 +796,7 @@ def __validate_network(self, network, data_set_type, validation_type):
min_verbosity=2)
tsample = time.time()
batchid += 1
torch.cuda.synchronize()
torch.cuda.synchronize(self.parameters._configuration["device"])
else:
batchid = 0
for loader in data_loaders:
Expand Down