PyTorch常用训练框架
现在模型建的比较多了,因此也形成了一套成熟的流程,这里简单的记述一下常用的模型构建的方法,为了后续改进。
文件夹架构
1 | \- model |
config.py 保存 train 以及预处理中的超参,但是不建议使用该文件保存模型的超参(除非整个调整结束)。 utils.py 保留操作函数,用来辅助预处理以及数据分析等等功能。data中存储原始数据以及处理后的数据,部分时候有中间生成数据。images保存为了报告生成的图片。
生成上述结构代码。
1 | import os |
Import
常用的 import 库文件。
1 | from sklearn.metrics import classification_report |
Model
自己的模型
常用的 Model 架构:
1 | import torch |
预训的模型
使用一些预训的模型使用。有两种魔改方法,其一是替代原模型中的部分层,另一部分是取出模型的某些部分和自己的其他网络组合。
替换层方法
以 vgg16 的替换方法为例。其中可以通过model.features._modules[]拿到对应的层,其中输入为 print(model)产生的输出。
1 | class VGG(nn.Module): |
上面可以通过pretrain = True拿到预训参数,但是下载很慢,可以复制链接自行离线下载然后通过上述方法导入。
重新组合方法
1 |
|
Train
下面是简化的框架。
1 | from sklearn.metrics import precision_score |
模型保存及预加载
保存
1 |
|
加载
1 | model.load_state_dict(torch.load(PATH)) |
这样就是整个模型的最基础框架搭建。但事实上一个任务真正困难的是在数据预处理策略和最后的调参上,这些就放在别的地方补充了吧。