3.5 图像分类数据集(Fashion-MNIST)
获取数据集
通过torchvision的torchvision.datasets
来下载这个数据集。第一次调用时会自动从网上获取数据。通过参数train
来指定获取训练数据集或测试数据集。参数transform = transforms.ToTensor()
使所有数据转换为Tensor
,如果不进行转换则返回的是PIL图片。transforms.ToTensor()
将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8
的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32
且位于[0.0, 1.0]的Tensor
。
1 | mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor()) |
读取小批量
数据读取经常是训练的性能瓶颈,特别当模型较简单或者计算硬件性能较高时。PyTorch的DataLoader
中一个很方便的功能是允许使用多进程来加速数据读取。
1 | batch_size = 256 |