医学大模型微调·数据处理全流程:炼丹,是自我超越的方法
- 数据清洗脚本
- 数据标注
- 数据核验
- 转为微调格式
- 随机化
数据清洗脚本
-
HTML标签移除
- 去除文本中的所有HTML标签,保留纯文本内容。
-
特殊字符处理
- 替换特殊数字符号(如①②③)为标准阿拉伯数字。
- 移除或替换其他特殊字符,只保留字母、数字、中文和基本标点。
-
表情符号(Emoji)处理
- 移除所有emoji表情,保持文本的纯粹性。
-
空白字符处理
- 去除多余的空格,包括首尾空格和文本中的连续空格。
-
违禁词过滤
- 识别并替换政治、色情等敏感词汇,通常用占位符(如***)替代。
-
文本长度控制
- 过滤掉长度不符合要求的文本,如过短或过长的内容。
- 代码里设为 5 - 1000
-
格式标准化
- 确保文本遵循一致的格式,如统一大小写、标点符号等。
-
语言检测和处理(未在脚本中实现,但常见)
- 识别文本语言,应用相应的处理规则。
-
拼写和语法检查(未在脚本中实现,但常见)
- 纠正常见的拼写错误和语法问题。
-
数据去重(未在脚本中实现,但常见)
- 移除重复的文本内容。
- MinHash-LSH 哈希模糊去重:如何解决医学大模型的大规模数据去重?【独篇】
-
文本规范化(未在脚本中实现,但常见)
- 将文本转换为标准格式,如将所有数字转为文字描述。
-
词形还原或词干提取(未在脚本中实现,但在NLP中常见)
- 将单词转换为其基本形式,便于后续分析。
import re
import emoji
def clean_data(text):
"""主数据清洗函数,按顺序调用所有清洗步骤"""
text = remove_html_tags(text)
text = replace_special_digits(text)
text = remove_emojis(text)
text = remove_special_chars(text)
text = remove_extra_spaces(text)
text = remove_banned_words(text)
text = filter_invalid_length(text)
return text
def remove_html_tags(text):
"""去除HTML标签"""
pattern = re.compile(r'<[^>]+>')
return pattern.sub('', text)
def replace_special_digits(text):
"""替换特殊数字符号为阿拉伯数字"""
digit_map = {
'①': '1', '②': '2', '③': '3', '④': '4', '⑤': '5',
'⑥': '6', '⑦': '7', '⑧': '8', '⑨': '9', '⑩': '10'
}
for special, normal in digit_map.items():
text = text.replace(special, normal)
return text
def remove_emojis(text):
"""去除所有emoji表情"""
return emoji.replace_emoji(text, '')
def remove_special_chars(text):
"""去除特殊字符,只保留字母、数字、中文和一些基本标点"""
pattern = re.compile(r'[^a-zA-Z0-9\u4e00-\u9fa5.,!?:; ]')
return pattern.sub('', text)
def remove_extra_spaces(text):
"""去除多余的空格,包括首尾空格"""
return ' '.join(text.split())
def remove_banned_words(text):
"""去除违禁词(政治、色情等)"""
banned_words = [
"政治敏感词1", "政治敏感词2",
"色情词1", "色情词2",
# 在实际使用时,这里需要一个更完整的违禁词列表
]
for word in banned_words:
text = re.sub(word, '***', text, flags=re.IGNORECASE)
return text
def filter_invalid_length(text, min_length=5, max_length=1000):
"""过滤长度不符合要求的文本"""
if min_length <= len(text) <= max_length:
return text
else:
return ''
def main():
"""主函数,用于测试数据清洗流程"""
sample_texts = [
"<p>这是一个示例文本②,包含emoji😊和一些特殊字符@#$%。还有一些敏感词如政治敏感词1和色情词1。</p>",
"这是一个正常的短文本。",
"这是一个超长的文本" + "很长很长" * 100,
"Html<br>标签和特殊字符%$#@!"
]
print("数据清洗测试:")
for i, text in enumerate(sample_texts, 1):
print(f"\n样本 {i}:")
print("原文本:", text)
cleaned_text = clean_data(text)
print("清洗后:", cleaned_text)
if __name__ == "__main__":
main()
医学还要加几个:
- 标准化医学术语和编码(如ICD编码)
- 处理缺失值
数据标注
标注与注释:
- 聘请专业医生进行标注
- 采用双盲或多人交叉验证的方式以提高准确性
- 标注内容可能包括:
- 疾病诊断
- 症状描述
- 影像特征
- 治疗方案
- 预后信息等
import requests
import json
from typing import List, Dict
import random
# 模拟Qwen API客户端
class QwenAPIClient:
def __init__(self, api_key: str):
self.api_key = api_key
self.api_url = "https://api.qwen.ai/v1/chat/completions" # 假设的API端点
def get_completion(self, prompt: str) -> str:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
data = {
"model": "qwen-turbo",
"messages": [{"role": "user", "content": prompt}]
}
response = requests.post(self.api_url, headers=headers, json=data)
return response.json()["choices"][0]["message"]["content"]
# 医学数据标注系统
class MedicalDataAnnotationSystem:
def __init__(self, api_client: QwenAPIClient):
self.api_client = api_client
def annotate_medical_data(self, data: Dict[str, str]) -> Dict[str, str]:
prompt = f"""
请对以下医学数据进行标注和注释。请提供以下信息:
1. 疾病诊断
2. 症状描述
3. 影像特征(如果适用)
4. 建议的治疗方案
5. 预后信息
患者数据:
年龄:{data['age']}
性别:{data['gender']}
主诉:{data['chief_complaint']}
病史:{data['medical_history']}
检查结果:{data['examination_results']}
请以JSON格式返回结果,包含上述5个方面的信息。
"""
response = self.api_client.get_completion(prompt)
try:
annotation = json.loads(response)
except json.JSONDecodeError:
annotation = {"error": "无法解析API响应"}
return annotation
def cross_validate(self, data: Dict[str, str], num_validators: int = 3) -> Dict[str, str]:
annotations = [self.annotate_medical_data(data) for _ in range(num_validators)]
# 简单的多数投票机制
final_annotation = {}
for key in annotations[0].keys():
values = [ann[key] for ann in annotations]
final_annotation[key] = max(set(values), key=values.count)
return final_annotation
# 主程序
def main():
api_key = "your_qwen_api_key_here"
client = QwenAPIClient(api_key)
annotation_system = MedicalDataAnnotationSystem(client)
# 模拟医学数据
medical_data = {
"age": "45",
"gender": "男",
"chief_complaint": "持续性胸痛,伴有呼吸困难",
"medical_history": "高血压病史5年,吸烟20年",
"examination_results": "心电图显示ST段抬高,心肌酶谱升高"
}
# 进行标注
annotation = annotation_system.cross_validate(medical_data)
print("医学数据标注结果:")
print(json.dumps(annotation, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()
输入:
- 一个包含医学数据的字典
medical_data
,包括:- 年龄 (age)
- 性别 (gender)
- 主诉 (chief_complaint)
- 病史 (medical_history)
- 检查结果 (examination_results)
输出:
- 一个包含标注结果的JSON格式字典
annotation
,包括:- 疾病诊断
- 症状描述
- 影像特征(如果适用)
- 建议的治疗方案
- 预后信息
全流程:
-
初始化:
- 创建 QwenAPIClient 实例
- 创建 MedicalDataAnnotationSystem 实例
-
数据处理:
- 调用 cross_validate 方法,对输入的医学数据进行多次标注
- 对每次标注:
a. 构造包含医学数据的提示词
b. 通过 Qwen API 发送请求
c. 接收并解析 API 响应
-
结果整合:
- 使用简单的多数投票机制整合多次标注结果
- 对每个标注项,选择出现次数最多的结果
-
输出结果:
- 打印最终的标注结果(JSON 格式)
将原始的医学记录转化为结构化的诊断信息、症状描述、治疗建议和预后信息,为医疗决策提供支持。
数据核验
输入:
- 包含医学记录的制表符分隔文件(TSV)
- 每条记录包含:患者ID、记录类型、诊断、症状、检查结果
输出:
- 包含原始数据和核验结果的新TSV文件
- 每条记录新增:准确性评分、一致性评分、完整性评分、发现的问题、改进建议
全流程:
- 读取输入文件
- 对每条记录:
a. 构造查询发送给Qwen API
b. API返回评估结果
c. 解析API响应,提取关键信息 - 将原始数据和评估结果写入输出文件
这个流程自动化了医学数据的质量检查,从原始医学记录到详细的质量评估报告,提高了数据处理的效率和一致性。
import os
import pandas as pd
import requests
import json
import re
from typing import List, Dict
from concurrent.futures import ThreadPoolExecutor, as_completed
class QwenAPI:
def __init__(self, api_key: str):
self.api_key = api_key
self.api_url = "https://api.qwen.ai/v1/chat/completions"
def chat(self, message: str, system: str, config: Dict) -> str:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
data = {
"model": "qwen-turbo",
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": message}
],
**config
}
response = requests.post(self.api_url, headers=headers, json=data)
if response.status_code == 200:
return response.json()["choices"][0]["message"]["content"]
else:
raise Exception(f"API request failed with status code {response.status_code}")
def medical_data_verification(input_file_path: str, output_file_path: str, api_key: str):
if not os.path.exists(input_file_path):
raise Exception(f"{input_file_path} -> 文件不存在")
if os.path.exists(output_file_path):
os.remove(output_file_path)
model = QwenAPI(api_key)
system = '''你是一位经验丰富的医学专家,负责核验医学数据的准确性、一致性和完整性。
你的任务是分析提供的医学记录,并根据给定的标准进行评估。请保持专业、客观,并提供详细的分析。'''
df = pd.read_csv(input_file_path, sep='\t', header=None, names=['patient_id', 'record_type', 'diagnosis', 'symptoms', 'test_results'])
def process_row(row):
message = f'''
请对以下医学记录进行核验,并评估其准确性、一致性和完整性。请考虑以下几个方面:
1. 诊断与症状的一致性
2. 诊断与检查结果的相符性
3. 使用的医学术语是否标准
4. 数据的完整性(是否有关键信息缺失)
5. 可能存在的数据录入错误或异常值
**患者ID:** {row['patient_id']}
**记录类型:** {row['record_type']}
**诊断:** {row['diagnosis']}
**症状:** {row['symptoms']}
**检查结果:** {row['test_results']}
**输出格式:**
请严格按照以下格式提供信息:
准确性评分:[1-5的整数]
一致性评分:[1-5的整数]
完整性评分:[1-5的整数]
发现的具体问题或疑点:[详细描述]
改进建议:[具体建议]
'''
config = {'temperature': 0.3, 'top_p': 0.9, 'max_tokens': 1000}
response = model.chat(message=message, system=system, config=config)
accuracy_score = extract_score(response, "准确性评分")
consistency_score = extract_score(response, "一致性评分")
completeness_score = extract_score(response, "完整性评分")
issues = extract_text(response, "发现的具体问题或疑点")
suggestions = extract_text(response, "改进建议")
return f"{row['patient_id']}\t{row['record_type']}\t{row['diagnosis']}\t{row['symptoms']}\t{row['test_results']}\t" \
f"{accuracy_score}\t{consistency_score}\t{completeness_score}\t{issues}\t{suggestions}"
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(process_row, row) for _, row in df.iterrows()]
with open(output_file_path, 'a', encoding='utf-8') as out_file:
for future in as_completed(futures):
out_file.write(future.result() + '\n')
out_file.flush()
def extract_score(text: str, score_type: str) -> int:
match = re.search(f"{score_type}:(\d)", text)
return int(match.group(1)) if match else 0
def extract_text(text: str, section: str) -> str:
pattern = f"{section}:(.*?)(?:\n|$)"
match = re.search(pattern, text, re.DOTALL)
return match.group(1).strip() if match else ""
if __name__ == "__main__":
input_file = "input_medical_data.tsv"
output_file = "verified_medical_data.tsv"
api_key = "your_qwen_api_key_here" # 替换为您的实际API密钥
medical_data_verification(input_file, output_file, api_key)
print("医学数据核验完成,结果已保存到", output_file)
转为微调格式
微调工具是:LLaMA-Factory
- autodl 上 使用 LLaMA-Factory 微调 中文版 llama3
import os
import json
import csv
def medical_data_to_llama_format(input_file_path, output_file_path):
if not os.path.exists(input_file_path):
raise Exception(f"{input_file_path} -> 文件不存在")
if os.path.exists(output_file_path):
os.remove(output_file_path)
total_records = 0
json_array = []
with open(input_file_path, 'r', encoding='utf-8') as in_file:
reader = csv.reader(in_file, delimiter='\t')
for row in reader:
if len(row) != 10: # 假设输入文件有10列
continue
patient_id, record_type, diagnosis, symptoms, test_results, accuracy_score, consistency_score, completeness_score, issues, suggestions = row
# 构建医学记录摘要
medical_summary = f"患者ID: {patient_id}\n记录类型: {record_type}\n诊断: {diagnosis}\n症状: {symptoms}\n检查结果: {test_results}"
# 构建评估结果摘要
evaluation_summary = f"准确性评分: {accuracy_score}\n一致性评分: {consistency_score}\n完整性评分: {completeness_score}\n发现的问题: {issues}\n改进建议: {suggestions}"
# 构建 LLaMA-Factory 格式的对话
conversation = {
"system": "你是一位经验丰富的医学专家,负责审核和改进医疗记录。",
"conversations": [
{
"from": "human",
"value": f"请审查以下医疗记录并提供改进建议:\n\n{medical_summary}"
},
{
"from": "assistant",
"value": f"根据您提供的医疗记录,我的评估和建议如下:\n\n{evaluation_summary}"
}
]
}
json_array.append(conversation)
total_records += 1
print(f'处理的记录总数:{total_records}')
if json_array:
with open(output_file_path, "w", encoding='utf-8') as out_file:
json.dump(json_array, out_file, indent=4, ensure_ascii=False)
print(f"数据已成功转换并保存到 {output_file_path}")
# 使用示例
input_file = "verified_medical_data.tsv"
output_file = "medical_data_llama_format.json"
medical_data_to_llama_format(input_file, output_file)
输入:
- 一个包含核验后医学数据的TSV(制表符分隔值)文件,名为 “verified_medical_data.tsv”
- 每行包含10列数据:患者ID、记录类型、诊断、症状、检查结果、准确性评分、一致性评分、完整性评分、发现的问题、改进建议
输出:
- 一个符合LLaMA-Factory格式的JSON文件,名为 “medical_data_llama_format.json”
- 包含一系列对话结构,每个结构包括系统提示、人类输入(医疗记录摘要)和AI助手回复(评估结果和建议)
全流程:
-
检查文件:
- 验证输入文件是否存在
- 如果输出文件已存在,则删除它
-
读取并处理数据:
- 打开输入TSV文件
- 逐行读取数据
-
对每条记录进行处理:
- 提取10个字段的数据
- 构建医学记录摘要(包括患者ID、记录类型、诊断、症状、检查结果)
- 构建评估结果摘要(包括各项评分、发现的问题和改进建议)
-
创建LLaMA-Factory格式的对话结构:
- 设置系统提示为医学专家角色
- 人类输入设置为请求审查医疗记录(包含医学记录摘要)
- AI助手回复设置为评估结果和建议(包含评估结果摘要)
-
累积处理结果:
- 将每个处理后的对话结构添加到JSON数组中
-
输出结果:
- 将整个JSON数组写入输出文件
- 使用格式化输出,确保可读性
-
提供处理统计:
- 打印处理的记录总数
实现了从原始医疗记录到AI训练数据的转换,为医疗AI模型的开发提供了基础数据支持。
原始数据:
medical_data = {
"patient_id": "P12345",
"age": 65,
"gender": "男",
"chief_complaint": "胸痛和呼吸困难",
"medical_history": "高血压10年,2型糖尿病5年",
"examination_results": "血压160/95mmHg,心电图显示ST段抬高",
"diagnosis": "急性心肌梗塞",
"treatment_plan": "立即进行经皮冠状动脉介入治疗(PCI)",
"prognosis": "及时治疗后预后良好,需要长期服用抗血小板药物和他汀类药物"
}
alpaca 格式:
[
{
"instruction": "请根据以下医学数据提供诊断和治疗建议。",
"input": "患者ID:P12345\n年龄:65岁\n性别:男\n主诉:胸痛和呼吸困难\n病史:高血压10年,2型糖尿病5年\n检查结果:血压160/95mmHg,心电图显示ST段抬高",
"output": "根据提供的信息,患者很可能患有急性心肌梗塞。建议立即进行以下处理:\n1. 诊断:急性心肌梗塞\n2. 治疗计划:立即进行经皮冠状动脉介入治疗(PCI)\n3. 用药建议:给予抗血小板药物(如阿司匹林、氯吡格雷)和他汀类药物\n4. 预后:及时治疗后预后良好,但需要长期药物治疗和定期随访\n5. 生活方式建议:戒烟、控制饮食、规律运动、严格控制血压和血糖",
"system": "你是一位经验丰富的心脏病专家,擅长处理急性心血管疾病。",
"history": [
["患者的血压读数是多少?", "患者的血压读数是160/95mmHg,属于高血压范围。"],
["心电图结果显示什么?", "心电图显示ST段抬高,这是急性心肌梗塞的典型表现。"]
]
}
]
sharegpt 格式:
[
{
"conversations": [
{
"from": "human",
"value": "我有一位65岁男性患者,主诉胸痛和呼吸困难。他有10年高血压和5年2型糖尿病病史。检查显示血压160/95mmHg,心电图有ST段抬高。请给出诊断和治疗建议。"
},
{
"from": "gpt",
"value": "根据您提供的信息,我的诊断和建议如下:\n\n1. 诊断:急性心肌梗塞\n\n2. 治疗建议:\n a) 立即进行经皮冠状动脉介入治疗(PCI)\n b) 给予抗血小板药物(如阿司匹林、氯吡格雷)\n c) 开始他汀类药物治疗\n d) 控制高血压和糖尿病\n\n3. 预后:及时治疗后预后通常良好,但需要长期药物治疗和生活方式调整\n\n4. 后续管理:\n a) 定期心脏功能检查\n b) 严格控制血压和血糖\n c) 建议戒烟(如适用)\n d) 调整饮食,增加适度运动\n\n请确保患者理解疾病的严重性和及时治疗的重要性。同时,需要全面评估其他潜在的心血管风险因素。"
}
],
"system": "你是一位经验丰富的心脏病专家,擅长处理急性心血管疾病。请基于提供的患者信息给出专业的诊断和治疗建议。",
"tools": "心电图分析工具、冠状动脉造影设备"
}
]
随机化
import os
import pandas as pd
def shuffle_medical_data(input_file_path, output_file_path):
if not os.path.exists(input_file_path):
raise Exception(f"{input_file_path} -> 文件不存在")
# 读取医疗数据文件
df = pd.read_csv(input_file_path, sep='\t', header=0)
# 对数据进行随机排序
df = df.sample(frac=1, random_state=42) # frac=1 表示对所有行进行排序
# 保存到新的 CSV 文件
df.to_csv(output_file_path, index=False, sep='\t', encoding='utf-8')
print(f"医疗数据已随机化并保存到 {output_file_path}")
# 使用示例
input_file = "原始医疗数据.csv"
output_file = "随机化医疗数据.csv"
shuffle_medical_data(input_file, output_file)
输入:
原始医疗数据文件(CSV格式)。这个文件可能包含如下信息:
- 患者ID
- 年龄
- 性别
- 诊断结果
- 实验室检查结果
- 用药记录
- 其他相关医疗信息
输出:
随机化后的医疗数据文件(CSV格式)。这个文件包含与输入文件相同的数据,但行的顺序被随机打乱。
全流程:
-
输入文件检查:
脚本首先检查指定的输入文件是否存在。如果文件不存在,会抛出异常。 -
数据读取:
使用pandas库的read_csv
函数读取输入的CSV文件。假设文件使用制表符(‘\t’)作为分隔符,并且有一个表头行。 -
数据随机化:
使用pandas的sample
函数对整个数据集进行随机排序。frac=1
参数确保所有行都被包含在随机化过程中。random_state=42
设置随机种子,确保结果可重复。 -
数据保存:
将随机化后的数据保存到新的CSV文件中。保留原始的列名(表头),不包含索引列,使用制表符作为分隔符,并以UTF-8编码保存。 -
确认输出:
打印一条消息,确认数据已被随机化并保存到指定的输出文件。
让我用一个具体的例子来说明这个过程:
import pandas as pd
# 模拟输入数据
input_data = pd.DataFrame({
'患者ID': ['P001', 'P002', 'P003', 'P004', 'P005'],
'年龄': [45, 62, 33, 58, 27],
'性别': ['男', '女', '女', '男', '男'],
'诊断': ['高血压', '糖尿病', '哮喘', '关节炎', '偏头痛'],
'血压': ['130/85', '145/90', '120/80', '135/88', '118/75']
})
print("原始数据:")
print(input_data)
# 随机化数据
shuffled_data = input_data.sample(frac=1, random_state=42)
print("\n随机化后的数据:")
print(shuffled_data)
# 保存到文件(这里仅作演示,不实际保存)
# shuffled_data.to_csv('随机化医疗数据.csv', index=False, sep='\t', encoding='utf-8')
这个例子展示了:
- 一个简单的医疗数据集,包含患者ID、年龄、性别、诊断和血压数据。
- 原始数据的样子。
- 使用
sample
函数随机化后的数据。
在实际应用中,数据集可能会大得多,包含更多的列和行。随机化过程会打乱所有行的顺序,但每行内的数据保持不变。这个过程有助于:
- 减少数据顺序可能带来的偏差,特别是在进行机器学习模型训练时。
- 在保护患者隐私的同时,为研究目的准备数据。
- 在进行交叉验证或划分训练/测试集时,确保数据的随机性。
只需要提供输入文件路径和期望的输出文件路径。
脚本会自动完成文件读取、随机化和保存的过程。