基于EifficientNet的视网膜病变识别

news2025/1/17 0:17:36

分析一下代码

model.py

①下面这个方法的作用是:将传入的ch(channel)的个数调整到离它最近的8的整数倍,这样做的目的是对硬件更加友好。

def _make_divisible(ch, divisor=8, min_ch=None):
    if min_ch is None:
        min_ch = divisor
    new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor)
    if new_ch < 0.9 * ch:
        new_ch += divisor
    return new_ch

②定义卷积BAN和激活函数的模块类,groups是用来控制卷积结构使用普通卷积结构还是使用Depwise卷积结构。

class ConvBNActivation(nn.Sequential):
    def __init__(self,
                 in_planes: int,
                 out_planes: int,
                 kernel_size: int = 3,
                 stride: int = 1,
                 groups: int = 1,
                 norm_layer: Optional[Callable[..., nn.Module]] = None,
                 activation_layer: Optional[Callable[..., nn.Module]] = None):
        padding = (kernel_size - 1) // 2
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if activation_layer is None:
            activation_layer = nn.SiLU

        super(ConvBNActivation, self).__init__(nn.Conv2d(in_channels=in_planes,
                                                         out_channels=out_planes,
                                                         kernel_size=kernel_size,
                                                         stride=stride,
                                                         padding=padding,
                                                         groups=groups,
                                                         bias=False),
                                               norm_layer(out_planes),
                                               activation_layer())

③定义InvertedResidualConfig模块,对应每一个MBConv模块的配置参数。

class InvertedResidualConfig:
    def __init__(self,
                 kernel: int,          # 3 or 5
                 input_c: int,
                 out_c: int,
                 expanded_ratio: int,  # 1 or 6
                 stride: int,          # 1 or 2
                 use_se: bool,         # True
                 drop_rate: float,
                 index: str,           # 1a, 2a, 2b, ...
                 width_coefficient: float):
        self.input_c = self.adjust_channels(input_c, width_coefficient)
        self.kernel = kernel
        self.expanded_c = self.input_c * expanded_ratio
        self.out_c = self.adjust_channels(out_c, width_coefficient)
        self.use_se = use_se
        self.stride = stride
        self.drop_rate = drop_rate
        self.index = index

    @staticmethod
    def adjust_channels(channels: int, width_coefficient: float):
        return _make_divisible(channels * width_coefficient, 8)

④定义InvertedResidual模块,即MBConv模块。

class InvertedResidual(nn.Module):
    def __init__(self,
                 cnf: InvertedResidualConfig,
                 norm_layer: Callable[..., nn.Module]):
        super(InvertedResidual, self).__init__()

        if cnf.stride not in [1, 2]:
            raise ValueError("illegal stride value.")

        self.use_res_connect = (cnf.stride == 1 and cnf.input_c == cnf.out_c)

        layers = OrderedDict()
        activation_layer = nn.SiLU

        if cnf.expanded_c != cnf.input_c:
            layers.update({"expand_conv": ConvBNActivation(cnf.input_c,
                                                           cnf.expanded_c,
                                                           kernel_size=1,
                                                           norm_layer=norm_layer,
                                                           activation_layer=activation_layer)})

        layers.update({"dwconv": ConvBNActivation(cnf.expanded_c,
                                                  cnf.expanded_c,
                                                  kernel_size=cnf.kernel,
                                                  stride=cnf.stride,
                                                  groups=cnf.expanded_c,
                                                  norm_layer=norm_layer,
                                                  activation_layer=activation_layer)})

        if cnf.use_se:
            layers.update({"se": SqueezeExcitation(cnf.input_c,
                                                   cnf.expanded_c)})

        layers.update({"project_conv": ConvBNActivation(cnf.expanded_c,
                                                        cnf.out_c,
                                                        kernel_size=1,
                                                        norm_layer=norm_layer,
                                                        activation_layer=nn.Identity)})

        self.block = nn.Sequential(layers)
        self.out_channels = cnf.out_c
        self.is_strided = cnf.stride > 1

        # 只有在使用shortcut连接时才使用dropout层
        if self.use_res_connect and cnf.drop_rate > 0:
            self.dropout = DropPath(cnf.drop_rate)
        else:
            self.dropout = nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        result = self.block(x)
        result = self.dropout(result)
        if self.use_res_connect:
            result += x

        return result

⑤接下来看看EfficientNet是如何实现的,网络参数如下:

代码如下:

class EfficientNet(nn.Module):
    def __init__(self,
                 width_coefficient: float,
                 depth_coefficient: float,
                 num_classes: int = 1000,
                 dropout_rate: float = 0.2,
                 drop_connect_rate: float = 0.2,
                 block: Optional[Callable[..., nn.Module]] = None,
                 norm_layer: Optional[Callable[..., nn.Module]] = None
                 ):
        super(EfficientNet, self).__init__()

        default_cnf = [[3, 32, 16, 1, 1, True, drop_connect_rate, 1],
                       [3, 16, 24, 6, 2, True, drop_connect_rate, 2],
                       [5, 24, 40, 6, 2, True, drop_connect_rate, 2],
                       [3, 40, 80, 6, 2, True, drop_connect_rate, 3],
                       [5, 80, 112, 6, 1, True, drop_connect_rate, 3],
                       [5, 112, 192, 6, 2, True, drop_connect_rate, 4],
                       [3, 192, 320, 6, 1, True, drop_connect_rate, 1]]

        def round_repeats(repeats):
            return int(math.ceil(depth_coefficient * repeats))

        if block is None:
            block = InvertedResidual

        if norm_layer is None:
            norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.1)

        adjust_channels = partial(InvertedResidualConfig.adjust_channels,
                                  width_coefficient=width_coefficient)

        bneck_conf = partial(InvertedResidualConfig,
                             width_coefficient=width_coefficient)

        b = 0
        num_blocks = float(sum(round_repeats(i[-1]) for i in default_cnf))
        inverted_residual_setting = []
        for stage, args in enumerate(default_cnf):
            cnf = copy.copy(args)
            for i in range(round_repeats(cnf.pop(-1))):
                if i > 0:
                    cnf[-3] = 1
                    cnf[1] = cnf[2]

                cnf[-1] = args[-2] * b / num_blocks
                index = str(stage + 1) + chr(i + 97)
                inverted_residual_setting.append(bneck_conf(*cnf, index))
                b += 1

        layers = OrderedDict()

        layers.update({"stem_conv": ConvBNActivation(in_planes=3,
                                                     out_planes=adjust_channels(32),
                                                     kernel_size=3,
                                                     stride=2,
                                                     norm_layer=norm_layer)})

        for cnf in inverted_residual_setting:
            layers.update({cnf.index: block(cnf, norm_layer)})

        last_conv_input_c = inverted_residual_setting[-1].out_c
        last_conv_output_c = adjust_channels(1280)
        layers.update({"top": ConvBNActivation(in_planes=last_conv_input_c,
                                               out_planes=last_conv_output_c,
                                               kernel_size=1,
                                               norm_layer=norm_layer)})

        self.features = nn.Sequential(layers)
        self.avgpool = nn.AdaptiveAvgPool2d(1)

        classifier = []
        if dropout_rate > 0:
            classifier.append(nn.Dropout(p=dropout_rate, inplace=True))
        classifier.append(nn.Linear(last_conv_output_c, num_classes))
        self.classifier = nn.Sequential(*classifier)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)
train.py

①导入模型

from model import efficientnet_b0 as create_model

②不同版本的EfficientNet网络模型,对应不同的输入图像大小。

img_size = {"B0": 224,
                "B1": 240,
                "B2": 260,
                "B3": 300,
                "B4": 380,
                "B5": 456,
                "B6": 528,
                "B7": 600}
    num_model = "B0"

③实例化模型部分,传入类别个数,然后添加设备。

model = create_model(num_classes=args.num_classes).to(device)

④还有一个参数为是否冻结权重,如果为true,只会微调最后一层的1*1的卷积以及FC全连接层结构,如果为false,就会训练全部的网络结构。

    if args.freeze_layers:
        for name, para in model.named_parameters():
            if ("features.top" not in name) and ("classifier" not in name):
                para.requires_grad_(False)
            else:
                print("training {}".format(name))
predict.py

①创建模型,num_classes的大小需要根据自己的数据集类别个数进行改变,不仅这里需要改变,训练时也需要改变。

model = create_model(num_classes=5).to(device)

②导入之前训练的模型权重即可。

model_weight_path = "./weights/model-29.pth"

开始训练


1. 在train.py脚本中将--data-path设置成解压后的视网膜病变数据集文件夹的绝对路径。


2.下载预训练权重,根据自己使用的模型下载对应预训练权重。


3. 在train.py脚本中将--weights参数设成下载好的预训练权重路径。


4. 设置好数据集的路径--data-path以及预训练权重的路径--weights就能使用train.py脚本开始训练了(训练过程中会自动生成class_indices.json文件)。


5. 在predict.py脚本中导入和训练脚本中同样的模型,并将model_weight_path设置成训练好的模型权重路径(默认保存在weights文件夹下)。


6. 在predict.py脚本中将img_path设置成你自己需要预测的图片绝对路径。


7. 设置好权重路径model_weight_path和预测的图片路径img_path就能使用predict.py脚本进行预测了。


8. 数据集必须按照视网膜病变数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的num_classes设置成正确的数据类别数。

基于efficientnetb0.pth预训练权重的训练结果如下:

基于efficientnetb7.pth预训练权重的训练结果如下(粉色部分):

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1685484.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

学习Uni-app开发小程序Day23

今天学习了将上一章的所有核算的js&#xff0c;抽离出去&#xff0c;让在其他地方可以直接调用&#xff0c;然后和适配抖音的办法&#xff0c;封装网络请求&#xff1b; 抽离公共方法 如何将公共方法抽离&#xff1f; 1、在根目录创建一个目录&#xff0c;一般起名是:utils 2…

谷歌快速收录怎么做?

快速收录顾名思义&#xff0c;就是让新的的网页内容能够迅速被谷歌搜索引擎抓取、索引和显示在搜索结果中&#xff0c;这对于做seo来说非常重要&#xff0c;因为它有助于新发布的内容尽快出现在谷歌的搜索结果中&#xff0c;从而增加网站的流量 想做谷歌快速收录谷歌推荐了几种…

告别硬编码:Spring条件注解优雅应对多类场景

一、背景 在当今的软件开发中&#xff0c;服务接口通常需要对应多个实现类&#xff0c;以满足不同的需求和场景。举例来说&#xff0c;假设我们是一家2B公司&#xff0c;公司的产品具备对象存储服务的能力。然而&#xff0c;在不同的合作机构部署时&#xff0c;发现每家公司底…

每天五分钟深度学习框架PyTorch:创建具有特殊值的tensor张量

本文重点 tensor张量是一个多维数组,本节课程我们将学习一些pytorch中已经封装好的方法,使用这些方法我们可以快速创建出具有特殊意义的tensor张量。 创建一个值为空的张量 import torch import numpy as np a=torch.empty(1) print(a) print(a.dim()) print(s.shape) 如图…

python3如何查看是32位还是64位

在安装一些python的软件包时&#xff0c;经常安装错误&#xff0c;可能是跟python的位数有关系。 下面告诉大家如何查看python的位数。 第一种方法&#xff1a;通过在cmd中输入“python”即可。 第二种方法&#xff1a;通过platform包查看&#xff0c;首先导入platform包&…

Nginx企业级负载均衡:技术详解系列(11)—— 实战一机多站部署技巧

你好&#xff0c;我是赵兴晨&#xff0c;97年文科程序员。 工作中你是否遇到过这种情况&#xff1a;公司业务拓展&#xff0c;新增一个域名&#xff0c;但服务器资源有限&#xff0c;只能跟原有的网站共用同一台Nginx服务器。 也就是说两个网站的域名都指向同一台Nginx服务器…

金融信贷风控基础知识

一、所谓风控(What && Why) 所谓风控&#xff0c;可以拆解从2个方面看&#xff0c;即 风险和控制 风险(what) 风险 这里狭隘的特指互联网产品中存在的风险点&#xff0c;例如 账户风险 垃圾注册账号账号被泄露盗用 交易支付风险 刷单&#xff1a;为提升卖家店铺人气…

小程序-滚动触底-页面列表数据无限加载

// index/index.vue <template> <!-- 自定义导航栏 --> <CustomNavbar /> <scroll-view scrolltolower"onScrolltolower" scroll-y class"scroll-view"> <!-- 猜你喜欢 --> <Guess ref"guessRef" /> </s…

利用Python队列生产者消费者模式构建高效爬虫

目录 一、引言 二、生产者消费者模式概述 三、Python中的队列实现 四、生产者消费者模式在爬虫中的应用 五、实例分析 生产者类&#xff08;Producer&#xff09; 消费者类&#xff08;Consumer&#xff09; 主程序 六、总结 一、引言 随着互联网的发展&#xff0c;信…

利用Anaconda+Pycharm配置PyTorch完整过程

说在前面&#xff1a;这篇是记录贴&#xff0c;因为被配置环境折磨了大半天&#xff0c;所以记录下来下次方便配置&#xff0c;有点像流水账&#xff0c;有不懂的地方可以评论问。 参考文章&#xff1a; https://blog.csdn.net/m0_48609250/article/details/129402319 环境&…

Android:使用Kotlin搭建MVC架构模式

一、简介Android MVC架构模式 M 层 model &#xff0c;负责处理数据&#xff0c;例如网络请求、数据变化 V 层 对应的是布局 C 层 Controller&#xff0c; 对应的是Activity&#xff0c;处理业务逻辑&#xff0c;包含V层的事情&#xff0c;还会做其他的事情&#xff0c;导致 ac…

ChineseOcr Lite Ncnn:高效轻量级中文OCR工具

目录结构 前言opencv编译编译命令编译结果 ncnn设置OcrLiteNcnn编译OcrLiteNcnn1.8.0源码下载OcrLiteNcnn1.8.0源码编译 OCR图片文本识别测试编译文件测试命令编译文件测试输出 模型下载相关链接 前言 ChineseOcr Lite Ncnn&#xff0c;超轻量级中文OCR PC Demo&#xff0c;支…

AI率怎么降低?有哪些论文降重降AI率的工具和方法?

关于aigc降重怎么降重&#xff1f;论文降重有哪些方法&#xff1f;有没有好用的降重软件&#xff1f;网上很多大神都有回答&#xff0c;但是最近还是会有很多学弟学妹会问这些问题&#xff01; 有没有发现论文降重像玄学一样复杂&#xff1f;最近刚完成一篇论文&#xff0c;使…

Python数据可视化(五)

实现GUI效果 借助 matplotlib&#xff0c;除可以绘制动画内容外&#xff0c;还可以实现用户图形界面的效果&#xff0c;也就是 GUI 效果。 GUI是用户使用界面的英文单词首字母的缩写。接下来&#xff0c;我们就以模块widgets中的类RadioButtons、 Cursor 和 CheckButtons 的使用…

说说什么是AOP,以及AOP的具体实现场景(外卖中应用)

推荐B站&#xff1a;【Spring AOP】实际开发中到底有什么用&#xff1f;_哔哩哔哩_bilibili 一、AOP的原理 AOP即Aspect Oriented Program&#xff0c;面向切面编程&#xff0c;是面向对象编程(OOP)的一种增强模式&#xff0c;可以将项目中与业务无关的&#xff0c;却为业务模…

Spark-广播变量详解

Spark概述 Spark-RDD概述 1.为什么会需要广播变量&#xff1f; 广播变量是为了在分布式计算环境中有效地向集群中的所有节点广播大型只读数据集而设计的。 在分布式环境中&#xff0c;通常会遇到需要在所有节点上使用相同的数据集的情况&#xff0c;但是将这些数据集复制到每个…

以及Spring中为什么会出现IOC容器?@Autowired和@Resource注解?

以及Spring中为什么会出现IOC容器&#xff1f;Autowired和Resource注解&#xff1f; IOC容器发展史 没有IOC容器之前 首先说一下在Spring之前&#xff0c;我们的程序里面是没有IOC容器的&#xff0c;这个时候我们如果想要得到一个事先已经定义的对象该怎么得到呢&#xff1f;…

数据结构(树)

1.树的概念和结构 树&#xff0c;顾名思义&#xff0c;它看起来像一棵树&#xff0c;是由n个结点组成的非线性的数据结构。 下面就是一颗树&#xff1a; 树的一些基本概念&#xff1a; 结点的度&#xff1a;一个结点含有的子树的个数称为该结点的度&#xff1b; 如上图&#…

Python | Leetcode Python题解之第107题二叉树的层序遍历II

题目&#xff1a; 题解&#xff1a; class Solution:def levelOrderBottom(self, root: TreeNode) -> List[List[int]]:levelOrder list()if not root:return levelOrderq collections.deque([root])while q:level list()size len(q)for _ in range(size):node q.popl…

夏天晚上热,早上凉怎么办?

温差太大容易引起感冒 1.定个大概3点的闹钟&#xff0c;起来盖被子。有些土豪可以开空调&#xff0c;我这个咸鱼没有空调。 2.空调调到合适的温度&#xff0c;比如20几度。