- torch==1.7.1
- torchvision==0.8.2
- opencv-python==4.4.0.46
- dlib==19.21.1
- CUDA Version==10.2.89
注:conda最近安装torch特别慢,如果有遇到相同的情况请切换下载源;dlib安装有时会报错,提示需要CMake之类的,但在Mac或linux上似乎不用手动安装CMake,可尝试conda install和pip install连着用,有时自动就装好了。
训练、测试数据都按照readme中给出的链接下载即可,建议上传至服务器后再解压,零散文件上传太慢了。
原始数据集解压出来有300多万张图片,但是通过阅读VGGFace2数据集的说明,发小存在大量“小图”,因此使用Data_preprocessing/Kill_img.py的数据清洗,文件35行为最小图片尺寸,默认为250*250(我也用的这个),文件46行为数据路径,需手动指定到你解压好的训练数据文件夹,然后运行即可(我运行之后显示还有1,018,790张图片,即大约删除了一大半的图片)。
shape_predictor_68_face_landmarks.dat文件下载好之后放在Data_preprocessing文件夹下就行,这样做Image_processing.py数据生成的时候就不用改路径了。
train_maskV9.py模型训练无法直接跑通,遇到以下几个问题:
-
问题一:
Data_preprocessing/Image_processing.py文件 第134行的训练数据路径需提前检查,解压好的文件夹不一定叫这个名字。第22行os.environ["CUDA_VISIBLE_DEVICES"] = "3"可能也需要修改,比如我只有2块卡就改成了0。 -
问题二:
train_maskV9.py文件 的第21行:from Data_loader.Data_loader_train_notmask import TrainDataset可#掉,原因是没用到这个Datasets,实际用的是V9_train_dataloader。 -
问题三:
train_maskV9.py文件 的第41-55行:这里是用来载入预训练模型的,而如果你没下预训练模型,就会报model_pathi不存在的问题,这里的解决方案也是直接#掉(如果你真没用预训练模型)。 -
问题四:
train_maskV9.py文件 的第5行可以看到,限定了只用一块GPU,第57-66行也有GPU相关设定,要注意,这里只能用1块,用多卡就报错,我尚且没有搞清为什么多卡就报错。 -
问题五:
train_maskV9.py文件 的第68-82行是模型学习率的设定,我在第一次跑的时候把这个学习率调低了,原因是原始FaceNet的github上一上来就用的0.001,但是训练效果不好,后来又提升到了这个文件中所写的0.125的水平,就复现出来了。我认为能长时间使用这么大的学习率还不波动的原因在于:随着模型的训练“困难样本”实际上是在快速减少的,通过阅读训练日志可以看到,刚开始训练的时候,有五万个困难样本,而训练了48个epoch之后就只剩3500个了,相当于一个epoch的30万张图片中只有3500个计算了损失,这就造成训练的变慢,可能这时候再调低学习率,就真训练不动了。
注:这套代码逻辑中,config文件里写的学习率没用,全靠在这块调整。 -
问题六:
train_maskV9.py文件 的第401行是往log文件中写入学习率用的,这里写的还是config文件的学习率,造成我复查日志文件的时候发现学习率没变化,应该改成optimizer_model.param_groups[0]["lr"],这就是真学习率了。 -
问题七:
Data_loader/Data_loader_facenet_mask.py文件以及Data_loader/Data_loader_facenet_notmask.py文件最后的“NOTLFWestMask_dataloader”都要#掉,因为这套代码文件实际上没有提供“非LFW的验证数据”,所以这个dataloader会报找不到文件的错误,#掉就好了,不影响正常训练。 -
问题八:
config_mask.py文件 的第30行config['num_train_triplets']这个值应为10万。作者现在也在这行后面做了批注;第27行config['embedding_dim']我改成了256(这个无关紧要)。
注:config_notmask.py文件 也有同样的问题。 -
问题九:
validate_on_LFW.py文件 在训练过程中会用这个文件来测试模型效果,以及绘制ROC曲线图。在文件第33行# plt.show()后面应该加一个plt.close(),作用是关闭matplotlib的图,要不然每次图都存着叠加多了就会报错。
我在第一还是第二次训练时,大约在第60个epoch出现显存溢出OOM,后来调小batch_size就好了,这个因服务器显存而异,我显存实际可用量是11G,batch_size开的是20。