基础组件——Datasets
datasets基本使用
导入包
from datasets import *
加载数据
datasets = load_dataset("madao33/new-title-chinese")
datasets
DatasetDict({
train: Dataset({
features: ['title', 'content'],
num_rows: 5850
})
validation: Dataset({
features: ['title', 'content'],
num_rows: 1679
})
})
加载数据集合集中的某一项子集
boolq_dataset = load_dataset("super_glue", "boolq")
boolq_dataset
DatasetDict({
train: Dataset({
features: ['question', 'passage', 'idx', 'label'],
num_rows: 9427
})
validation: Dataset({
features: ['question', 'passage', 'idx', 'label'],
num_rows: 3270
})
test: Dataset({
features: ['question', 'passage', 'idx', 'label'],
num_rows: 3245
})
})
按照数据集划分进行加载
dataset = load_dataset("madao33/new-title-chinese", split="train")
dataset
Dataset({
features: ['title', 'content'],
num_rows: 5850
})
dataset = load_dataset("madao33/new-title-chinese", split="train[10:100]")
dataset
Dataset({
features: ['title', 'content'],
num_rows: 90
})
dataset = load_dataset("madao33/new-title-chinese", split="train[:50%]")
dataset
Dataset({
features: ['title', 'content'],
num_rows: 2925
})
dataset = load_dataset("madao33/new-title-chinese", split=["train[:50%]", "train[50%:]"])
dataset
[Dataset({
features: ['title', 'content'],
num_rows: 2925
}),
Dataset({
features: ['title', 'content'],
num_rows: 2925
})]
查看数据集
datasets = load_dataset("madao33/new-title-chinese")
datasets
DatasetDict({
train: Dataset({
features: ['title', 'content'],
num_rows: 5850
})
validation: Dataset({
features: ['title', 'content'],
num_rows: 1679
})
})
查看某一个数据
datasets["train"][0]
{'title': '望海楼是危险的赌博',
'content': '近期妥善处理)'}
查看某一些数据
datasets["train"][:2]
{'title': ['望海楼是危险的赌博'],
'content': ['撒打发是',
'在推进“双一流”高校建设进程中']}
查看列名
datasets["train"].column_names
['title', 'content']
查看列属性
{'title': Value(dtype='string', id=None),
'content': Value(dtype='string', id=None)}
数据集划分
可使用train_test_split
这个函数
dataset = datasets["train"]
dataset.train_test_split(test_size=0.1) # 按测试集比例为10%划分
DatasetDict({
train: Dataset({
features: ['title', 'content'],
num_rows: 5265
})
test: Dataset({
features: ['title', 'content'],
num_rows: 585
})
})
对于分类任务,指定标签字段,然后让这个数据集均衡划分标签字段
dataset = boolq_dataset["train"]
dataset.train_test_split(test_size=0.1, stratify_by_column="label") # 分类数据集可以按照比例划分
DatasetDict({
train: Dataset({
features: ['question', 'passage', 'idx', 'label'],
num_rows: 8484
})
test: Dataset({
features: ['question', 'passage', 'idx', 'label'],
num_rows: 943
})
})
数据选取与过滤
# 选取
datasets["train"].select([0, 1])
Dataset({
features: ['title', 'content'],
num_rows: 2
})
# 过滤
## 传入一个lambda函数,让其只取含有中国的数据
filter_dataset = datasets["train"].filter(lambda example: "中国" in example["title"])
filter_dataset["title"][:5]
['世界探寻中国成功秘诀',
'信心来自哪里',
'世界减贫跑出加速度',
'和音瞩目历史交汇点',
'风采感染世界']
数据映射
def add_prefix(example):
example["title"] = 'Prefix: ' + example["title"]
return example
prefix_dataset = datasets.map(add_prefix) # 每个title数据前面添加了前缀
prefix_dataset["train"][:10]["title"]
['Prefix: 危险的',
'Prefix: 大力推进高校治理能力建设',
'Prefix: 坚持事业为上选贤任能',
'Prefix: “大朋友”的话儿记心头',
'Prefix: 用好可持续发展这把“金钥匙”',
'Prefix: 跨越雄关,我们走在大路上',
'Prefix: 脱贫奇迹彰显政治优势',
'Prefix: 拱卫亿万人共同的绿色梦想',
'Prefix: 育人育才',
'Prefix: 净化网络语言']
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
def preprocess_function(example, tokenizer=tokenizer):
model_inputs = tokenizer(example["content"], max_length=512, truncation=True)
labels = tokenizer(example["title"], max_length=32, truncation=True)
# label就是title编码的结果
model_inputs["labels"] = labels["input_ids"]
return model_inputs
processed_datasets = datasets.map(preprocess_function) # 添加了分类标签
processed_datasets
DatasetDict({
train: Dataset({
features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
num_rows: 5850
})
validation: Dataset({
features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
num_rows: 1679
})
})
processed_datasets = datasets.map(preprocess_function, batched=True) # 使用批处理
processed_datasets
DatasetDict({
train: Dataset({
features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
num_rows: 5850
})
validation: Dataset({
features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
num_rows: 1679
})
})
去除某一字段
processed_datasets = datasets.map(preprocess_function, batched=True, remove_columns=datasets["train"].column_names)
processed_datasets
DatasetDict({
train: Dataset({
features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
num_rows: 5850
})
validation: Dataset({
features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
num_rows: 1679
})
})
保存与加载
# 保存
processed_datasets.save_to_disk("./processed_data")
# 加载
processed_datasets = load_from_disk("./processed_data")
加载本地数据集
# 加载本地csv文件
dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split="train")
dataset
Dataset({
features: ['label', 'review'],
num_rows: 7766
})
dataset = Dataset.from_csv("./ChnSentiCorp_htl_all.csv")
dataset
Dataset({
features: ['label', 'review'],
num_rows: 7766
})
加载文件夹内全部文件作为数据集
# 使用data_dir加载全部文件夹内文件
dataset = load_dataset("csv", data_dir="./all_data/", split='train')
dataset
Dataset({
features: ['label', 'review'],
num_rows: 23298
})
# 使用data_files加载文件夹内指定文件
dataset = load_dataset("csv", data_files=["./all_data/ChnSentiCorp_htl_all.csv", "./all_data/ChnSentiCorp_htl_all copy.csv"], split='train')
dataset
Dataset({
features: ['label', 'review'],
num_rows: 15532
})
通过其他方式读取数据,再将其转换成datasets
import pandas as pd
data = pd.read_csv("./ChnSentiCorp_htl_all.csv")
data.head()
dataset = Dataset.from_pandas(data)
dataset
Dataset({
features: ['label', 'review'],
num_rows: 7766
})
# List格式的数据需要内嵌{},明确数据字段
data = [{"text": "abc"}, {"text": "def"}]
# data = ["abc", "def"]
Dataset.from_list(data)
Dataset({
features: ['text'],
num_rows: 2
})
通过自定义加载脚本加载数据集
load_dataset("json", data_files="./cmrc2018_trial.json", field="data")
DatasetDict({
train: Dataset({
features: ['title', 'paragraphs', 'id'],
num_rows: 256
})
})
dataset = load_dataset("./load_script.py", split="train")
dataset
dataset[0]
{'id': 'TRIAL_800_QUERY_0',
'context': '基于《跑跑卡丁车》与《泡泡堂》上所开发的游戏,由韩国Nexon开发与发行。中国大陆由盛大游戏运营,这是Nexon时隔6年再次授予盛大网络其游戏运营权。台湾由游戏橘子运营。玩家以水枪、小枪、锤子或是水炸弹泡封敌人(玩家或NPC),即为一泡封,将水泡击破为一踢爆。若水泡未在时间内踢爆,则会从水泡中释放或被队友救援(即为一救援)。每次泡封会减少生命数,生命数耗完即算为踢爆。重生者在一定时间内为无敌状态,以踢爆数计分较多者获胜,规则因模式而有差异。以2V2、4V4随机配对的方式,玩家可依胜场数爬牌位(依序为原石、铜牌、银牌、金牌、白金、钻石、大师) ,可选择经典、热血、狙击等模式进行游戏。若游戏中离,则4分钟内不得进行配对(每次中离+4分钟)。开放时间为暑假或寒假期间内不定期开放,8人经典模式随机配对,采计分方式,活动时间内分数越多,终了时可依该名次获得奖励。',
'question': '生命数耗完即算为什么?',
'answers': {'text': ['踢爆'], 'answer_start': [127]}}