自定义神经网络入门-----Pytorch

news2024/12/23 23:34:50

文章目录

    • 目标检测的相关评价指标
      • IoU
      • mAP
        • 正例和负例
        • 准确率P
        • 召回率R
        • 准确率ACC
        • P-R曲线--AP
    • nn.Module类
      • 全连接层
      • 感知机类
      • 使用nn.Sequential进行构造
      • 使用randn函数进行简单测试
    • 损失函数nn.functional
    • nn.optim
    • 模型处理
      • 网络模型库torchvision.models
    • 模型Fine-tune和save
    • 参考


目标检测的相关评价指标

IoU

IoU(Intersection of Union)用来评估预测框和真实框的贴合程度来判断检测的质量。其计算方式如下图所示,使用两个边框的交集与并集的比值,就能得到IoU.

在这里插入图片描述

公式如下所示:
I o U A , B = S A ⋂ S B S A ⋃ S B IoU_{A,B}=\frac{S_A\bigcap S_B}{S_A\bigcup S_B} IoUA,B=SASBSASB
利用Python实现IoU计算的代码如下:

# boxA = (Ax1,Ay1,Ax2,Ay2)
# boxB = (Bx1,By1,Bx2,By2)
def iou(boxA,boxB):
    #计算重合部分的上、下、左、右4个边的值
    #因为在该处检测框的坐标位置是以左上角为原点
    #对于左边界,就是比较哪个框的x坐标大,同理对于上边界
    #对于右边界和下边界则相反,哪个小就是intersection的边界
    left_max=max(boxA[0],boxB[0])
    top_max=max(boxA[1],boxB[1])
    right_min=min(boxA[2],boxB[2])
    bottom_min=min(boxA[3],boxB[3])
    #计算重合部分面积
    #和0作比较就是判断是否有交集,没有交集则为0
    inter=max(0,(right_min-left_max)*max(0,(bottom_min-top_max)))
    Sa=(boxA[2]-boxA[0])*(boxA[3]-boxA[1])
    Sb=(boxB[2]-boxB[0])*(boxB[3]-boxB[1])
    #计算并集
    union=Sa+Sb-inter
    #计算iou
    iou=inter/union
    return iou

对于iou,显而易见它的取值范围为[0,1],iou值越大,表示两个框重合的越好,也就是预测的越好。所以我们平时实际使用过程中,通常取某个阈值,当iou大于阈值,就认为其是一个有效的预测,反之,就属于一个无效预测。

mAP

正例和负例

(1)True positives(TP): 被正确地划分为正例的个数,即实际为正例且被分类器划分为正例的实例数(样本数);

(2)False positives(FP): 被错误地划分为正例的个数,即实际为负例但被分类器划分为正例的实例数;

(3)False negatives(FN):被错误地划分为负例的个数,即实际为正例但被分类器划分为负例的实例数;

(4)True negatives(TN): 被正确地划分为负例的个数,即实际为负例且被分类器划分为负例的实例数。

准确率P

表示预测样本中实际正样本数所有正样本数的比例。
p r e c i s i o n = T P T P + F P precision=\frac{TP}{TP+FP} precision=TP+FPTP

召回率R

表示预测样本中实际正样本数占所有预测样本的比例。
r e c a l l = T P T P + F N recall=\frac{TP}{TP+FN} recall=TP+FNTP

准确率ACC

预测样本中预测正确数占所有样本的比例。
a c c = T P + T N T P + F P + T N + F N acc=\frac{TP+TN}{TP+FP+TN+FN} acc=TP+FP+TN+FNTP+TN

P-R曲线–AP

A P = ∫ 0 1 P d R AP=\int_{0}^{1}PdR AP=01PdR

AP代表曲线包围的面积,综合考虑了不同召回率下的准确率,对每个类别的AP进行平均取值就可以得到mAP
$$

$$
查看源图像

nn.Module类

nn.Module类是pytorch提供的一个神经网络类,并在类中实现了各层的定义及前向计算与反向传播机制。在我们搭建自己的神经网络,只需要继承nn.Module类,在初始化中定义模型结构与参数,在函数forward()中编写网络前向传播即可。

全连接层

class Linear(nn.Module):
    def __init__(self, in_dim, out_dim):
        # 调用nn.Module的构造函数
        super(Linear, self).__init__()
        # 使用nn.Parameter来构造需要学习的参数
        self.w = nn.Parameter(torch.randn(in_dim, out_dim))
        self.b = nn.Parameter(torch.randn(out_dim))

    # 在forward函数中实现前向传播过程
    def forward(self, x):
        x = x.matmul(self.w)
        y = x + self.b.expand_as(x)
        return y

感知机类

class Perception(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim):
        super(Perception, self).__init__()
        self.layer1 = Linear(in_dim, hid_dim)
        self.layer2 = Linear(hid_dim, out_dim)
    def forward(self, x):
        x = self.layer1(x)
        y = torch.sigmoid(x)
        y = self.layer2(y)
        y = torch.sigmoid(y)
        return y
   

使用nn.Sequential进行构造

class Perception(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim):
        super(Perception, self).__init__()
        self.layer = nn.Sequential(
          nn.Linear(in_dim, hid_dim),
          nn.Sigmoid(),
          nn.Linear(hid_dim, out_dim),
          nn.Sigmoid()
		)
    def forward(self, x):
        y = self.layer(x)
        return y

使用randn函数进行简单测试

mlp = MLP(2, 6, 2)
#用于输出参数名称,以及参数的数值
for name,param in mlp.named_parameters():
    print(name,param)

data=torch.randn(4,2)
print(data)
print(mlp(data))


#输出结果如下
# layer.0.weight Parameter containing:
# tensor([[-0.6859,  0.0836],
#         [ 0.1598, -0.3973],
#         [-0.4115, -0.1504],
#         [ 0.5661, -0.6860],
#         [ 0.6184,  0.5957],
#         [ 0.4745,  0.0665]], requires_grad=True)
# layer.0.bias Parameter containing:
# tensor([ 0.2333, -0.3779, -0.5412,  0.0109,  0.3617, -0.2331],
#        requires_grad=True)
# layer.2.weight Parameter containing:
# tensor([[-0.3937,  0.3270,  0.3897, -0.1779, -0.2968,  0.3092],
#         [ 0.2164, -0.0256, -0.1913,  0.1014,  0.0315,  0.2570]],
#        requires_grad=True)
# layer.2.bias Parameter containing:
# tensor([-0.1708,  0.2282], requires_grad=True)


# tensor([[ 0.0470, -0.4241],
#         [-1.4543,  1.1945],
#         [-0.0185,  0.4032],
#         [ 1.4869, -0.6973]])
# tensor([[0.4464, 0.6108],
#         [0.4221, 0.6021],
#         [0.4340, 0.6113],
#         [0.4562, 0.6196]], grad_fn=<SigmoidBackward0>)

损失函数nn.functional

from torch import nn
import torch
import torch.nn.functional as F
label=torch.Tensor([0,1,1,0])
pred=torch.Tensor([0.4464,0.6021,0.6113,0.4562])

#实例化nn.CrossEntropy
criterion=nn.CrossEntropyLoss()
loss_nn =criterion(pred,label)
loss_func=F.cross_entropy(pred,label)
print(loss_nn,loss_func)


#输出
#tensor(2.6232) tensor(2.6232)

nn.optim

该库包含了各种常见的优化算法,用来更新参数,包括随机梯度下降SGD、Adam、Adagrad、RMSProp

模型处理

网络模型库torchvision.models

from torchvision import models
vgg=models.vgg16()
print(vgg)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

可以看到vgg16分为了三个层次,分别为提取特征层,均值池化层,分类层.

模型Fine-tune和save

加载预训练模型

  1. 预训练模型

    from torch import nn
    from torchvision import models
    vgg=models.vgg16(pretrained=True)
    
  2. 加载本地模型

    from torch import nn
    from torchvision import models
    import torch
    
    vgg=models.vgg(16)
    state_dict=torch.load("local model path")
    
    vgg.load_state_dict({k:v for k,v in state_dict_items() if k in vgg.state_dict()})
    

冻结部分预训练层

for layer in range(10):
	for p in vgg[layer].parameters():
        p.require_grad=False

模型保存

torch.save({
	'model':model.state_dict(),
	'optimizer':optimizer.state_dict(),
	'model_path.path'
})

参考

[深度学习小常识]什么是mAP?

深度学习之Pytorch物体检测实战

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/126836.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【STM32F4系列】【HAL库】【自制库】模拟IIC从机

介绍 本项目是利用GPIO模拟I2C的从机 网上常见的是模拟I2C主机 本项目是作为一个两个单片机之间低速通信的用法 协议介绍请看,传送门 模拟主机请看这里 从机 功能 实现I2C从机端读写寄存器 编程思路 I2C的从机实现比起主机来麻烦一些 因为SCL的时序是由主机发送,从机需…

【nowcoder】笔试强训Day12

目录 一、选择题 二、编程题 2.1二进制插入 2.2 查找组成一个偶数最接近的两个素数 一、选择题 1.以下方法&#xff0c;哪个不是对add方法的重载? public class Test {public void add( int x,int y,int z){} } A. public int add(int x,int y,float z){return 0;} B.…

Go语言设计与实现 -- WaitGroup, Once, Cond

WaitGroup 我们可以通过 sync.WaitGroup 将原本顺序执行的代码在多个 Goroutine 中并发执行&#xff0c;加快程序处理的速度。 我们来看一下sync.WaitGroup的结构体&#xff1a; type WaitGroup struct {//保证WaitGroup不会被开发者通过再赋值的方式复制noCopy noCopy// 64-…

重学redux之Redux-Thunk高级使用(三)

这是第三篇了,哥们,如果没看过前两篇,可以去看看之前的两篇,有基础的可以直接看,不多说,直接开讲 默认情况下,Redux 的动作是同步调度的,对于任何需要与外部 API 通信或执行副作用的应用程序来说都是一个问题。 Redux 允许中间件位于被分派的动作和到达 reducer 的动…

抖音本地生活的蓬勃发展,离不开服务商的推波助澜

抖音本地生活&#xff0c;已经势不可挡01 抖音公布本地生活成绩单&#xff0c;交易额增长30倍抖音经过6年时间的演变&#xff0c;产品功能日益丰富&#xff0c;已经从内容消费&#xff0c;延续到线上购物、线下团购等领域&#xff0c;从最初的记录美好生活&#xff0c;成为一种…

统计分析工具-FineReport配置SQL Server外接数据库(2)

1. 配置外接数据库 1.1 外接数据库配置入口 外接数据库的配置入口&#xff0c;有三种形式&#xff1a; 1&#xff09;超级管理员第一次登录数据决策系统时&#xff0c;即可为系统配置外接数据库。如下图所示&#xff1a; 2&#xff09;对于使用内置数据库的系统&#xff0c;管…

站点能源低碳目标网,助力网络碳中和 | 华为发布站点能源十大趋势

2022年12月29日&#xff0c;华为今天举办站点能源十大趋势发布会并重磅发布白皮书。发布会上&#xff0c;华为站点能源领域总裁尧权全面解读了能源数字化、低碳网络、站点供电绿色化等站点能源十大趋势。 尧权表示&#xff0c;2022年是不平凡的一年&#xff0c;全球能源危机背…

十、通过网络服务将esp8266引脚状态显示在网页中

ESP8266在服务器模式运行时&#xff0c;我们可以使用浏览器来显示它的引脚状态。 1、实现目标 学习如何通过esp8266建立基本网站&#xff0c;在该网站上实时显示esp8266的引脚值。 2、原理图 FLASH按键与D3引脚连接&#xff0c;可以通过FLASH按键改变D3引脚的电平。当没有按…

中型企业适合用什么样的CRM管理软件,求推荐?

中型企业适合用什么样的CRM管理软件&#xff0c;求推荐&#xff1f; CRM管理软件是现代企业必不可少的管理软件之一&#xff0c;很多企业都会选择CRM管理软件来经营客户资源&#xff0c;但能够精准地选择到适合自己企业的CRM管理软件则是困难的。 中型企业需要与自己业务流程…

数据可视化之finebi和tableau电力系统分析实现对比

通过一个电力系统简单案例&#xff0c;尝试实际执行finebi和Tableau数据可视化设计的各项基本步骤&#xff0c;以熟悉Tableau和finebi数据可视化设计技巧&#xff0c;提高大数据可视化应用能力。 一、工具/准备工作 在开始本实验之前&#xff0c;请认真阅读课程的相关内容。 …

写给小白的TensorFlow的入门课

文章目录前言学习AI的必要性和业务的关系最简单的例子要做什么&#xff1f;数据图形化展示构建计算图形计算图形最小化误差MacOS 中配置运行环境安装验证安装简单模型训练识别数字图片的模型训练Softmax Regression算法大概步骤大致算法实现结语参考链接前言 深度学习就是从大…

抖音电商发布2023年食品健康行业8大趋势,新减负、新养生等成为关键词

2022抖音电商食品健康峰会暨年货盛典在杭州成功举行。抖音电商食品健康行业还联合欧睿共同发布了《2023年度食品健康行业趋势洞察报告》。图片来源&#xff1a;抖音电商抖音电商食品健康行业负责人白华在会上透露&#xff0c;过去一年&#xff0c;抖音电商食品健康行业呈现出有…

虚拟机数据库改密码ERROR 1396 (HY000): Operation ALTER USER failed for ‘root‘@‘localhost‘

注&#xff1a;原因为MySql 8.0.11 换了新的身份验证插件&#xff08;caching_sha2_password&#xff09;, 原来的身份验证插件为&#xff08;mysql_native_password&#xff09;。而客户端工具Navicat Premium12 中找不到新的身份验证插件&#xff08;caching_sha2_password&a…

Java实现多线程

目录 基本概念 1、程序、进程、线程 2、使用线程的优点 3、线程的分类 4、线程的生命周期 多线程的实现方法 1、继承Thread类 2、实现Runnable接口 3、实现Callable接口 4、使用线程池 线程同步 1、同步代码、同步方法 2、同步机制中的锁 3、锁&#xff08;Lock&…

【电商】电商后台---采购管理模块

从供应商的管理到合同的管理&#xff0c;再到商品系统的模块的介绍、商品价格与税率维护策略&#xff0c;不知不觉已经完成了几篇文章&#xff0c;前期的准备工作完成后&#xff0c;接下来就应该进入到采购管理模块了。 几天来一直在构思如何写&#xff0c;写的内容让大家看过觉…

使用天地图加载Geoserver的图层

一、写在前面 在项目中往往使用地图作为底图(比如 天地图卫星图等)&#xff0c;再其上覆盖你的通过geoserver发布自定义图层。本文记录了我的实现方法。 二、过程 2.1 我遇到的难题 遇到难题1&#xff1a;使用无人机拍摄制作的正射影像图有几百MB甚至1个G&#xff0c;直接展示图…

YOLO系列目标检测算法——PP-YOLOE

YOLO系列目标检测算法目录 - 文章链接 YOLO系列目标检测算法总结对比- 文章链接 YOLOv1- 文章链接 YOLOv2- 文章链接 YOLOv3- 文章链接 YOLOv4- 文章链接 Scaled-YOLOv4- 文章链接 YOLOv5- 文章链接 YOLOv6- 文章链接 YOLOv7- 文章链接 PP-YOLO- 文章链接 …

深入浅出面向对象设计模式(Java)

设计模式是什么 设计模式是面向对象的一种思想。 设计模式的基本原则&#xff1f; 单一职责原则开放封闭原则里氏替换原则接口隔离原则依赖翻转原则 基本分类和为什么分为3类&#xff1f; 创建型&#xff08;怎么优雅创建对象&#xff09; 结构性&#xff08;对象的结构&am…

巧用Hibernate 完成多数据库的DDL脚本创建

巧用Hibernate 完成多数据库的DDL脚本创建 spring boot jpa 默认的orm框架就是Hibernate。 由hibernate完成数据库的读写也是主流的方式之一。但是不同数据库之间&#xff0c;建表、建索引的方言语句都有较多差别&#xff0c;很难做到一套SQL在所有数据库上进行执行。 那么Hibe…

C++11之线程库

文章目录一、thread二、mutex三、lock_guard 与 unique_lock1. lock_guard2. unique_lock四、atomic五、condition_variable在 C11 之前&#xff0c;涉及到多线程问题&#xff0c;都是和平台相关的&#xff0c;比如 Windows 和 Linux 下各有自己的接口&#xff0c;这使得代码的…