在PyTorch中,数据集(Data Set)和数据加载器(Data Loader)是实现深度学习模型和测试的基本组件。下面将首先介绍数据集(Data Set)和数据加载器(Data Loader)的概念,然后介绍如何创建和使用PyTorch中的数据加载器的一些步骤和示例。
数据集类(Data Set)是指存储和表示数据的类或接口。它通常用于封装数据,以便能够在机器学习任务中使用。数据集可以是任何形式的数据,比如图像、文本、音频等。数据集的主要目的是提供对数据的标准访问方法,以便可以轻松地将其用于模型训练、验证和测试。
数据加载器(Data Loader)是一个提供批量加载数据的工具。它通过将数据集分割成小批量,并按照一定的顺序加载到内存中,以提高训练效率。数据加载器常用于训练过程中的数据预处理、批量化操作和数据并行处理等。
PyTorch中的 torch.utils.data.Dataset和torch.utils.data.DataLoader 是数据加载和处理的核心组件。它们将数据读取与模型训练解耦,提供高效、灵活的数据迭代方式。下面从基础概念、自定义加载器参数、多进程机制等方面进行详细介绍。
Data Set 是一个抽象类,表示一个数据集。任何自定义数据集都必须继承它,自定义DataSet类,必须实现它构造函数和两个方法:
__init__: 在 实例化DataSet 对象运行一次。我们初始化包含图像的目录、注释文件和transform与 target_transform.__len__:返回数据集的总样本数。len(dataset)会调用它。__getitem__(self, idx):根据整数索引idx会返回一个样本(通常为特征和标签)。dataset[idx] 会调用它。其作用就是实现通过索引访问对应的数据以及标签。
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
使用自定义数据集时,可以用将其与torch.utils.data.DataLoader 结合使用,以便进行数据的批量加载和处理和训练。
在PyTorch中,自定义数据集有两个核心设计模式:映射式(Map-Style)和 可迭代式(Iterable-style) 。它们的差异不仅是实现接口不同,更反映了“随机访问”与“流式读取”两种数据消费范式的根本区别。下面从设计理念、实现细节、多进程交互、适用场景等方面深入解析。
__getitem__ 和 __len__ 的数据集,它通过索引映射到数据样本。适用于所有数据能一次性放入索引结构(如列表、文件路径列表)的场景。IterableDataset,实现 __iter__ 方法返回一个迭代器。这种数据集不能使用 len(),也无法使用随机采样(shuffle)的 loader,需使用 Sampler 的特定变体。在后续笔记我们将详细介绍。
PyTorch提供了一些常用数据集类,主要在torchvision.datasets、torchtext.datasets、torchaudio.datasets中。例如:
torchvision.datasets.MNIST、CIFAR10、ImageFolder(从文件夹结构加载图片,子文件夹为类别)torchtext.datasets.IMDB 等torchaudio.datasets.LIBRISPEECH 等这些内置类都继承自 Dataset,使用时可自动下载数据,并提供标准化访问方式。
现在我们来展示一个如何从TorchVision加载了Fahion-MINIST由60000个训练样本和10000个测试样本组成。每个样本包含一个28×2828×28 灰度图像和一个来自10个类别之一的关联标签。下面使用以下参数加载FashionMINIST数据集:
root:是存储路径、测试数据的路径。train:指定训练集或测试数据集。download=True:如果root路径下没有数据,则从网上下载数据。transform和target_transform是指定特征和标签转换。import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="./data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="./data",
train=False,
download=True,
transform=ToTensor()
)
我们可以用索引来访问数据集中的样本,用 matplotlib 可视化图形样本。
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
其运行结果如下:

数据加载器(Data Loader)将 DataSet 封装为可迭代对象,负责批量加载、打乱数据、多进程并行加载等功能。其功能如下:
数据加载器的API形式与核心参数:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None,
generator=None, prefetch_factor=2, persistent_workers=False)
Dataset 对象(映射式或可迭代式)。RandomSampler。torch.utils.data.Sampler。定义数据索引的抽取策略。如果指定,shuffle 必须为 False。sampler,但每次返回一批索引,与 batch_size、shuffle、sampler 互斥。collate_fn 会将所有样本沿第0维堆叠成张量,通常对于同型数据有效。如果样本结构不一致(如不同长度序列),需要自定义。True,数据加载器在返回张量前将其复制到 CUDA 固定内存,加速数据传输到 GPU。仅适用于 CUDA。True,丢弃最后一个不完整批次(当总样本数不能被 batch_size 整除时)。在训练时如果要求严格固定批次大小(如 BatchNorm)应设为 TrueTrue,在数据集被消费一次后不会关闭 worker 进程,可保持 worker 存活以加速后续 epoch。数据调用案例Demo:
import torch
from torch.utils.data import Dataset, DataLoader
# 自定义数据集类
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
# 自定义数据加载器类
class MyDataLoader(DataLoader):
def __init__(self, dataset, batch_size=1, shuffle=False, num_workers=0):
super().__init__(dataset, batch_size, shuffle, num_workers=num_workers)
def collate_fn(self, batch):
# 自定义的数据预处理、合并等操作
# 这里只是简单地将样本转换为Tensor,并进行堆叠
return torch.stack(batch)
# 自定义数据集类
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)
# 创建数据加载器实例
dataloader = MyDataLoader(dataset, batch_size=2, shuffle=True)
# 遍历数据加载器
for batch in dataloader:
# batch是一个包含多个样本的张量(或列表)
# 这里可以对批次数据进行处理
print(batch)
import torch
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader
# 此函数用于加载鸢尾花数据集
def load_data(shuffle=True):
x = torch.tensor(load_iris().data)
y = torch.tensor(load_iris().target)
# 数据归一化
x_min = torch.min(x, dim=0).values
x_max = torch.max(x, dim=0).values
x = (x - x_min) / (x_max - x_min)
if shuffle:
idx = torch.randperm(x.shape[0])
x = x[idx]
y = y[idx]
return x, y
# 自定义鸢尾花数据类
class IrisDataset(Dataset):
def __init__(self, mode='train', num_train=120, num_dev=15):
super(IrisDataset, self).__init__()
x, y = load_data(shuffle=True)
if mode == 'train':
self.x, self.y = x[:num_train], y[:num_train]
elif mode == 'dev':
self.x, self.y = x[num_train:num_train + num_dev], y[num_train:num_train + num_dev]
else:
self.x, self.y = x[num_train + num_dev:], y[num_train + num_dev:]
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
def __len__(self):
return len(self.x)
batch_size = 16
# 分别构建训练集、验证集和测试集
train_dataset = IrisDataset(mode='train')
dev_dataset = IrisDataset(mode='dev')
test_dataset = IrisDataset(mode='test')
train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
ataset 定义数据源及其访问方式,映射式最常用,流式数据用 IterableDataset。DataLoader 封装采样、批处理、多进程加载和内存固定等功能,参数丰富。sampler、collate_fn 可以灵活处理各种数据形式和不平衡问题。掌握 Dataset 和 DataLoader 的用法与内部机制,能够让你根据实际需求搭建高效的数据管道,将 I/O 瓶颈降到最低,从而充分释放 GPU 计算能力。