在 PyTorch 中,DatasetDataLoaderSampler 是用于数据加载和处理的核心组件。它们相互配合,使得数据的加载和批处理更加高效和灵活。

  • Dataset 是一个抽象类,用于表示数据集。
  • DataLoader 是一个迭代器,用于将数据集分成小批量。
  • Sampler 可以自定义更加复杂的采样策略。

1. Dataset

将训练数据(如特征和标签)封装为一个可迭代的 PyTorch Dataset 类。有如下两种方式。

1.1 自定义类

继承 Dataset父类,并实现如下两个方法

  • __len__: 返回数据集的大小。
  • __getitem__: 支持索引操作,返回指定位置的数据和标签。这里可以根据数据类型特点,灵活定义。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from torch.utils.data import Dataset

data = torch.arange(100).reshape(50, 2)  # 50个样本,每个样本有2个特征
labels = torch.randint(0, 2, (50,))     # 50个二分类标签


class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # 返回样本和其对应的标签,可以根据数据特点,灵活定义
        return self.data[idx], self.labels[idx]
    
next(iter(dataset))
# (tensor([0, 1]), tensor(0))

1.2 TensorDataset

对于简单的数组类型数据,可以直接使用TensorDataset(*tensors)

1
2
3
4
5
from torch.utils.data import TensorDataset
dataset_2 = TensorDataset(data, labels)

next(iter(dataset_2))
# (tensor([0, 1]), tensor(0))

2. Dataloader

  • 将数据集拆分为小批量,用于后续的迭代训练
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
from torch.utils.data import DataLoader
# dataset: 要加载的数据集,通常是 Dataset 的实例。
# batch_size: 每个批次的数据量。
# shuffle: 是否在每个 epoch 开始时打乱数据。
# num_workers: 使用多少子进程来加载数据,默认仅使用当前进程。
# drop_last: 如果数据集大小不能被批次大小整除,是否丢弃最后一个不完整的批次。默认为False

dataloader = DataLoader(dataset, 
                        batch_size=4, 
                        shuffle=True, 
                        num_workers=2)

next(iter(dataloader))
# [tensor([[14, 15],
#          [22, 23],
#          [70, 71],
#          [40, 41]]),
#  tensor([0, 1, 0, 0])]

for batch_data, batch_labels in dataloader:
    # 处理每个批次的数据
    pass

2.1 collate_fn 参数

  • 默认情况下,DataLoader 会将多个样本的数据直接拼接成一个批次。
  • collate_fn 是一个函数参数,用于在批处理时进行数据转换或预处理,以更好地适应训练要求。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
def custom_collate_fn(batch):
    # 假设 batch 是一个列表,包含多个 (features, labels) 元组
    features = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    
    # 在这里可以进行序列填充或其他预处理
    # 例如,将序列填充到相同长度
    padded_features = pad_sequence(features, batch_first=True)
    
    return padded_features, torch.tensor(labels)

dataloader = DataLoader(dataset, batch_size=4, collate_fn=custom_collate_fn)

2.2 sampler 参数

使用 sampler 参数可以实现复杂的采样逻辑,适用于需要特定采样策略的场景,如不均衡数据处理、分布式训练等。

  • SequentialSampler: 顺序采样,按照数据集的顺序返回样本索引。
  • RandomSampler: 随机采样,打乱数据集顺序返回样本索引。
  • SubsetRandomSampler: 从数据集的一个子集随机采样。
  • WeightedRandomSampler: 根据指定的权重随机采样,适用于不平衡数据集。
  • DistributedSampler,用于在分布式训练中对数据进行采样。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
## 简单示例
# 随机采样
from torch.utils.data import RandomSampler
sampler = RandomSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler)

# 子集采样
from torch.utils.data import SubsetRandomSampler
import numpy as np
indices = np.random.permutation(len(dataset))[:subset_size]
sampler = SubsetRandomSampler(indices)  #参数为子集的index
dataloader = DataLoader(dataset, sampler=sampler)