Skip to content

Commit

Permalink
Also included the device for stream operations, for good measure
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomDefaultUser committed Dec 22, 2023
1 parent cd1a696 commit 45f0749
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions mala/network/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,8 +636,8 @@ def __process_mini_batch(self, network, input_data, target_data):
if self.parameters._configuration["gpu"]:
if self.parameters.use_graphs and self.train_graph is None:
printout("Capturing CUDA graph for training.", min_verbosity=2)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
s = torch.cuda.Stream(self.parameters._configuration["device"])
s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"]))
# Warmup for graphs
with torch.cuda.stream(s):
for _ in range(20):
Expand All @@ -651,7 +651,7 @@ def __process_mini_batch(self, network, input_data, target_data):
self.gradscaler.scale(loss).backward()
else:
loss.backward()
torch.cuda.current_stream().wait_stream(s)
torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s)

# Create static entry point tensors to graph
self.static_input_data = torch.empty_like(input_data)
Expand Down Expand Up @@ -754,15 +754,15 @@ def __validate_network(self, network, data_set_type, validation_type):

if self.parameters.use_graphs and self.validation_graph is None:
printout("Capturing CUDA graph for validation.", min_verbosity=2)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
s = torch.cuda.Stream(self.parameters._configuration["device"])
s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"]))
# Warmup for graphs
with torch.cuda.stream(s):
for _ in range(20):
with torch.cuda.amp.autocast(enabled=self.parameters.use_mixed_precision):
prediction = network(x)
loss = network.calculate_loss(prediction, y)
torch.cuda.current_stream().wait_stream(s)
torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s)

# Create static entry point tensors to graph
self.static_input_validation = torch.empty_like(x)
Expand Down

0 comments on commit 45f0749

Please sign in to comment.