diff --git a/jobs/ExtensionJob.py b/jobs/ExtensionJob.py index cad870cf..5ca4ff73 100644 --- a/jobs/ExtensionJob.py +++ b/jobs/ExtensionJob.py @@ -21,33 +21,21 @@ def __init__(self, config: OrderedDict): # 多 GPU 情况 self.distributed = True self.local_rank = int(os.getenv('LOCAL_RANK', 0)) - self.device = torch.device(f'cuda:{self.local_rank}') + if self.local_rank < len(devices): + self.device = torch.device(f'cuda:{devices[self.local_rank].strip()}') # 确保设备字符串正确 + else: + raise RuntimeError(f"Invalid LOCAL_RANK {self.local_rank}, exceeding available devices {devices}") self.world_size = len(devices) dist.init_process_group(backend='nccl', rank=self.local_rank, world_size=self.world_size) else: # 单 GPU 情况 self.distributed = False - self.device = torch.device(self.device_config) + self.device = torch.device(self.device_config.strip()) # 单 GPU 情况,去除空格 else: # 默认 CPU 情况 self.distributed = False self.device = torch.device('cpu') - # 从配置文件加载模型路径或配置参数 - model_path = self.get_conf('model_path', None) # 假设配置文件中定义了 model_path - if model_path: - # 加载预训练模型 - self.model = torch.load(model_path, map_location=self.device) - else: - raise ValueError("模型路径未在配置文件中指定,请检查 'model_path' 配置项。") - - # 如果是分布式训练模式,使用 DDP 包装模型 - if self.distributed: - self.model = DDP(self.model, device_ids=[self.local_rank], output_device=self.local_rank) - - # 将模型移动到相应设备 - self.model = self.model.to(self.device) - # 加载扩展的进程 self.process_dict = get_all_extensions_process_dict() self.load_processes(self.process_dict)