Skip to content

torchvision.datasets

译者:BXuan694

所有的数据集都是torch.utils.data.Dataset的子类, 即:它们实现了__getitem____len__方法。因此,它们都可以传递给torch.utils.data.DataLoader,进而通过torch.multiprocessing实现批数据的并行化加载。例如:

imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

目前为止,收录的数据集包括:

数据集

以上数据集的接口基本上很相近。它们至少包括两个公共的参数transformtarget_transform,以便分别对输入和和目标做变换。

class torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)

MNIST数据集。

参数:

  • root(string)– 数据集的根目录,其中存放processed/training.ptprocessed/test.pt文件。
  • train(bool, 可选)– 如果设置为True,从training.pt创建数据集,否则从test.pt创建。
  • download(bool, 可选)– 如果设置为True, 从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform (可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
class torchvision.datasets.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)

Fashion-MNIST数据集。

参数:

  • root(string)– 数据集的根目录,其中存放processed/training.ptprocessed/test.pt文件。
  • train(bool, 可选)– 如果设置为True,从training.pt创建数据集,否则从test.pt创建。
  • download(bool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
class torchvision.datasets.EMNIST(root, split, **kwargs)

EMNIST数据集。

参数:

  • root(string)– 数据集的根目录,其中存放processed/training.ptprocessed/test.pt文件。
  • split(string)– 该数据集分成6种:byclassbymergebalancedlettersdigitsmnist。这个参数指定了选择其中的哪一种。
  • train(bool, 可选)– 如果设置为True,从training.pt创建数据集,否则从test.pt创建。
  • download(bool, 可选)– 如果设置为True, 从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform(可被调用 , 可选) – 一种函数或变换,输入目标,进行变换。

注意:

以下要求预先安装COCO API

class torchvision.datasets.CocoCaptions(root, annFile, transform=None, target_transform=None)

MS Coco Captions数据集。

参数:

  • root(string)– 下载数据的目标目录。
  • annFile(string)– json标注文件的路径。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.ToTensor
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。

示例

import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
                        annFile = 'json annotation file',
                        transform=transforms.ToTensor())

print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample

print("Image Size: ", img.size())
print(target)

输出:

Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']

__getitem__(index)
参数: index (int) – 索引
返回: 元组(image, target),其中target是列表类型,包含了对图片image的描述。
--- ---
返回类型: tuple
--- ---
class torchvision.datasets.CocoDetection(root, annFile, transform=None, target_transform=None)

MS Coco Detection数据集。

参数:

  • root(string)– 下载数据的目标目录。
  • annFile(string)– json标注文件的路径。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.ToTensor
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
__getitem__(index)
参数: index (int) – 索引
返回: 元组(image, target),其中target是coco.loadAnns返回的对象。
--- ---
返回类型: tuple
--- ---
class torchvision.datasets.LSUN(root, classes='train', transform=None, target_transform=None)

LSUN数据集。

参数:

  • root(string)– 存放数据文件的根目录。
  • classes(string list)– {'train', 'val', 'test'}之一,或要加载类别的列表,如['bedroom_train', 'church_train']。
  • transform(可被调用 , 可选) – 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
__getitem__(index)
参数: index (int) – 索引
返回: 元组(image, target),其中target是目标类别的索引。
--- ---
Return type: tuple
--- ---
class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)

一种通用数据加载器,其图片应该按照如下的形式保存:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

参数:

  • root(string)– 根目录路径。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  • loader – 一种函数,可以由给定的路径加载图片。
__getitem__(index)
参数: index (int) – 索引
返回: (sample, target),其中target是目标类的类索引。
--- ---
返回类型: tuple
--- ---
class torchvision.datasets.DatasetFolder(root, loader, extensions, transform=None, target_transform=None)

一种通用数据加载器,其数据应该按照如下的形式保存:

root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext

root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext

参数:

  • root(string)– 根目录路径。
  • loader(可被调用)– 一种函数,可以由给定的路径加载数据。
  • extensions(list[__string__])– 列表,包含允许的扩展。
  • transform(可被调用 , 可选)– 一种函数或变换,输入数据,返回变换之后的数据。如:对于图片有transforms.RandomCrop
  • target_transform – 一种函数或变换,输入目标,进行变换。
__getitem__(index)
参数: index (int) – 索引
返回: (sample, target),其中target是目标类的类索引.
--- ---
返回类型: tuple
--- ---

这个类可以很容易地实现ImageFolder数据集。数据预处理见此处

示例

class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)

CIFAR10数据集。

参数:

  • root(string)– 数据集根目录,要么其中应存在cifar-10-batches-py文件夹,要么当download设置为True时cifar-10-batches-py文件夹保存在此处。
  • train(bool, 可选)– 如果设置为True, 从训练集中创建,否则从测试集中创建。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  • download(bool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
__getitem__(index)
参数: index (int) – 索引
返回: (image, target),其中target是目标类的类索引。
--- ---
返回类型: tuple
--- ---
class torchvision.datasets.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)

CIFAR100数据集。

这是CIFAR10数据集的一个子集。

class torchvision.datasets.STL10(root, split='train', transform=None, target_transform=None, download=False)

STL10数据集。

参数:

  • root(string)– 数据集根目录,应该包含stl10_binary文件夹。
  • split(string)– {'train', 'test', 'unlabeled', 'train+unlabeled'}之一,选择相应的数据集。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  • download(bool, optional)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
__getitem__(index)
参数: index (int) – 索引
返回: (image, target),其中target应是目标类的类索引。
--- ---
返回类型: tuple
--- ---
class torchvision.datasets.SVHN(root, split='train', transform=None, target_transform=None, download=False)

SVHN数据集。注意:SVHN数据集将10指定为数字0的标签。然而,这里我们将0指定为数字0的标签以兼容PyTorch的损失函数,因为损失函数要求类标签在[0, C-1]的范围内。

参数:

  • root(string)– 数据集根目录,应包含SVHN文件夹。
  • split(string)– {'train', 'test', 'extra'}之一,相应的数据集会被选择。'extra'是extra训练集。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
  • download(bool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
__getitem__(index)
参数: index (int) – 索引
返回: (image, target),其中target是目标类的类索引。
--- ---
返回类型: tuple
--- ---
class torchvision.datasets.PhotoTour(root, name, train=True, transform=None, download=False)

Learning Local Image Descriptors Data数据集。

参数:

  • root(string)– 保存图片的根目录。
  • name(string)– 要加载的数据集。
  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。
  • download (bool, optional) – 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
__getitem__(index)
参数: index (int) – 索引
返回: (data1, data2, matches)
--- ---
返回类型: tuple
--- ---


回到顶部