torch.utils.data.DataLoader 参数
例如,如果你的数据集返回的是图像和标签,默认的 collate_fn 会将一个 batch 的图像堆叠成一个 4D 张量(batch_size x channels x height x width),并将一个 batch 的标签堆叠成一个 1D 张量(batch_size)。这可以进一步提高数据加载速度,因为当一个 batch 的数据被送入模型训练时,下一个 batch 的数据已经在加载中了。这
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
-
num_workers (int, optional): 表示用于数据加载的子进程数量。 默认值为 0,表示在主进程中加载数据。
-
num_workers = 0:所有数据加载都在主进程中进行。 对于小型数据集或简单的变换,这通常足够了。 但是,对于大型数据集或复杂的变换,这可能会成为瓶颈,因为主进程需要等待数据加载完成才能开始训练。
-
num_workers > 0:使用多个子进程来加载数据。 这可以显著加快数据加载速度,因为数据加载和模型训练可以并行进行。 子进程将数据加载到内存中,主进程则负责将数据送入模型进行训练。 num_workers 的最佳值取决于你的硬件配置(CPU 核心数、内存大小、磁盘速度等)和数据集的复杂性。 过大的 num_workers 值可能会导致 CPU 负载过高,甚至降低性能。
-
-
pin_memory (bool, optional): 如果设置为 True,数据加载器会在返回张量之前将其复制到 CUDA 固定内存中。 这可以加快将张量从 CPU 复制到 GPU 的速度,但会消耗更多的内存。 通常建议在使用 GPU 训练时将 pin_memory 设置为 True。
-
persistent_workers (bool, optional): 如果设置为 True,数据加载器会在 epoch 之间保持 worker 进程的存活状态。 这可以减少在每个 epoch 开始时创建和销毁 worker 进程的开销,从而提高性能。 仅在 num_workers > 0 时有效。
-
prefetch_factor (int, optional): 指定在每个 worker 中预取的 batch 数量。 例如,如果 prefetch_factor = 2,则每个 worker 会预取两个 batch 的数据。 这可以进一步提高数据加载速度,因为当一个 batch 的数据被送入模型训练时,下一个 batch 的数据已经在加载中了。 仅在 num_workers > 0 时有效。
-
worker_init_fn (callable, optional): 一个在每个 worker 子进程启动时调用的函数。 这可以用于在 worker 进程中执行一些初始化操作,例如设置随机种子。 仅在 num_workers > 0 时有效。 这个例子中,如果num_workers > 0,则使用传入的 worker_init_fn,否则为 None。
- drop_last:当最后剩余的数据数量不足一个batchsize大小时,丢弃最后一个batchsize的数据
接下来介绍一下collate_fn
默认情况下,DataLoader 使用默认的 collate_fn,它会将一个 batch 的数据样本堆叠成一个张量。 例如,如果你的数据集返回的是图像和标签,默认的 collate_fn 会将一个 batch 的图像堆叠成一个 4D 张量(batch_size x channels x height x width),并将一个 batch 的标签堆叠成一个 1D 张量(batch_size)。
然而,在某些情况下,你可能需要自定义 collate_fn。 例如:
-
不同长度的序列: 如果你处理的是变长序列(例如文本数据),默认的 collate_fn 无法直接处理。 你需要自定义 collate_fn 来对序列进行填充,使其长度一致。 例如可以使用 torch.nn.utils.rnn.pad_sequence 函数。
-
自定义数据结构: 如果你使用的是自定义的数据结构,默认的 collate_fn 可能无法正确处理。 你需要自定义 collate_fn 来将你的数据结构转换为张量。
-
特殊的数据预处理: 你可能需要在将数据送入模型之前进行一些特殊的数据预处理操作。 你可以将这些操作放在自定义的 collate_fn 中。
collate_fn 的接口:
collate_fn 接受一个列表作为输入,列表中的每个元素都是一个数据样本。 collate_fn 需要返回一个 batch 的数据,通常是一个张量或一个元组/列表,其中包含多个张量。
示例:
import torch
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def my_collate_fn(batch):
# 对 batch 中的数据进行处理
# ...
return processed_batch
data = [
[1, 2, 3],
[4, 5],
[6, 7, 8, 9]
]
dataset = MyDataset(data)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=my_collate_fn)
for batch in dataloader:
print(batch)
在这个例子中,my_collate_fn 可以用来对不同长度的列表进行填充,或者进行其他自定义的预处理操作
更多推荐
所有评论(0)