口罩检测数据集来源:https://www.bilibili.com/video/av88696686?t=115
口罩检测模型:https://github.com/zzh8829/yolov3-tf2
模型使用:
- 数据集更改
将所有图片放到一个文件夹,重命名文件,修改xml文件中filename
修改./data/voc2012.names中的名称为have_mask和no_mask - 模型修改
由于磁盘空间有限,在train.py中的modelcheckpoint中添加参数,save_best_only=True - 数据集生成
分别生成train和val1
2
3
4
5
6
7
8
9python tools/voc2012.py \
--data_dir './data/voc2012_raw/VOCdevkit/VOC2012' \
--split train \
--output_file ./data/voc2012_train.tfrecord
python tools/voc2012.py \
--data_dir './data/voc2012_raw/VOCdevkit/VOC2012' \
--split val \
--output_file ./data/voc2012_val.tfrecord - 模型训练
迁移学习1
2
3
4
5
6
7
8
9
10
11
12
13
14wget https://pjreddie.com/media/files/yolov3.weights -O data/yolov3.weights
python convert.py
python detect.py --image ./data/meme.jpg # Sanity check
python train.py \
--dataset ./data/voc2012_train.tfrecord \
--val_dataset ./data/voc2012_val.tfrecord \
--classes ./data/voc2012.names \
--num_classes 2 \
--mode fit --transfer darknet \
--batch_size 16 \
--epochs 10 \
--weights ./checkpoints/yolov3.tf \
--weights_num_classes 80
随机初始化
1 | python train.py \ |
可以在迁移学习冻结darknet主题部分之后,不冻结再次训练
batch_size要尽量小,不然会爆显存
- 预测
1 | python detect.py --classes ./data/voc2012.names --num_classes 2 --weights ./checkpoints/yolov3_train_2.tf --image ./data/street.jpg |
- 实时检测