Biu懂AI:Pytorch架构的CV训练数据输入

       Bui~ 新系列博文将专注AI相关领域,想要学习高通蓝牙相关知识请查看之前的系列或关注大博主声波电波就看今朝

       Pytorch是我们常用的训练框架之一,也是当下主流框架之一。框架提供许多有用的工具,也提供了获取Coco、Cifar、MNIST等知名数据集的快捷获取方法。但如果是自己的数据集,方法就不一样了。下面我们来了解一下,设置自己的数据集需要用到的模块。

       Pytorch提供的两个模块(或者说是python包)torch.utils.data.DataLoader 和torch.utils.data.Dataset。Dataset的作用是将源数据整理打包,方便使用。在这个class中有三个必要函数:

  1. 实例化时传入必要参数的__init__ :这里面会传入数据及标签存放的文件夹路径,如果模型对输入数据有要求,这里面会传入转换数据用的transform,包括标签的target_transform(都可在transforms 包中找到方法),后续会在DataLoader 调用 __getitem__函数里面使用。此外还会传入一些别的设置参数,但都是先保存在本地,供后续使用。这函数里面主要是为了将数据整理好,数据和标签一一对应形成列表,方便获取,也方便__len__函数统计数目。

  2. 例如数据和标签分别放在多个不同的文件夹下面保存,我们可以用他们的路径组队放在列表。__getitem__的时候就根据路径去提取内容;再例如标签只保存放在一个文件里面,我们就需要读取里面的内容,并映射到对应的数据文件;或者说数据文件名和标签文件名不一样,需要靠第三个索引表去找对应关系,我们也要读出去将他们匹配好。

  3. 总的来说就是将训练数据整理好成一个个训练样本,以供__getitem__快速提取,加快训练时间



  4. 获取数据集总数的__len__:获取可用的样本个数



  5. 和获取一个指定索引样本的__getitem__:根据索引去找对应的样本,如果有transform就需要先将数据或标签转换成指定的数据格式。例如resize成同样大小的图片;或者normalized数据;或者转tensor格式。标签可能会转换一下数据格式、坐标格式、独热编码(one-hot)等等。这些转换都是为了满足模型输入要求或者提供模型性能。



       除了上面说的这些基本功能之外,有些衍生框架会加入很多调节功能,例如不够训练样本的就多添加一些增强数据;将数据或标签放在ram中,提供训练数据;数据和标签没成对的就抛弃;只取其中个别类型做样本数据的等等这些个性化的调节功能。

       DataLoader 的作用是管理数据的加载,并提供模型需要的数据格式输入。为了让训练速度加快,DataLoader 会将多个数据样本打包成一个batch(相当于把数据、标签合并成对应的大张量),多个数据样本一次送入模型中训练。(这里打包是将多个张量堆叠在一起,获取时是使用迭代器一一获取。但如果样本的张量不一致,例如object detect的label有多个输出,这时直接堆叠就会出错。因此DataLoader 提供了collate_fn函数,用来自定义打包功能的,但是要注意堆叠出来的张量要符合模型输入要求,要能在DataLoader 的迭代器中每次都能返回所需的样本数据和标签)



       打包的数量不能设置太小,不然模型很难收敛。但是也不能设太大,每次打包的数量要取决于内存容量,太大的话,有可能训练中途内存撑不住。(tips:GPU对2的幂次数的batch可以发挥更佳的性能,因此设置成16、32、64、128等数字性能更好,但是大部分数据集没办法被这些数字整除,所以常常有剩下的数据凑不够一个batch,因此你可以选择不使用剩下这些数据)

       另外,如果数据排序是按分类排的话,直接按序列取也会影响模型收敛速度,所以DataLoader 会提供打乱数据的功能;除了这些基本功能外,还有别的方法加快训练速度,例如将数据放在页存储中,加快读取速度,多进程运行DataLoader 等等,这些都能直接配置DataLoader 去实现,前提是需要有足够的资源。

 

        附件上传了小编自定义dataset和dataloader的一种使用场景例程,有需要可以下载参考下。以上是本期博文的全部内容,如有疑问请在博文下方评论留言,我会尽快解答(o´ω`o)و。谢谢大家浏览,我们下期再见。

 

 

简单是长期努力的结果,而不是起点
                                                 —— 不是我说的

 

 

FAQ 1:Keras能用这种方法吗?

A1:不行的,Keras是基于TensorFlow的框架

 

FAQ 2:label格式有哪些?

A2:看前一篇博文-Biu懂AI:Object Detection训练数据的Label格式

 

FAQ 3:可以获取dataloader的数据出来看吗?

A3:可以的,但是如果有transform的话,会转换成对应的张量,这时数据就不能直接显示出来了,需要进行逆转换

 

FAQ 4:train和val要用一样的dataloader吗?

A4: 不能说完全一样,因为train和val的任务不一样,val不需要考虑模型性能等问题,所以在数据transform时,可以不用考虑数据增强问题。

 

FAQ 5:图像数据用别的格式可以吗?

A5: 可以的,但是需要transform,变成RBG格式就行了

 

技术文档

类型标题档案
硬件ipynb

★博文内容均由个人提供,与平台无关,如有违法或侵权,请与网站管理员联系。

★文明上网,请理性发言。内容一周内被举报5次,发文人进小黑屋喔~

评论