Skip to content

Commit

Permalink
Update ExtensionJob.py
Browse files Browse the repository at this point in the history
  • Loading branch information
NBSTpeterhill authored Nov 20, 2024
1 parent f19fc15 commit 11cad49
Showing 1 changed file with 5 additions and 17 deletions.
22 changes: 5 additions & 17 deletions jobs/ExtensionJob.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,21 @@ def __init__(self, config: OrderedDict):
# 多 GPU 情况
self.distributed = True
self.local_rank = int(os.getenv('LOCAL_RANK', 0))
self.device = torch.device(f'cuda:{self.local_rank}')
if self.local_rank < len(devices):
self.device = torch.device(f'cuda:{devices[self.local_rank].strip()}') # 确保设备字符串正确
else:
raise RuntimeError(f"Invalid LOCAL_RANK {self.local_rank}, exceeding available devices {devices}")
self.world_size = len(devices)
dist.init_process_group(backend='nccl', rank=self.local_rank, world_size=self.world_size)
else:
# 单 GPU 情况
self.distributed = False
self.device = torch.device(self.device_config)
self.device = torch.device(self.device_config.strip()) # 单 GPU 情况,去除空格
else:
# 默认 CPU 情况
self.distributed = False
self.device = torch.device('cpu')

# 从配置文件加载模型路径或配置参数
model_path = self.get_conf('model_path', None) # 假设配置文件中定义了 model_path
if model_path:
# 加载预训练模型
self.model = torch.load(model_path, map_location=self.device)
else:
raise ValueError("模型路径未在配置文件中指定,请检查 'model_path' 配置项。")

# 如果是分布式训练模式,使用 DDP 包装模型
if self.distributed:
self.model = DDP(self.model, device_ids=[self.local_rank], output_device=self.local_rank)

# 将模型移动到相应设备
self.model = self.model.to(self.device)

# 加载扩展的进程
self.process_dict = get_all_extensions_process_dict()
self.load_processes(self.process_dict)
Expand Down

0 comments on commit 11cad49

Please sign in to comment.