Skip to content

Commit

Permalink
Update BaseSDTrainProcess.py
Browse files Browse the repository at this point in the history
  • Loading branch information
NBSTpeterhill authored Nov 20, 2024
1 parent 11cad49 commit e19625d
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,22 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No
# if true, then we do not do an optimizer step. We are accumulating gradients
self.is_grad_accumulation_step = False
self.device = self.get_conf('device', self.job.device)
device_config = self.get_conf('device', self.job.device)
if isinstance(device_config, str) and device_config.startswith('cuda'):
devices = device_config.split(',')
if len(devices) > 1:
# 多 GPU 情况下,选择当前进程的 GPU
local_rank = int(os.getenv('LOCAL_RANK', 0))
if local_rank < len(devices):
self.device = devices[local_rank]
else:
raise RuntimeError(f"Invalid LOCAL_RANK {local_rank}, exceeding available devices {devices}")
else:
# 单 GPU 情况
self.device = device_config
else:
# 使用 CPU
self.device = 'cpu'
self.device_torch = torch.device(self.device)
network_config = self.get_conf('network', None)
if network_config is not None:
Expand Down

0 comments on commit e19625d

Please sign in to comment.