这是一个用于蝴蝶分类的卷积神经网络设计。经验证,准确率可达到89%.
- src文件夹下为核心代码文件:
- preprocess.py为数据预处理
- model.py为CNN网络
- train.py为训练脚本
- eval.py为评估脚本
- info.py用于打印模型信息
- cam.py用于生成热力图
- utils.py为工具函数(与模型本身无关)
-
checkpoints文件夹下为训练好的参数权重
-
logs为训练日志
-
outputs为输出:
- 包括混淆矩阵、报告和训练曲线
- 子文件夹cam为热力图
-
pre_data为预处理后的数据保存
-
原始数据保存在./ButterflyClassificationDataset/ButterflyClassificationDataset
-
首先确保你的虚拟环境支持,包括一些常用的包,我已经打包在requirements.txt文件中,你可以pip安装,特别注意CUDA支持的pytorch
-
将原始数据导入指定位置./ButterflyClassificationDataset/ButterflyClassificationDataset或者在配置文件config.yaml中修改“ data:root: ”为你指定的数据集路径。
-
通过命令行进入项目根目录
-
进行数据预处理,在命令行输入: python.exe main.py --config config.yaml --mode preprocess 会在根目录下生成pre_data文件夹,分为三个子文件夹train,val,test
-
模型训练,在命令行输入: python.exe main.py --config config.yaml --mode train
-
评估 python.exe main.py --config config.yaml --mode eval
-
热力图 python.exe main.py --config config.yaml --mode cam
-
模型信息 python.exe main.py --config config.yaml --mode info
-
主要配置在config.yaml中