diff --git a/_modules/mala/network/trainer.html b/_modules/mala/network/trainer.html index 9f8a0ad6c..f07fefab3 100644 --- a/_modules/mala/network/trainer.html +++ b/_modules/mala/network/trainer.html @@ -355,7 +355,7 @@
self.data.training_data_sets[0].shuffle()
if self.parameters._configuration["gpu"]:
- torch.cuda.synchronize()
+ torch.cuda.synchronize(self.parameters._configuration["device"])
tsample = time.time()
t0 = time.time()
batchid = 0
@@ -385,7 +385,7 @@ Source code for mala.network.trainer
training_loss_sum += loss
if batchid != 0 and (batchid + 1) % self.parameters.training_report_frequency == 0:
- torch.cuda.synchronize()
+ torch.cuda.synchronize(self.parameters._configuration["device"])
sample_time = time.time() - tsample
avg_sample_time = sample_time / self.parameters.training_report_frequency
avg_sample_tput = self.parameters.training_report_frequency * inputs.shape[0] / sample_time
@@ -395,14 +395,14 @@ Source code for mala.network.trainer
min_verbosity=2)
tsample = time.time()
batchid += 1
- torch.cuda.synchronize()
+ torch.cuda.synchronize(self.parameters._configuration["device"])
t1 = time.time()
printout(f"training time: {t1 - t0}", min_verbosity=2)
training_loss = training_loss_sum.item() / batchid
# Calculate the validation loss. and output it.
- torch.cuda.synchronize()
+ torch.cuda.synchronize(self.parameters._configuration["device"])
else:
batchid = 0
for loader in self.training_data_loaders:
@@ -451,14 +451,14 @@ Source code for mala.network.trainer
self.tensor_board.close()
if self.parameters._configuration["gpu"]:
- torch.cuda.synchronize()
+ torch.cuda.synchronize(self.parameters._configuration["device"])
# Mix the DataSets up (this function only does something
# in the lazy loading case).
if self.parameters.use_shuffling_for_samplers:
self.data.mix_datasets()
if self.parameters._configuration["gpu"]:
- torch.cuda.synchronize()
+ torch.cuda.synchronize(self.parameters._configuration["device"])
# If a scheduler is used, update it.
if self.scheduler is not None:
@@ -712,8 +712,8 @@ Source code for mala.network.trainer
if self.parameters._configuration["gpu"]:
if self.parameters.use_graphs and self.train_graph is None:
printout("Capturing CUDA graph for training.", min_verbosity=2)
- s = torch.cuda.Stream()
- s.wait_stream(torch.cuda.current_stream())
+ s = torch.cuda.Stream(self.parameters._configuration["device"])
+ s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"]))
# Warmup for graphs
with torch.cuda.stream(s):
for _ in range(20):
@@ -727,7 +727,7 @@ Source code for mala.network.trainer
self.gradscaler.scale(loss).backward()
else:
loss.backward()
- torch.cuda.current_stream().wait_stream(s)
+ torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s)
# Create static entry point tensors to graph
self.static_input_data = torch.empty_like(input_data)
@@ -818,7 +818,7 @@ Source code for mala.network.trainer
with torch.no_grad():
if self.parameters._configuration["gpu"]:
report_freq = self.parameters.training_report_frequency
- torch.cuda.synchronize()
+ torch.cuda.synchronize(self.parameters._configuration["device"])
tsample = time.time()
batchid = 0
for loader in data_loaders:
@@ -830,15 +830,15 @@ Source code for mala.network.trainer
if self.parameters.use_graphs and self.validation_graph is None:
printout("Capturing CUDA graph for validation.", min_verbosity=2)
- s = torch.cuda.Stream()
- s.wait_stream(torch.cuda.current_stream())
+ s = torch.cuda.Stream(self.parameters._configuration["device"])
+ s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"]))
# Warmup for graphs
with torch.cuda.stream(s):
for _ in range(20):
with torch.cuda.amp.autocast(enabled=self.parameters.use_mixed_precision):
prediction = network(x)
loss = network.calculate_loss(prediction, y)
- torch.cuda.current_stream().wait_stream(s)
+ torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s)
# Create static entry point tensors to graph
self.static_input_validation = torch.empty_like(x)
@@ -862,7 +862,7 @@ Source code for mala.network.trainer
loss = network.calculate_loss(prediction, y)
validation_loss_sum += loss
if batchid != 0 and (batchid + 1) % report_freq == 0:
- torch.cuda.synchronize()
+ torch.cuda.synchronize(self.parameters._configuration["device"])
sample_time = time.time() - tsample
avg_sample_time = sample_time / report_freq
avg_sample_tput = report_freq * x.shape[0] / sample_time
@@ -872,7 +872,7 @@ Source code for mala.network.trainer
min_verbosity=2)
tsample = time.time()
batchid += 1
- torch.cuda.synchronize()
+ torch.cuda.synchronize(self.parameters._configuration["device"])
else:
batchid = 0
for loader in data_loaders:
diff --git a/objects.inv b/objects.inv
index 1245cdc0e..0f7efe040 100644
Binary files a/objects.inv and b/objects.inv differ