简单记录一下如何使用 PyTorch 的 DataSet 及 DataLoader 功能。

DataSet 的使用通过继承 DataSet 类完成,并在此基础上需要构造三个特殊函数。下例为使用 DataSet,通过访问 json 文件获取数据内容,然后在 gititem 函数中获取数据并返回的例子。

DataSet

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from torch.utils.data import DataLoader
import json
from PIL import Image
import numpy as np

class DataSet(torch.utils.data.Dataset):
def __init__(self, train_or_valid, transform, path):
super().__init__()
file = open('data/dataset.json', 'r')
data = json.load(file)
self.datalist = data[train_or_valid]
self.path = path

def __getitem__(self, index):
name = self.datalist[index][0]
img = Image.open(self.path + name)
label = self.datalist[index][1]
return {"img": img, "label": label}

def __len__(self):
return len(self.datalist)

上述返还的是一个数据字典。

调用

依旧使用上述的例子。

1
trainset = DataSet('train', transform_train, normal_path)

DataLoader

DataLoader 是 PyTorch 用来调取 DataSet 的一个类,其声明和使用如下:

1
2
3
from torch.utils import data

trainloader = data.DataLoader(trainset, batch_size = batch_sz, shuffle = True)

第一个参数是上述生成的 DataSet,后面如同表述。

但往往上述的结构由于数据不规整不能满足要求,需要自己定义 Batch 函数。如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from torch.utils import data
def padding(data):
src_len = []
for p in data:
src_len.append(p['wav'].shape[1])
src_pad = torch.zeros(len(data), data[0]['wav'].shape[0], max(src_len))
tgt = torch.zeros(len(data))
for i in range(len(data)):
p = data[i]
end = src_len[i]
src_pad[i, :, -end:] = p['wav']
tgt[i] = p['label']

return {'wav': src_pad, 'label': tgt}

validloader = data.DataLoader(validset, batch_size = batch_sz, shuffle = False, collate_fn = padding)

返回的也是字典,并会使用 padding 函数。

调用

上面完成了预先代码的构建,最后是调用的步骤:

1
2
3
for idx, samples in enumerate(trainloader):
wavs, labels = samples['wav'], samples['label']
pass

返回的 samples 就是前面 padding 的结果,可以对此进行修改,例如保留原有的长度信息等等。