3.5 图像分类数据集(Fashion-MNIST)

获取数据集

通过torchvision的torchvision.datasets来下载这个数据集。第一次调用时会自动从网上获取数据。通过参数train来指定获取训练数据集或测试数据集。参数transform = transforms.ToTensor()使所有数据转换为Tensor,如果不进行转换则返回的是PIL图片。transforms.ToTensor()将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor

1
2
3
4
5
6
7
8
9
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

print(len(mnist_train), len(mnist_test)) #获取该数据集的大小
#output: 60000 10000

feature, label = mnist_train[0] # 可以通过下标来访问任意一个样本
print(feature.shape, label) # Channel x Height x Width
#output: torch.Size([1, 28, 28]) tensor(9)

读取小批量

数据读取经常是训练的性能瓶颈,特别当模型较简单或者计算硬件性能较高时。PyTorch的DataLoader中一个很方便的功能是允许使用多进程来加速数据读取。

1
2
3
4
5
6
7
batch_size = 256
if sys.platform.startswith('win'):
num_workers = 0 # 0表示不用额外的进程来加速读取数据
else:
num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)