简单记录一下如何使用 PyTorch 的 DataSet 及 DataLoader 功能。
DataSet 的使用通过继承 DataSet 类完成,并在此基础上需要构造三个特殊函数。下例为使用 DataSet,通过访问 json 文件获取数据内容,然后在 gititem 函数中获取数据并返回的例子。
DataSet 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 import torchfrom torch.utils.data import DataLoaderimport jsonfrom PIL import Imageimport numpy as npclass DataSet (torch.utils.data.Dataset) : def __init__ (self, train_or_valid, transform, path) : super().__init__() file = open('data/dataset.json' , 'r' ) data = json.load(file) self.datalist = data[train_or_valid] self.path = path def __getitem__ (self, index) : name = self.datalist[index][0 ] img = Image.open(self.path + name) label = self.datalist[index][1 ] return {"img" : img, "label" : label} def __len__ (self) : return len(self.datalist)
上述返还的是一个数据字典。
调用 依旧使用上述的例子。
1 trainset = DataSet('train' , transform_train, normal_path)
DataLoader DataLoader 是 PyTorch 用来调取 DataSet 的一个类,其声明和使用如下:
1 2 3 from torch.utils import datatrainloader = data.DataLoader(trainset, batch_size = batch_sz, shuffle = True )
第一个参数是上述生成的 DataSet,后面如同表述。
但往往上述的结构由于数据不规整不能满足要求,需要自己定义 Batch 函数。如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 from torch.utils import datadef padding (data) : src_len = [] for p in data: src_len.append(p['wav' ].shape[1 ]) src_pad = torch.zeros(len(data), data[0 ]['wav' ].shape[0 ], max(src_len)) tgt = torch.zeros(len(data)) for i in range(len(data)): p = data[i] end = src_len[i] src_pad[i, :, -end:] = p['wav' ] tgt[i] = p['label' ] return {'wav' : src_pad, 'label' : tgt} validloader = data.DataLoader(validset, batch_size = batch_sz, shuffle = False , collate_fn = padding)
返回的也是字典,并会使用 padding 函数。
调用 上面完成了预先代码的构建,最后是调用的步骤:
1 2 3 for idx, samples in enumerate(trainloader): wavs, labels = samples['wav' ], samples['label' ] pass
返回的 samples 就是前面 padding 的结果,可以对此进行修改,例如保留原有的长度信息等等。