在 PyTorch 中,Dataset
、DataLoader
和 Sampler
是用于数据加载和处理的核心组件。它们相互配合,使得数据的加载和批处理更加高效和灵活。
Dataset
是一个抽象类,用于表示数据集。
DataLoader
是一个迭代器,用于将数据集分成小批量。
Sampler
可以自定义更加复杂的采样策略。
1. Dataset#
将训练数据(如特征和标签)封装为一个可迭代的 PyTorch Dataset
类。有如下两种方式。
1.1 自定义类#
继承 Dataset
父类,并实现如下两个方法
__len__
: 返回数据集的大小。
__getitem__
: 支持索引操作,返回指定位置的数据和标签。这里可以根据数据类型特点,灵活定义。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
import torch
from torch.utils.data import Dataset
data = torch.arange(100).reshape(50, 2) # 50个样本,每个样本有2个特征
labels = torch.randint(0, 2, (50,)) # 50个二分类标签
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# 返回样本和其对应的标签,可以根据数据特点,灵活定义
return self.data[idx], self.labels[idx]
next(iter(dataset))
# (tensor([0, 1]), tensor(0))
|
1.2 TensorDataset#
对于简单的数组类型数据,可以直接使用TensorDataset(*tensors)
1
2
3
4
5
|
from torch.utils.data import TensorDataset
dataset_2 = TensorDataset(data, labels)
next(iter(dataset_2))
# (tensor([0, 1]), tensor(0))
|
2. Dataloader#
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
from torch.utils.data import DataLoader
# dataset: 要加载的数据集,通常是 Dataset 的实例。
# batch_size: 每个批次的数据量。
# shuffle: 是否在每个 epoch 开始时打乱数据。
# num_workers: 使用多少子进程来加载数据,默认仅使用当前进程。
# drop_last: 如果数据集大小不能被批次大小整除,是否丢弃最后一个不完整的批次。默认为False
dataloader = DataLoader(dataset,
batch_size=4,
shuffle=True,
num_workers=2)
next(iter(dataloader))
# [tensor([[14, 15],
# [22, 23],
# [70, 71],
# [40, 41]]),
# tensor([0, 1, 0, 0])]
for batch_data, batch_labels in dataloader:
# 处理每个批次的数据
pass
|
2.1 collate_fn
参数#
- 默认情况下,
DataLoader
会将多个样本的数据直接拼接成一个批次。
collate_fn
是一个函数参数,用于在批处理时进行数据转换或预处理,以更好地适应训练要求。
1
2
3
4
5
6
7
8
9
10
11
12
|
def custom_collate_fn(batch):
# 假设 batch 是一个列表,包含多个 (features, labels) 元组
features = [item[0] for item in batch]
labels = [item[1] for item in batch]
# 在这里可以进行序列填充或其他预处理
# 例如,将序列填充到相同长度
padded_features = pad_sequence(features, batch_first=True)
return padded_features, torch.tensor(labels)
dataloader = DataLoader(dataset, batch_size=4, collate_fn=custom_collate_fn)
|
2.2 sampler
参数#
使用 sampler 参数可以实现复杂的采样逻辑,适用于需要特定采样策略的场景,如不均衡数据处理、分布式训练等。
SequentialSampler
: 顺序采样,按照数据集的顺序返回样本索引。
RandomSampler
: 随机采样,打乱数据集顺序返回样本索引。
SubsetRandomSampler
: 从数据集的一个子集随机采样。
WeightedRandomSampler
: 根据指定的权重随机采样,适用于不平衡数据集。
DistributedSampler
,用于在分布式训练中对数据进行采样。
1
2
3
4
5
6
7
8
9
10
11
12
|
## 简单示例
# 随机采样
from torch.utils.data import RandomSampler
sampler = RandomSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler)
# 子集采样
from torch.utils.data import SubsetRandomSampler
import numpy as np
indices = np.random.permutation(len(dataset))[:subset_size]
sampler = SubsetRandomSampler(indices) #参数为子集的index
dataloader = DataLoader(dataset, sampler=sampler)
|