昇思学习打卡营第32天|基于ResNet50的中药炮制饮片质量判断模型

news2024/10/5 10:17:53
背景介绍

        中药炮制是根据中医药理论,依照临床用药需求,通过调剂和制剂要求,将中药材制备成中药饮片的过程。老百姓日常使用的中药饮片,是中药炮制技术的成果。中药炮制过程中,尤其是涉及到水火处理时,必须注重“程度适中”。如果炮制火候不足,则无法发挥最好的药效;而火候过度则会使药效丧失。因此,判断炮制程度的准确性直接影响中药的质量和疗效。

        传统上,中药炮制程度主要依赖于经验丰富的老药工判断。然而,随着老药工的减少,经验传承面临挑战。人工智能的兴起为这一问题提供了解决方案,通过图像分类技术,尤其是使用深度学习中的ResNet50模型,我们能够有效判断饮片的炮制状态,智能化再现药工的经验。

ResNet50网络简介

        ResNet50网络由何恺明等人在2015年提出,是ILSVRC 2015年图像分类竞赛的冠军模型。传统的卷积神经网络随着层数加深会出现退化问题,而ResNet网络通过引入残差结构,成功训练了数百甚至上千层的深度神经网络。ResNet50则是基于Bottleneck残差块的50层深度网络,在多种图像分类任务中展现了优异的性能。

准备阶段
配置实验环境

        本实验基于MindSpore框架和华为Ascend平台进行。以下是环境配置的必要步骤:

!pip install mindspore==2.3.0
数据集介绍

        我们使用的中药炮制饮片数据集由成都中医药大学提供,包含三类药材(蒲黄、山楂、王不留行)的不同炮制程度图片:生品、不及、适中、太过。每种状态下包含500张图片,总共12类5000张图像。

数据预处理

        我们将原始4K的图片缩放到1000x1000像素,以适应ResNet50的输入需求。

from PIL import Image
import os

def resize_images(data_dir, target_size=(1000, 1000)):
    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.endswith('.jpg'):
                img = Image.open(os.path.join(root, file))
                img = img.resize(target_size)
                img.save(os.path.join(root, file))

resize_images('dataset/zhongyiyao/')
数据加载与划分

        为了训练和验证模型,我们将数据集分为训练集、验证集和测试集。通过使用sklearn中的train_test_split函数,我们将数据按比例分配,并保证每类样本均匀分布。

from sklearn.model_selection import train_test_split

def split_dataset(data_dir):
    classes = os.listdir(data_dir)
    for class_name in classes:
        images = os.listdir(os.path.join(data_dir, class_name))
        train, test = train_test_split(images, test_size=0.2, random_state=42)
        # 进一步划分验证集
        train, val = train_test_split(train, test_size=0.2, random_state=42)
        # 保存划分后的数据
        # ...

split_dataset('dataset1/zhongyiyao/')
ResNet50模型构建

        在处理完数据后,我们选择了ResNet50作为基础网络,并对其进行微调。我们将最后的全连接层调整为输出12个类别,以适应中药饮片的分类任务。

from mindspore import nn
from mindspore import Model

def build_resnet50(num_classes=12):
    network = resnet50(pretrained=True)
    in_channels = network.fc.in_channels
    network.fc = nn.Dense(in_channels, num_classes)
    return network

network = build_resnet50()
数据加载函数定义

        为了训练模型,我们需要定义一个数据加载器。此函数加载图片并执行图像增强等预处理步骤。

from mindspore.dataset import GeneratorDataset
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
from mindspore import dtype as mstype

class Iterable:
    def __init__(self,data_path):
        self._data = []
        self._label = []
        if data_path.endswith(('JPG','jpg','png','PNG')):
            # 用作推理,所以没有label
            image = Image.open(data_path)
            self._data.append(image)
            self._label.append(0)
        else:
            classes = os.listdir(data_path)
            if '.ipynb_checkpoints' in classes:
                classes.remove('.ipynb_checkpoints')
            for (i,class_name) in enumerate(classes):
                new_path =  data_path+"/"+class_name
                for image_name in os.listdir(new_path):
                    try:
                        image = Image.open(new_path + "/" + image_name)
                        self._data.append(image)
                        self._label.append(i)
                    except:
                        pass
                
    def __getitem__(self, index):
        return self._data[index], self._label[index]

    def __len__(self):
        return len(self._data)

def create_dataset_zhongyao(dataset_dir,usage,resize,batch_size,workers):
    data = Iterable(dataset_dir)
    data_set = GeneratorDataset(data,column_names=['image','label'])
    trans = []
    if usage == "train":
        trans += [
            vision.RandomCrop(700, (4, 4, 4, 4)),
            vision.RandomHorizontalFlip(prob=0.5)
        ]

    trans += [
        vision.Resize((resize,resize)),
        vision.Rescale(1.0 / 255.0, 0.0),
        vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        vision.HWC2CHW()
    ]

    target_trans = transforms.TypeCast(mstype.int32)
    data_set = data_set.map(
        operations=trans,
        input_columns='image',
        num_parallel_workers=workers)

    data_set = data_set.map(
        operations=target_trans,
        input_columns='label',
        num_parallel_workers=workers)

    data_set = data_set.batch(batch_size,drop_remainder=True)
    return data_set
模型训练

      我们采用交叉熵作为损失函数,Momentum优化器进行模型参数优化。通过MindSpore的Model接口进行训练。

from mindspore import Model
from mindspore import nn

loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.001, momentum=0.9)
model = Model(network, loss_fn=loss_fn, optimizer=optimizer, metrics={"accuracy"})

dataset_train = create_dataset_zhongyao('dataset1/zhongyiyao/train', 'train', 224, 32, 4)
dataset_val = create_dataset_zhongyao('dataset1/zhongyiyao/valid', 'valid', 224, 32, 4)

model.train(epochs=50, train_dataset=dataset_train, valid_dataset=dataset_val)
模型评估与推理

        在训练过程中,我们可以实时评估模型的性能,并保存训练效果最好的模型。

from mindspore import save_checkpoint, load_checkpoint, load_param_into_net

def evaluate_and_save_best_model(model, dataset_val, best_ckpt_path):
    best_acc = 0
    for epoch in range(50):
        acc = model.eval(dataset_val, dataset_sink_mode=False)['accuracy']
        print(f"Epoch {epoch}, Accuracy: {acc}")
        if acc > best_acc:
            best_acc = acc
            save_checkpoint(network, best_ckpt_path)
            print("Best model saved.")

evaluate_and_save_best_model(model, dataset_val, 'best_model.ckpt')

        推理部分代码如下,加载训练好的最佳模型,并对新的图片进行分类:

best_ckpt_path = 'best_model.ckpt'
net = resnet50(num_classes=12)
param_dict = load_checkpoint(best_ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

def predict_one(input_img):
    dataset_one = create_dataset_zhongyao(input_img, 'test', 224, 1, 1)
    data = next(dataset_one.create_tuple_iterator())
    output = model.predict(ms.Tensor(data[0]))
    pred = output.asnumpy().argmax(axis=1)
    return pred

print(predict_one('dataset1/zhongyiyao/test/sz_tg/IMG_0001.JPG'))
结果可视化

        我们可以通过可视化训练过程中准确率和损失的变化,直观展示模型的训练效果。

import matplotlib.pyplot as plt

def plot_training_results(acc_list, loss_list):
    epochs = range(1, len(acc_list) + 1)
    plt.subplot(1, 2, 1)
    plt.plot(epochs, acc_list, label="Accuracy")
    plt.title("Accuracy over Epochs")
    plt.subplot(1, 2, 2)
    plt.plot(epochs, loss_list, label="Loss")
    plt.title("Loss over Epochs")
    plt.show()

plot_training_results(acc_list, loss_list)
结语

        通过本次实验,我们成功构建并应用了ResNet50模型,对中药炮制饮片的质量进行了精准的智能化分类判断。中药炮制作为中医药的重要组成部分,其炮制火候的判断历来依赖老药工的丰富经验。然而,随着人工智能技术的迅速发展,我们借助深度学习模型有效地实现了这一经验的传承和智能化,解决了传统经验判断可能失传的问题。通过数据集的准备、网络的构建、模型的训练与验证,我们发现ResNet50在中药饮片分类任务中展现了出色的表现,准确率极高,进一步验证了其在图像分类领域的优势。

        未来,我们将继续探索更多深度学习模型的应用与改进,优化网络结构和算法,进一步提升模型在多样化炮制饮片中的判断能力。同时,也将尝试引入其他先进的算法,例如Transformer等新兴模型,探索它们在中药智能化领域的应用潜力。希望在这个过程中,能够与大家一起不断学习和进步,共同推动中医药智能化发展的新前景,助力中药现代化与人工智能的深度融合,实现更广泛的创新应用。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2189710.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CNN模型对CIFAR-10中的图像进行分类

代码功能 这段代码展示了如何使用 Keras 和 TensorFlow 构建一个卷积神经网络(CNN)模型,用于对 CIFAR-10 数据集中的图像进行分类。主要功能包括: 加载数据:从 CIFAR-10 数据集加载训练和测试图像。 数据预处理&#…

HTTP【网络】

文章目录 HTTPURL(Uniform Resource Lacator) HTTP协议格式HTTP的方法HTTP的状态码HTTP常见的Header HTTP 超文本传输协议,是一个简单的请求-响应协议,HTTP通常运行在TCP之上 URL(Uniform Resource Lacator) 一资源定位符,也就是通常所说的…

NIM简单实践-图像分割

项目背景 我正在学习一个图像分割的 Demo,使用 NVIDIA 提供的预训练大模型进行光学字符检测 (OCDNet) 和光学字符识别 (OCRNet)。这些模型专门为光学字符检测和识别设计,能够自动将图像中的字符进行分割和识别。 预训练模型介绍 OCDNet (Optical Char…

Windows NTLM中继攻击(PortBender二进制可执行文件)

Windows NTLM中继攻击(PortBender二进制可执行文件) 前言 最近在完善自己的一套TTPs(战术、技术和程序)以应对未来的网络作战、项目和攻防演练需求,翻到了PortBender,我觉得不依赖C2和影响主机本身实现这一切非常有趣…

如何使用ssm实现民族大学创新学分管理系统分析与设计+vue

TOC ssm763民族大学创新学分管理系统分析与设计vue 第1章 绪论 1.1 课题背景 二十一世纪互联网的出现,改变了几千年以来人们的生活,不仅仅是生活物资的丰富,还有精神层次的丰富。在互联网诞生之前,地域位置往往是人们思想上不…

Linux 生产者消费者模型

前言 生产者消费者模型(CP模型)是一种十分经典的设计,常常用于多执行流的并发问题中!很多书上都说他很高效,但高效体现在哪里并没有说明!本博客将详解! 目录 前言 一、生产者消费者模型 1.…

绝美的登录界面!滑动切换效果

绝美登录界面&#xff01;添加了管理员账号和测试账号 <!DOCTYPE html> <html lang"zh-CN"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><scri…

RC正弦波振荡电路

0、判断电路能否产生正弦波震荡的条件 如上图所示&#xff0c; Xo:输出量&#xff1b; A:放大器的增益&#xff1b; F:反馈系数。 上式分别为RC正弦波震荡器的幅值条件和相位条件&#xff0c;为了使输出量在合闸后能够有一个从小到大直至平衡在一定幅值的过程&#xff0c;电…

《Linux服务与安全管理》| 配置YUM源并验证

《Linux服务与安全管理》配置YUM源并验证 目录 《Linux服务与安全管理》配置YUM源并验证 任务一&#xff1a;配置本地YUM源 任务二&#xff1a;配置网络YUM源 学生姓名 **** 学号 **** 专业 **** 任务名称 配置YUM源并验证 完成日期 **** 任务目标 知识 了解配…

docker安装kafka-manager

kafkamanager docker安装_mob64ca12d80f3a的技术博客_51CTO博客 # 1、拉取镜像及创建容器 docker pull hlebalbau/kafka-manager docker run -d --name kafka-manager -p 9000:9000 --networkhost hlebalbau/kafka-manager# 2、增设端口 腾讯云# 3、修改防火墙 sudo firewall-…

Salesforce AI 推全新大语言模型评估家族SFR-Judge 基于Llama3构建

在自然语言处理领域&#xff0c;大型语言模型&#xff08;LLMs&#xff09;的发展迅速&#xff0c;已经在多个领域取得了显著的进展。不过&#xff0c;随着模型的复杂性增加&#xff0c;如何准确评估它们的输出就变得至关重要。传统上&#xff0c;我们依赖人类来进行评估&#…

【目标检测】yolo的三种数据集格式

目标检测中数据集格式之间的相互转换--coco、voc、yolohttps://zhuanlan.zhihu.com/p/461488682?utm_mediumsocial&utm_psn1825483604463071232&utm_sourcewechat_session【目标检测】yolo的三种数据集格式https://zhuanlan.zhihu.com/p/525950939?utm_mediumsocial&…

Python小示例——质地不均匀的硬币概率统计

在概率论和统计学中&#xff0c;随机事件的行为可以通过大量实验来研究。在日常生活中&#xff0c;我们经常用硬币进行抽样&#xff0c;比如抛硬币来决定某个结果。然而&#xff0c;当我们处理的是“质地不均匀”的硬币时&#xff0c;事情就变得复杂了。质地不均匀的硬币意味着…

【宽搜】4. leetcode 103 二叉树的锯齿形层序遍历

1 题目描述 题目链接&#xff1a;二叉树的锯齿形层序遍历 2 题目解析 根据题目描述&#xff0c;第一行是从左往右遍历&#xff0c;第二行是从右往左遍历。和层序遍历的区别就是&#xff1a; 在偶数行需要从右往左遍历。 因此&#xff0c;只需要在层序遍历的基础上增加一个变…

网络基础:TCP/IP五层模型、数据在局域网传输和跨网络传输的基本流程、IP地址与MAC地址的简单解析

目录 背景介绍 网络协议 OSI七层模型 TCP/IP五层模型 TCP/IP协议与OS的关系 网络协议的本质 数据在局域网传输的基本流程 MAC地址 报文的封装和解包 补充内容 数据的跨网络传输基本流程 IP地址 IP地址和MAC地址的区别 ​​​ 背景介绍 网络的发展经理了四个阶段…

dijstra算法——单元最短路径算法

Dijkstra算法 用来计算从一个点到其他所有点的最短路径的算法&#xff0c;是一种单源最短路径算法。也就是说&#xff0c;只能计算起点只有一个的情况。Dijkstra的时间复杂度是O(n^2)&#xff0c;它不能处理存在负边权的情况。 算法描述&#xff1a; 设起点为s&#xff0c;d…

云原生(四十六) | MySQL软件安装部署

文章目录 MySQL软件安装部署 一、MySQL软件部署步骤 二、安装MySQL MySQL软件安装部署 一、MySQL软件部署步骤 第一步&#xff1a;删除系统自带的mariadb 第二步&#xff1a;下载MySQL源&#xff0c;安装MySQL软件 第三步&#xff1a;启动MySQL&#xff0c;获取默认密码…

【无标题】提升快递管理效率的必备技能:教你批量查询与导出物流信息

在当今快节奏的商业环境中&#xff0c;快递与物流行业的效率直接关系到企业的运营成本和客户满意度。随着订单量的不断增加&#xff0c;如何高效地管理和追踪大量的物流信息成为了企业面临的一大挑战。批量查询与导出物流信息作为一种高效的数据处理手段&#xff0c;正逐渐成为…

信息安全工程师(33)访问控制概述

前言 访问控制是信息安全领域中至关重要的一个环节&#xff0c;它提供了一套方法&#xff0c;旨在限制用户对某些信息项或资源的访问权限&#xff0c;从而保护系统和数据的安全。 一、定义与目的 定义&#xff1a;访问控制是给出一套方法&#xff0c;将系统中的所有功能和数据…

ElliQ 老年身边的陪伴

前记 国庆回家发现爸爸之前干活脚崴了&#xff0c;找个临时拐杖撑住&#xff0c;我心里很不是滋味。虽然总和爸妈说&#xff0c;不要干重活&#xff0c;但老人总是担心成为儿女的负担&#xff0c;所以只要能动&#xff0c;就找活干。 给爸妈一点零花钱&#xff0c;老妈只收了…