guodong's blog

master@zhejiang university
   

目标检测yolo3-keras训练自己的数据集

1、所需环境

源码地址 https://github.com/qqwweee/keras-yolo3

python3

keras 2.1.5

tensorflow 1.6.0

2、下载源码以及预训练的权重

源码:

git clone https://github.com/qqwweee/keras-yolo3.git

权重:

wget https://pjreddie.com/media/files/yolov3.weights

源码文件结构:

—YOLO3-keras
│  coco_annotation.py                            #coco数据集格式转换
│  convert.py                                          #权重格式转换
│  darknet53.cfg                                     #darknet53网络配置文件
│  kmeans.py                     #k-means算法,用于生成预设的anchor box
│  train.py                                               #训练
│  voc_annotation.py                              #voc数据集格式转换
│  yolo.py                                               #预测
│  yolov3-tiny.cfg                                   #yolo3-tiny配置文件
│  yolov3.cfg                                          #yolo3的配置文件
│  yolo_video.py                                    #yolo3预测视频

├─font                                                     #字体库,用于预测边框的文字
│      FiraMono-Medium.otf
│      SIL Open Font License.txt

├─model_data                                         #模型数据
│      coco_classes.txt                          #coco类别
│      tiny_yolo_anchors.txt                #anchor
│      voc_classes.txt                           #voc 类别
│      yolo_anchors.txt                        #anchor

└─yolo3
│       model.py                                  #yolo模型
│       utils.py                                     #yolo模型里用到的函数
│       __init__.py

3、准备自己的数据集及权重

  • 明确自己的数据集格式,应该和voc的数据集格式相同,可以使用labelimg。
  • 使用python voc_annotation.py转化数据集,格式如下
  • 转换权重
python convert.py  yolov3.cfg  yolov3.weights  model_data/yolo.h5

其中 yolov3.cfg指yolo3的配置文件,yolov3.weights为下载好的预训练的权重,最后一个参数为输出路径,命名为yolo.h5

  • 根据自己的数据集准备anchor (可选)
python kmeans.py

其中kmeans.py中需要根据自己需要修改的参数有(已红色标记)

if __name__ == "__main__":
    cluster_number = 9   #生成的anchor
    filename = "2012_train.txt"  #数据集文件
    kmeans = YOLO_Kmeans(cluster_number, filename)
    kmeans.txt2clusters()

4、训练

python train.py

其中train.py的内容需要修改或注意的有(第16行)

def _main():
    annotation_path = 'train.txt'
    log_dir = 'logs/000/'
    classes_path = 'model_data/voc_classes.txt'
    anchors_path = 'model_data/yolo_anchors.txt'</span>
    class_names = get_classes(classes_path)
    num_classes = len(class_names)
    anchors = get_anchors(anchors_path)

第30以及33行,明确权重路径。
第52行,如果Ture,则先冻结darknet53只训练yolo层。可以根据己需求更改为false或者迭代周期epoch次数。
第70行,如果True,则始终训练全部层。可以根据己需求更改为false或者迭代周期epoch次数。

5、预测

预测图片时,

python yolo.py

预测视频时,

python yolo_vedio.py  [video_path] [output_path (optional)]



上一篇:
下一篇:

头像

guodong

说点什么

avatar
  Subscribe  
提醒