Skip to content

torch.utils.data

原文: https://pytorch.org/docs/1.4.0/data.html

PyTorch 数据加载实用程序的核心是 torch.utils.data.DataLoader 类。 它表示可在数据集上迭代的 Python,并支持

这些选项由 DataLoader 的构造函数参数配置,构造函数的签名如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

以下各节详细介绍了这些参数选项的作用和用法。

数据集类型

DataLoader 构造函数最重要的参数是dataset,它指定了用于加载数据的数据集对象。PyTorch 支持两种不同类型的数据集:

映射式数据集

映射式数据集是一种实现__getitem__()__len__()两个协议函数的数据集,它表示从(可能是非整数)索引/键到数据样本的映射。

例如,当使用dataset[idx]访问此类数据集时,可以从磁盘上的文件夹中读取第idx张图像及其对应的标签。

有关更多详细信息,请参见 Dataset

迭代式数据集

迭代式的数据集是 IterableDataset 子类的实例,该子类实现了__iter__()协议函数,并表示可在数据样本上进行迭代。当随机读取代价较高或不可能实现时,以及批处理大小取决于所获取数据的情况时,使用这种数据集是相当合适的。

例如,当这种数据集被iter(dataset)调用时,它可以从数据库、远程服务器甚至日志中实时读取返回数据流。

有关更多详细信息,请参见 IterableDataset

注意

IterableDataset多进程数据加载一起使用时。 每个工作进程都会复制相同的数据集对象,因此必须对副本进行不同的配置,以避免重复的数据。 有关如何实现此功能的信息,请参见 IterableDataset 文档。

数据加载顺序和 Sampler

对于迭代式数据集,数据加载顺序完全由用户定义的迭代器控制。这样可以轻松实现块读取和动态批次大小(例如,每次生成一个批次的样本)。

本节的其余部分主要关注映射式数据集torch.utils.data.Sampler 类用于指定数据加载中使用的索引/键的顺序。 它们代表数据集索引上的可迭代对象。 以一个常见情况举例,使用随机梯度下降(SGD)时,一个Sampler 每次可以随机生成一列索引,或者为小批量SGD生成少量索引。

基于 DataLoadershuffle参数,将自动构建顺序采样或打乱的采样器。 或者,用户可以使用sampler参数指定一个自定义的 Sampler 对象,该对象每次都会生成下一个要提取的索引/关键字。

一次生成一个批次索引的自定义 Sampler 可以赋值给batch_sampler参数,传递到Dataloader的构造函数中。 也可以通过batch_sizedrop_last两个参数启用自动批处理。 有关更多详细信息,请参见下一部分

Note

samplerbatch_sampler与迭代式数据集均不兼容,因为此类数据集没有键或索引的概念。

加载批处理和非批处理数据

DataLoader 支持通过参数batch_sizedrop_lastbatch_sampler将各个提取的数据样本自动整理为批次数据。

自动批处理(默认)

这是最常见的情况,对应于获取一小批数据并将其整理为批处理的样本,即数据加载后生成一个张量,张量中的某一维度代表批处理数据量的数目(通常是第一维)。

batch_size(默认1)不是None时,数据加载器将生成批处理的样本,而不是单个样本。 batch_sizedrop_last参数用于指定数据加载器如何获取一个批次的数据集的键。 对于映射式数据集,用户可以选择指定batch_sampler,它一次生成一个键列表。

Note

batch_sizedrop_last参数本质上是用sampler构造batch_sampler。 对于映射式数据集,sampler由用户提供,或基于shuffle参数构造。 对于迭代式数据集,sampler是一个虚拟且无限的采样器。 有关采样器的更多详细信息,请参见本节

Note

当用多进程数据加载迭代式数据集获取数据时,drop_last参数的应用,会丢弃每个进程中,数据集副本最后一个非完整批次。

使用采样器生成索引,并获取索引对应样本列表后,该样本列表将由collate_fn参数传递的函数整理为批次。

在这种情况下,从映射式数据集加载数据大致等效于:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

从迭代式数据集加载数据大致等效于:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

自定义collate_fn可用于自定义排序规则,例如,将顺序数据填充至批处理的最大长度。 有关collate_fn的更多信息,请参见本部分

禁用自动批处理

在某些情况下,用户可能希望用数据集代码手动处理批处理,或仅加载单个样本。 例如,直接加载批处理数据所需代价更低(例如,从数据库中批量读取或读取连续的内存块),或者批处理大小取决于数据,或者该程序设计为只处理单个样本。 在这些情况下,最好不要使用自动批处理(其中collate_fn用于整理样本),而应让数据加载器直接返回dataset对象的每个成员。

batch_sizebatch_sampler均为None时(其中batch_sampler的默认值为None),自动批处理被禁用。 collate_fn参数传递的函数对从dataset获得的每个样本进行处理。

禁用自动批处理时,collate_fn的默认函数仅将 NumPy 数组转换为 PyTorch 张量,而其他所有内容均保持不变。

在这种情况下,从映射式数据集中读取样本基本相当于:

for index in sampler:
    yield collate_fn(dataset[index])

从迭代式数据集中读取样本基本相当于:

for data in iter(dataset):
    yield collate_fn(data)

有关collate_fn的更多信息,请参见本部分

使用collate_fn

启用或禁用自动批处理时,collate_fn的用法略有不同。

禁用自动批处理时,对每个数据样本分别调用collate_fn,并且从数据加载器迭代器产生输出。 在这种情况下,默认的collate_fn仅转换 PyTorch 张量中的 NumPy 数组。

启用自动批处理时,对数据样本列表每次调用collate_fn,这是为了将输入样本整理为一个批次,或从数据加在迭代器中生成一个批次的数据。本节的其余部分描述了这种情况下默认collate_fn的行为。

例如,如果每个数据样本都包含一个3通道图像和一个整体类标签,即数据集的每个元素返回一个元组(image, class_index),则默认的collate_fn将此类元组的列表整理为一个单独的元组 该元组包括批处理化的图像张量和批处理化的类标签张量。特别地,默认collate_fn具有以下属性:

  • 它始终将新维度前置为批次维度。

  • 它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。

  • 它保留了数据结构,例如,如果每个样本都是一个字典,它将输出一个字典,该字典具有相同的键集合,但值将被批处理化为张量(或者,当值无法被转换为张量时,将被转换为列表)。listtuplenamedtuple等相同,数据结构不变。

用户可以使用自定义的collate_fn来实现自定义批处理,例如,沿除第一个维度之外的其他维度进行数据整理,将序列填充至不同长度,或添加对自定义数据类型的支持。

单进程和多进程数据加载

默认情况下, DataLoader 使用单进程数据加载。

在 Python 进程中,全局解释器锁定(GIL)阻止了跨线程时Python代码真正的完全并行化。为了避免在加载数据时系统阻塞计算代码,PyTorch 提供了一个简单的开关,只需将参数num_workers设置为正整数,即可执行多进程数据加载。

单进程数据加载(默认)

在此模式下,获取数据的进程与 DataLoader初始化的进程一致。因此,数据加载时计算可能会被阻止。然而,以下情况可优先选择本模式:如进程之间用于共享数据的资源(例如共享内存,文件描述符等)有限,或者整个数据集很小,可以被完全加载到内存中。此外,单进程加载通常能够显示更多可读的错误跟踪,因此有利于代码调试。

多进程数据加载

将参数num_workers设置为正整数,将启动多进程数据加载,工作进程数目为参数指定值。

在此模式下,每次创建 DataLoader 的迭代器时(例如,当您调用enumerate(dataloader)时),都会创建num_workers工作进程。 此时,datasetcollate_fnworker_init_fn被传递给每个工作进程,在这些工作进程中,它们被用来初始化和获取数据。 这意味着数据集访问及其内部IO、数据变换(包括collate_fn)在工作进程中运行。

torch.utils.data.get_worker_info() 返回工作进程的各种有用信息(包括工作进程的ID,数据集副本,初始化种子等),而主进程返回None。 用户可以在数据集代码和/或worker_init_fn中使用此函数来分别配置每个数据集副本,并确定代码是否正在工作进程中运行。例如,这个函数在对数据集切片时特别有用。

对于映射式数据集,主过程使用sampler生成索引并将其发送给工作进程。 因此,任何随机打乱都是在主进程中完成的,该进程通过为加载分配索引来引导加载过程。

对于迭代式数据集,由于每个工作进程都获得dataset对象的副本,因此简易的多进程加载通常会导致数据重复。 用户可以使用 torch.utils.data.get_worker_info() 和/或worker_init_fn独立配置每个副本。(有关如何实现此操作的信息,请参见 IterableDataset 文档。)出于类似的原因,在多进程加载中,drop_last参数会删除每个工作进程中迭代式数据集副本的最后一个非完整批次。

一旦迭代结束或迭代器被垃圾回收,工作进程将关闭。

警告

通常不建议在多进程加载中返回 CUDA 张量,因为在使用 CUDA 和并行处理共享 CUDA 张量时,存在很多微妙之处(请参见在并行处理中的 CUDA)。相反,我们建议使用自动内存固定(即,设置pin_memory=True),该功能可以将数据快速传输到支持CUDA的GPU中。

平台特定的行为

由于工作进程依赖于 Python multiprocessing,因此与 Unix 相比,Windows 上的工作进程启动行为有所不同。

  • 在 Unix 上,fork()multiprocessing默认的启动方法。 使用fork(),子进程通常可以通过克隆的地址空间,直接访问dataset和 Python 参数函数。

  • 在 Windows 上,spawn()multiprocessing的默认启动方法。 使用spawn()启动另一个解释器,该解释器将运行您的主脚本,然后运行内部的工作程序函数,该函数通过pickle序列化获取datasetcollate_fn和其他参数。

这种独立的序列化意味着,应采取两个步骤确保多进程数据与 Windows 兼容:

  • 大部分主脚本代码应在if __name__ == '__main__':块中,以确保在启动每个工作进程时,该脚本不会再次运行(很可能会产生错误)。 可以在此处放置数据集和 DataLoader 实例创建逻辑,因为实例创建不需要在工作进程中重新执行。

  • 确保在__main__检查之外将任意自定义collate_fnworker_init_fndataset代码声明为顶级定义,以确保它们在工作进程中可用。 (这是必需的,因为这些函数被序列化为引用,而不是bytecode。)

多进程数据加载中的随机性

默认情况下,每个工作进程的 PyTorch 种子将设置为base_seed + worker_id,其中base_seed是主进程使用其 RNG 生成的长整数(因此,强制使用 RNG 状态)。 但是,初始化工作进程时,可能会复制其他库的种子(例如 NumPy),导致每个工作进程返回相同的随机数。 (请参阅 FAQ 中的本部分。)

worker_init_fn中,可以用 torch.utils.data.get_worker_info().seedtorch.initial_seed() 访问每个工作进程的 PyTorch 种子集合,并在加载数据之前使用它为其他库提供种子。

内存固定

当GPU副本来自固定(页锁定)内存时,主机访问GPU副本速度会快得多。有关通常何时以及如何使用固定内存的详细信息,请参见使用固定内存缓冲区

对于数据加载,将pin_memory=True传递到 DataLoader 后,获取的数据张量将自动放置在固定内存中,此时数据可以更快传输到启用 CUDA 的 GPU中。

默认的内存固定逻辑,仅识别张量以及包含张量的映射和可迭代对象。 默认情况下,如果固定逻辑看到一个自定义类型的批处理(例如您有一个collate_fn,并返回自定义批处理类型),或者该批处理的每个元素都是自定义类型,则固定逻辑将无法识别它们,它将返回该批处理(或那些元素)而不固定内存。 要为自定义批处理或数据类型启用内存固定,请在自定义类型上定义pin_memory()方法。

请参见下面的示例。

例:

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())


class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)¶

数据加载器。 合并了数据集和采样器,并为给定数据集上提供迭代器。

DataLoader 支持映射式和迭代式数据集,可进行单进程或多进程数据加载,可自定义加载顺序,自动批处理(整理)和内存固定为可选参数。

有关更多详细信息,请参见 torch.utils.data 文档页面。

参数

  • dataset (Dataset)–要从中加载数据的数据集。

  • batch_size (python:int 可选)–每批次要加载多少个样本。(默认值:1

  • shuffle (bool 可选)–设置为True时,数据每轮会被重新打乱。(默认值:False )

  • sampler (Sampler 可选)–采样器定义了从数据集中抽取样本的策略。 如果指定了采样器,则shuffle必须为False

  • batch_sampler (Sampler 可选)–类似sampler,但每次只返回一个批次的索引。 与batch_sizeshufflesamplerdrop_last互斥。

  • num_workers (python:int 可选)–数据加载需要的子进程数目。 0表示将在主进程中加载​​数据。 (默认:0

  • collat​​e_fn (可调用的_, _可选)–合并样本列表以形成一个小批次的张量。 从映射式数据集中使用批量加载时使用。

  • pin_memory (bool 可选)–如果值为True,则数据加载器将把张量复制到 CUDA 固定的内存中,然后返回。 如果您的数据元素是自定义类型,或者您的collate_fn返回的是自定义类型的批次,请参见下面的示例。

  • drop_last (bool 可选)–当数据集大小无法被批次大小整除时,若该参数设置为True,则最后一个不完整的批次将被删除;设置为False,则最后一个批次的大小将比设定的批次大小要小。(默认:False

  • timeout(数字 可选)–如果为正,则表示从工作进程中收集批次的超时值。 应始终为非负数。 (默认:0

  • worker_init_fn (可调用 可选)–如果不是None,则该函数将在生成种子之后和数据加载之前,在每个工作进程中被调用,其中工作进程的id([0, num_workers - 1]范围内的整数值)是它的输入。(默认:None

Warning

如果使用spawn启动方法,则worker_init_fn不能是不可序列化对象,例如 lambda 函数。 有关 PyTorch 中与并行处理有关的更多详细信息,请参见并行处理最佳实践

Note

len(dataloader)启发式方法基于所用采样器的长度。 当datasetIterableDataset 时,无论多进程加载配置如何,都将返回len(dataset)(如果实现),因为 PyTorch 信任用户的dataset代码可以正确处理多进程加载,避免重复数据。 有关这两种类型的数据集以及 IterableDataset 如何与多进程数据加载交互的更多详细信息,请参见数据集类型


class torch.utils.data.Dataset¶

表示 Dataset 的抽象类。

所有从键映射到数据样本的数据集都应该是它的子类。所有子类都应该重写__getitem__(),实现键值给定,返回对应数据样本的功能。 子类还可以选择重写__len__(),它的预计功能为,通过许多 Sampler 实现以及 DataLoader 的默认选项,返回数据集的大小。

Note

默认情况下, DataLoader 构造一个索引采样器,该采样器产生整数索引。 要使DataLoader与具有非整数索引/键的映射式数据集一起使用,必须提供自定义采样器。


class torch.utils.data.IterableDataset¶

迭代式数据集。

所有代表可迭代数据样本的数据集都应该是它的子类。当数据来自流时,这种形式的数据集特别有用。

所有子类都应重写__iter__(),该函数将返回此数据集中的样本迭代器。

当子类与 DataLoader 一起使用时,数据集中的每条数据都将由 DataLoader 迭代器产生。 当num_workers > 0时,每个工作进程将具有数据集对象的不同副本,因此通常需要独立配置每个副本,以避免从工作进程返回重复的数据。在工作程序进程中调用get_worker_info()时,该函数将返回与工作进程有关的信息。可以在数据集的__iter__()方法或 DataLoaderworker_init_fn选项中应用get_worker_info()来修改每个副本的行为。

示例 1:在__iter__()中将工作负载分配给所有工作进程:

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]

>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]

示例 2:使用worker_init_fn将工作负载分配给所有工作进程:

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]


class torch.utils.data.TensorDataset(*tensors)¶

包装张量的数据集。

每个样本将沿第一维度索引张量,以完成检索。

Parameters

*tensors (Tensor)–张量,它的尺寸的与第一维大小相同。


class torch.utils.data.ConcatDataset(datasets)¶

该数据集是多个数据集的串联。

此类对于合并不同的现有数据集很有用。

Parameters

datasets(序列)–要串联的数据集列表


class torch.utils.data.ChainDataset(datasets)¶

用于链接多个 IterableDataset 的数据集。

此类对于合并不同的现有数据集流很有用。 链接操作是即时完成的,因此将大型数据集与此类连接起来使用将非常有效。

Parameters

datasets(IterableDataset 的可迭代对象)– 要链接在一起的数据集


class torch.utils.data.Subset(dataset, indices)¶

指定索引处的数据集子集。

Parameters

  • dataset (Dataset)–整个数据集

  • indices(序列)–在整个数据集中为子集选择的索引序列


torch.utils.data.get_worker_info()¶

返回当前 DataLoader 迭代器工作进程的相关信息。

在工作进程中调用时,此方法返回一个对象,该对象保证具有以下属性:

  • id:当前工作进程 ID。

  • num_workers:工作进程总数。

  • seed:当前工作进程的随机种子集。 该值由主进程 RNG 和工作进程 ID 确定。 有关更多详细信息,请参见 DataLoader 的文档。

  • dataset:在当前进程中的数据集对象副本。请注意,在不同的进程中,该数据集对象与主进程中的对象将不同。

在主进程中调用时,将返回None

Note

当此方法用作worker_init_fn传递给 DataLoader 使用时,它可用于对每个工作进程进行不同设置,例如,使用worker_id配置dataset对象,只读共享数据集的某一特定部分,或在数据集代码中用seed为其他库(例如 NumPy)做种。


torch.utils.data.random_split(dataset, lengths)¶

将数据集随机切片为给定长度的不重叠的新数据集。

Parameters

  • dataset (Dataset)–要切片的数据集

  • lengths(序列)–要产生的切片的长度


class torch.utils.data.Sampler(data_source)¶

所有采样器的基类。

每个 Sampler 子类都必须提供__iter__()方法和__len__()方法,前者提供一种对数据集元素的索引进行迭代的方法,后者返回已返回的迭代器的长度。

Note

DataLoader 并未严格要求__len__()方法,但涉及 DataLoader 长度的任意计算时,该方法都可能被调用。


class torch.utils.data.SequentialSampler(data_source)¶

始终以相同顺序序列化采样元素。

Parameters

data_source (Dataset)–要从中采样的数据集


class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)¶

随机采样元素。 如果参数replacement等于False,则从乱序数据集中采样。 如果replacement等于True,则用户可以指定num_samples抽取数据。

Parameters

  • data_source (Dataset) – 用于采样的数据集

  • replacement (bool )–当值为True时,样本抽取有替换,True为默认值

  • num_samples (python:int )–要抽取的样本数,默认值为len(dataset)。 仅当replacementTrue时才可指定此参数。


class torch.utils.data.SubsetRandomSampler(indices)¶

从给定的索引列表中随机抽样元素,无替换。

Parameters

indices(序列)–索引序列


class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)¶

以给定的概率(权重)从[0,..,len(weights)-1]中采样元素。

Parameters

  • weights(序列)–权重序列,权重序列和不必为1

  • num_samples (python:int )–要抽取的样本数

  • replacement (bool )–如果值为True,则抽取样本时可以有替代。 否则,抽取样本时无替换,这意味着当为一行抽取样本索引时,无法为该行再次抽取样本。

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[0, 0, 0, 1, 0]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]


class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)¶

包装另一个采样器以生成一个小批次的索引。

Parameters

  • sampler (Sampler)–基本采样器。

  • batch_size (python:int )–小批次的大小。

  • drop_last (bool )–如果值为True,则采样器将丢弃最后一批大小小于batch_size的数据

Example

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]


class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True)¶

将数据加载限制为数据集子集的采样器。

torch.nn.parallel.DistributedDataParallel 结合使用时特别有用。 在这种情况下,每个进程都可以将 DistributedSampler 实例作为 DataLoader 采样器传递,并加载原始数据集的专有子集。

Note

假定数据集大小恒定。

Parameters

  • dataset –用于采样的数据集。

  • num_replicas (可选)–参与分布式训练的进程数。

  • rank(可选)–当前进程在 num_replicas 中的等级。

  • shuffle(可选)–如果值为 true(默认值),采样器将随机打乱索引



回到顶部