English | 简体中文
最新版本 v0.10.4 在 2024.4.23 发布。
版本亮点:
- 支持在 MLFlowVisBackend 中自定义
artifact_location
#1505 - 支持在
DeepSpeedEngine._zero3_consolidated_16bit_state_dict
使用exclude_frozen_parameters
#1517
如果想了解更多版本更新细节和历史信息,请阅读更新日志。
MMEngine 是一个基于 PyTorch 实现的,用于训练深度学习模型的基础库。它作为 OpenMMLab 所有代码库的训练引擎,其在不同研究领域支持了上百个算法。此外,MMEngine 也可以用于非 OpenMMLab 项目中。它的亮点如下:
集成主流的大模型训练框架
支持丰富的训练策略
提供易用的配置系统
覆盖主流的训练监测平台
兼容主流的训练芯片
- 英伟达 CUDA | 苹果 MPS
- 华为 Ascend | 寒武纪 MLU | 摩尔线程 MUSA
支持的 PyTorch 版本
MMEngine | PyTorch | Python |
---|---|---|
main | >=1.6 <=2.1 | >=3.8, <=3.11 |
>=0.9.0, <=0.10.4 | >=1.6 <=2.1 | >=3.8, <=3.11 |
在安装 MMEngine 之前,请确保 PyTorch 已成功安装在环境中,可以参考 PyTorch 官方安装文档。
安装 MMEngine
pip install -U openmim
mim install mmengine
验证是否安装成功
python -c 'from mmengine.utils.dl_utils import collect_env;print(collect_env())'
更多安装方式请阅读安装文档。
以在 CIFAR-10 数据集上训练一个 ResNet-50 模型为例,我们将使用 80 行以内的代码,利用 MMEngine 构建一个完整的、可配置的训练和验证流程。
构建模型
首先,我们需要构建一个模型,在 MMEngine 中,我们约定这个模型应当继承 BaseModel
,并且其 forward
方法除了接受来自数据集的若干参数外,还需要接受额外的参数 mode
。
- 对于训练,我们需要
mode
接受字符串 "loss",并返回一个包含 "loss" 字段的字典。 - 对于验证,我们需要
mode
接受字符串 "predict",并返回同时包含预测信息和真实信息的结果。
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
构建数据集
其次,我们需要构建训练和验证所需要的数据集(Dataset)和数据加载器(DataLoader)。在该示例中,我们使用 TorchVision 支持的方式构建数据集。
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
val_dataloader = DataLoader(batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
构建评测指标
为了进行验证和测试,我们需要定义模型推理结果的评测指标。我们约定这一评测指标需要继承 BaseMetric
,并实现 process
和 compute_metrics
方法。
from mmengine.evaluator import BaseMetric
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
# 将一个批次的中间结果保存至 `self.results`
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})
def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
# 返回保存有评测指标结果的字典,其中键为指标名称
return dict(accuracy=100 * total_correct / total_size)
构建执行器
最后,我们利用构建好的模型
,数据加载器
,评测指标
构建一个执行器(Runner),并伴随其他的配置信息,如下所示。
from torch.optim import SGD
from mmengine.runner import Runner
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
# 优化器包装,用于模型优化,并提供 AMP、梯度累积等附加功能
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
# 训练配置,例如 epoch 等
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
)
开始训练
runner.train()
示例
迁移指南
我们感谢所有的贡献者为改进和提升 MMEngine 所作出的努力。请参考贡献指南来了解参与项目贡献的相关指引。
如果您觉得 MMEngine 对您的研究有所帮助,请考虑引用它:
@article{mmengine2022,
title = {{MMEngine}: OpenMMLab Foundational Library for Training Deep Learning Models},
author = {MMEngine Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmengine}},
year={2022}
}
该项目采用 Apache 2.0 license 开源许可证。
- APES: Attention-based Point Cloud Edge Sampling
- DiffEngine: diffusers training toolbox with mmengine
- MIM: MIM 是 OpenMMLab 项目、算法、模型的统一入口
- MMCV: OpenMMLab 计算机视觉基础库
- MMEval: 统一开放的跨框架算法评测库
- MMPreTrain: OpenMMLab 深度学习预训练工具箱
- MMagic: OpenMMLab 新一代人工智能内容生成(AIGC)工具箱
- MMDetection: OpenMMLab 目标检测工具箱
- MMYOLO: OpenMMLab YOLO 系列工具箱与测试基准
- MMDetection3D: OpenMMLab 新一代通用 3D 目标检测平台
- MMRotate: OpenMMLab 旋转框检测工具箱与测试基准
- MMTracking: OpenMMLab 一体化视频目标感知平台
- MMOCR: OpenMMLab 全流程文字检测识别理解工具包
- MMSegmentation: OpenMMLab 语义分割工具箱
- MMPose: OpenMMLab 姿态估计工具箱
- MMHuman3D: OpenMMLab 人体参数化模型工具箱与测试基准
- MMSelfSup: OpenMMLab 自监督学习工具箱与测试基准
- MMFewShot: OpenMMLab 少样本学习工具箱与测试基准
- MMAction2: OpenMMLab 新一代视频理解工具箱
- MMFlow: OpenMMLab 光流估计工具箱与测试基准
- MMDeploy: OpenMMLab 模型部署框架
- MMRazor: OpenMMLab 模型压缩工具箱与测试基准
- Playground: 收集和展示 OpenMMLab 相关的前沿、有趣的社区项目
扫描下方的二维码可关注 OpenMMLab 团队的 知乎官方账号,扫描下方微信二维码添加喵喵好友,进入 MMEngine 微信交流社群。【加好友申请格式:研究方向+地区+学校/公司+姓名】
我们会在 OpenMMLab 社区为大家
- 📢 分享 AI 框架的前沿核心技术
- 💻 解读 PyTorch 常用模块源码
- 📰 发布 OpenMMLab 的相关新闻
- 🚀 介绍 OpenMMLab 开发的前沿算法
- 🏃 获取更高效的问题答疑和意见反馈
- 🔥 提供与各行各业开发者充分交流的平台
干货满满 📘,等你来撩 💗,OpenMMLab 社区期待您的加入 👬