https://huggingface.co/docs/datasets/index

A Dataset provides fast random access to the rows, and memory-mapping so that loading even large datasets only uses a relatively small amount of device memory.

But for really, really big datasets ( > 100G) that won’t even fit on disk or in memory, an IterableDataset allows you to access and use the dataset without waiting for it to download completely!

1. 读取

1.1 以Json文件读取为例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
## 预先下载解压到本地的datasets文件夹
# wget https://github.com/crux82/squad-it/raw/master/SQuAD_it-train.json.gz
# wget https://github.com/crux82/squad-it/raw/master/SQuAD_it-test.json.gz
# gzip -dkv SQuAD_it-*.json.gz

from datasets import load_dataset

squad_it_dataset = load_dataset("json", data_files="./datasets/SQuAD_it-train.json", field="data")
# field参数为JSON文件特有,用于指定 JSON 文件中包含实际数据的字段名
squad_it_dataset # 默认读取为train split
squad_it_dataset.keys()
# dict_keys(['train'])
squad_it_dataset
# DatasetDict({
#     train: Dataset({
#         features: ['title', 'paragraphs'],
#         num_rows: 442
#     })
# })

1.2 JSON的其它形式读取

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
# (1) 同时读取两个split
data_files = {"train": "SQuAD_it-train.json", "test": "SQuAD_it-test.json"}
squad_it_dataset = load_dataset("json", data_files=data_files, field="data")

# (2) 读取gz压缩格式
data_files = {"train": "SQuAD_it-train.json.gz", "test": "SQuAD_it-test.json.gz"}
squad_it_dataset = load_dataset("json", data_files=data_files, field="data")

# (3) 远程读取
url = "https://github.com/crux82/squad-it/raw/master/"
data_files = {
    "train": url + "SQuAD_it-train.json.gz",
    "test": url + "SQuAD_it-test.json.gz",
}
squad_it_dataset = load_dataset("json", data_files=data_files, field="data")

data_files可以为每个split指定多个文件,https://huggingface.co/docs/datasets/loading

1.3 表格基本操作

1
2
3
4
5
6
7
8
squad_it_dataset["train"].column_names #查看列名
squad_it_dataset.features

squad_it_dataset["train"].features     #查看每列的详细信息

squad_it_dataset["train"][0]           #第一行

squad_it_dataset["train"]["title"]	   #title列

此外还支持,CSV/Parquet/Arrow/SQL等文件格式

  • CVS泛指表格类文件,可以设置分隔符,例如tsv
  • Parquet: Large datasets may be stored in a Parquet file because it is more efficient and faster at returning your query.
  • Arrow: Datasets库使用的方式

也支持字典Dict,Pandas表格,Generator等

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
from datasets import Dataset
## 字典
my_dict = {"a": [1, 2, 3]}
Dataset.from_dict(my_dict)
# 等价于
my_list = [{"a": 1}, {"a": 2}, {"a": 3}]
dataset = Dataset.from_list(my_list)

## Pandas
import pandas as pd
df = pd.DataFrame({"a": [1, 2, 3]})
dataset = Dataset.from_pandas(df)

2. 处理

以一个CSV文件为例

1
2
3
4
5
6
7
8
## 预先下载解压到本地的datasets文件夹
# wget "https://archive.ics.uci.edu/ml/machine-learning-databases/00462/drugsCom_raw.zip"
# unzip drugsCom_raw.zip

from datasets import load_dataset

data_files = {"train": "datasets/drugsComTrain_raw.tsv", "test": "datasets/drugsComTest_raw.tsv"}
drug_dataset = load_dataset("csv", data_files=data_files, delimiter="\t")

2.1 行操作

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
## 基本查看
drug_dataset["train"][:3]
drug_dataset["train"].select(range(1000)) # by index

## 乱序
drug_dataset["train"].shuffle(seed=42)

## 过滤行
drug_dataset.filter(lambda x: x["condition"] is not None)
drug_dataset.filter(lambda x: x["review_length"] > 30)
drug_dataset.filter(lambda x: x["sentence1"].startswith("Ar"))
dataset.filter(lambda example, idx: idx % 2 == 0, with_indices=True)

## 排序
drug_dataset["train"].sort("review_length")[:3]

2.2 列操作

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
drug_dataset['train']

for col1, col2 in zip(drug_dataset['train']['drugName'][:5], drug_dataset['train']['rating'][:5]):
    print(f"drugName: {col1}, rating: {col2}")

## 删除列
drug_dataset.remove_columns(["drugName", "rating"])
## 选择列
drug_dataset.select_columns(['sentence1', 'sentence2', 'idx'])
    
## 修改列名
drug_dataset.rename_column(
    original_column_name="Unnamed: 0", new_column_name="patient_id"
)

## 修改列内容(.map)
def lowercase_condition(example):
    return {"condition": example["condition"].lower()}
drug_dataset.map(lowercase_condition)

## 新增列
def compute_review_length(example):
    return {"review_length": len(example["review"].split())}
drug_dataset = drug_dataset.map(compute_review_length)

默认情况下,map() 对每个样本单独调用该函数。如果设置 batched=True,则函数会接收一批样本作为输入,这样可以进行批量加速处理。

1
2
3
4
5
6
7
8
9
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("./datasets/bert-base-cased")

def tokenize_function(examples):
    return tokenizer(examples["review"], truncation=True)

tokenized_dataset = drug_dataset.map(tokenize_function, batched=True)
# 默认batch_size为1000

2.3 设置输出格式

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# only changes the *output format* of the dataset
drug_dataset.set_format("pandas")
drug_dataset["train"][:3]

# 此时,可以转换为pandas.DataFrame,执行相关操作
train_df = drug_dataset["train"][:]
frequencies = (
    train_df["condition"]
    .value_counts()
    .to_frame()
    .reset_index()
    .rename(columns={"index": "condition", "condition": "frequency"})
)
frequencies.head()

# from datasets import Dataset
# freq_dataset = Dataset.from_pandas(frequencies)

# reset from 'pandas' to 'arrow'
drug_dataset.reset_format()

# 设置为torch tensor
drug_dataset.set_format("torch")
drug_dataset["train"][:1]

2.4 Split拆分

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
drug_dataset_clean = drug_dataset["train"].train_test_split(train_size=0.8, seed=42)
drug_dataset_clean.keys()
# dict_keys(['train', 'test'])

drug_dataset_clean["validation"] = drug_dataset_clean.pop("test")
drug_dataset_clean.keys()
# dict_keys(['train', 'validation'])

drug_dataset_clean["test"] = drug_dataset["test"]
drug_dataset_clean.keys()
# dict_keys(['train', 'validation', 'test'])

3. 保存

3.1 保存为arrow格式

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
drug_dataset_clean.save_to_disk("datasets/demo-drug-reviews")
# demo-drug-reviews/
# ├── dataset_dict.json
# ├── test
# │   ├── dataset.arrow
# │   ├── dataset_info.json
# │   └── state.json
# ├── train
# │   ├── dataset.arrow
# │   ├── dataset_info.json
# │   ├── indices.arrow
# │   └── state.json
# └── validation
#     ├── dataset.arrow
#     ├── dataset_info.json
#     ├── indices.arrow
#     └── state.json

# 加载
from datasets import load_from_disk
drug_dataset_reloaded = load_from_disk("datasets/demo-drug-reviews")
drug_dataset_reloaded

3.2 其它格式

  • 对于json等其它格式(csv, parquet, sql),需要分别保存每个split
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# 保存
for split, dataset in drug_dataset_clean.items():
    dataset.to_json(f"drug-reviews-{split}.jsonl")
    
# 加载
data_files = {
    "train": "drug-reviews-train.jsonl",
    "validation": "drug-reviews-validation.jsonl",
    "test": "drug-reviews-test.jsonl",
}
drug_dataset_reloaded = load_dataset("json", data_files=data_files)