PyTorch深度学习(三)【Logistic Regression、处理多维特征的输入】

news2024/12/23 22:19:57

Logistic Regression  这个名字叫做回归,做的是分类。

线性和logistic的模型:

使用的损失函数:二分类交叉熵

(这个也叫做BCELoss)

logistic要做的事:

代码:

import torch# import torch.nn.functional as F# prepare datasetx_data = torch.Tensor([[1.0], [2.0], [3.0]])   #数据准备y_data = torch.Tensor([[0], [0], [1]])         #第0类,第1类# design model using classclass LogisticRegressionModel(torch.nn.Module):def __init__(self):   #这里跟线性模型是一样的,没什么区别。因为没有参数,在构造函数里面不用初始化        super(LogisticRegressionModel, self).__init__()        self.linear = torch.nn.Linear(1, 1)def forward(self, x):# y_pred = F.sigmoid(self.linear(x))        y_pred = torch.sigmoid(self.linear(x))    #先用linear做一下线性变换,再把sigmoid函数应用到计算出来的结果上面作为最后的输出。线性模型和logistic多了sigmoid这一步return y_predmodel = LogisticRegressionModel()# construct loss and optimizer# 默认情况下,loss会基于element平均,如果size_average=False的话,loss会被累加。criterion = torch.nn.BCELoss(size_average=False)              #损失也用得不一样与线性模型(MSE)相比optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# training cycle forward, backward, updatefor epoch in range(1000):    y_pred = model(x_data)    loss = criterion(y_pred, y_data)    print(epoch, loss.item())    optimizer.zero_grad()    loss.backward()    optimizer.step()print('w = ', model.linear.weight.item())print('b = ', model.linear.bias.item())x_test = torch.Tensor([[4.0]])y_test = model(x_test)print('y_pred = ', y_test.data)

关于上述的代码产生的结果进行可视化:

import numpy as npimport matplotlib. pyplot as pltx=np. linspace (0,10,200)x_t = torch.Tensor(x).view((200,1))y_t = model(x_t)y = y_t.data. numpy ()plt.plot(x,y)plt.plot([0,10],[0.5,0.5],c='r')plt.xlabel(' Hours')plt.ylabel(' Probability of Pass')plt.grid()plt.show()

练习:(关于BCELoss函数)

import mathimport torchpred = torch.tensor([[-0.2],[0.2],[0.8]])target = torch.tensor([[0.0],[0.0],[1.0]])sigmoid = torch.nn.Sigmoid()pred_s = sigmoid(pred)print(pred_s)"""pred_s 输出tensor([[0.4502],[0.5498],[0.6900]])0*math.log(0.4502)+1*math.log(1-0.4502)0*math.log(0.5498)+1*math.log(1-0.5498)1*math.log(0.6900) + 0*log(1-0.6900)"""result = 0i=0for label in target:if label.item() == 0:        result +=  math.log(1-pred_s[i].item())else:        result += math.log(pred_s[i].item())    i+=1result /= 3print("bce:", -result)loss = torch.nn.BCELoss()print('BCELoss:',loss(pred_s,target).item())

处理多维特征的输入

数据集:文末分享。

模型:

损失和优化:(BCE)

代码:

import numpy as npimport torchimport matplotlib.pyplot as plt# prepare datasetxy = np.loadtxt('diabetes.csv', delimiter=',', dtype=np.float32)   #“delimiter”分隔符x_data = torch.from_numpy(xy[:, :-1])  # 第一个‘:’是指读取所有行,第二个‘:’是指从第一列开始,最后一列不要y_data = torch.from_numpy(xy[:, [-1]])  # [-1] 最后得到的是个矩阵# design model using classclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6)  # 输入数据x的特征是8维,x有8个特征;输出维度是6self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 1)  # 三个线性模型self.sigmoid = torch.nn.Sigmoid()     # 将其看作是网络的一层,而不是简单的函数使用。nn.Sigmoid()是一个模块,,继承自Module,没有参数;用它来做计算图#激活函数可以更改,比如改为“torch.nn.ReLU()”def forward(self, x):        x = self.sigmoid(self.linear1(x))     #如果前面的激活函数改为了ReLU,这里就需要改为“self.activate(self.linear1(x)) ”        x = self.sigmoid(self.linear2(x))        x = self.sigmoid(self.linear3(x))  # y hatreturn xmodel = Model()# construct loss and optimizer# criterion = torch.nn.BCELoss(size_average = True)criterion = torch.nn.BCELoss(reduction='mean')optimizer = torch.optim.SGD(model.parameters(), lr=0.1)epoch_list = []loss_list = []# training cycle forward, backward, updatefor epoch in range(100):    y_pred = model(x_data)      #所有的数据加载进来    loss = criterion(y_pred, y_data)    print(epoch, loss.item())    epoch_list.append(epoch)    loss_list.append(loss.item())    optimizer.zero_grad()    loss.backward()    optimizer.step()   #更新plt.plot(epoch_list, loss_list)plt.ylabel('loss')plt.xlabel('epoch')plt.show()

ok。

可以使用不同的激活函数去尝试,比如:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1024732.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java基于SpringBoot的校园疫情防控系统

文章目录 第一章2.主要技术第三章第四章 系统设计4.1功能结构4.2 数据库设计4.2.1 数据库E/R图4.2.2 数据库表 第五章 系统功能实现5.1系统功能模块5.2后台功能模块5.2.1管理员功能 源码咨询 第一章 springboot校园疫情防控系统演示录像2022 一个好的系统能将校园疫情防控的管理…

VB求平均值

VB求平均值 Private Function pj(x() As Integer) As SingleDim m%, n%, i%, s%m LBound(x): n UBound(x)For i m To ns s x(i)Next ipj s / (n - m 1) End Function Private Sub Command1_Click()Dim a%(1 To 10), i%, aver!For i 1 To 10a(i) Int(Rnd() * 10) 随机…

IMX6ULL移植篇-Linux内核编译

一. Linux内核 Linux 官网为 https://www.kernel.org ,所以你想获取最新的 Linux 版本就可以在这个网站上下载。 Linux-4.x 版本 的 Linux 和 5.x 版本没有本质上的区别, 5.x 更多的是加入了一些新的平台、新的外设驱动而已。 NXP 会从网址 …

提升科研可复现性:和鲸聚焦 AI for Science 全生命周期管理

今年三月,国家科技部会同自然科学基金委正式启动“人工智能驱动的科学研究(AI for Science)”专项部署工作。数据驱动的科学研究长期以来面临诸多困境,针对传统科研工作流中过度依赖人类专家经验与体力的局限性,AI4S 旨…

优化软件系统,解决死锁问题,提升稳定性与性能 redis排队下单

项目背景: 随着用户数量的不断增加,我们的速卖通小管家软件系统面临了一个日益严重的问题:在从存储区提供程序的数据读取器中进行读取时,频繁出现错误。系统报告了一个内部异常: 异常信息如下: 从存储区提供程序的数…

nvme各模块间的关系总结

目录:driver/host/nvme/makefile # SPDX-License-Identifier: GPL-2.0 ccflags-y -I$(src)obj-$(CONFIG_NVME_CORE) nvme-core.o obj-$(CONFIG_BLK_DEV_NVME) nvme.o obj-$(CONFIG_NVME_FABRICS) nvme-fabrics.o obj-$(CONFIG_NVME_RDMA) nvme-rdma.…

02、Servlet核心技术(下)

目录 1 ServletJDBC应用(重点) 2 重定向和转发(重点) 2.1 重定向的概述 2.2 转发的概述 3 Servlet线程安全(重点) 4 状态管理(重点 ) 5 Cookie技术(重点&#xf…

26 环形链表II

环形链表 II 题解1 哈希表题解2 双指针 给定一个链表的头节点 head ,返回链表开始入环的第一个节点。 如果链表无环,则返回 null。 如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达,则链表中存在环。 为了表示给定链表…

pgzrun 拼图游戏制作过程详解(10)

10. 拼图游戏继续升级——多关卡拼图 初始化列表Photos用来储存拼图文件名,Photo_ID用来统计当下是第几张拼图,Squares储存当下拼图的24张小拼图的文件名,Gird储存当下窗口上显示的24个小拼图及坐标。 Photos["girl_","boy_…

“顽固”——C语言用栈实现队列

解题图解: 1、 先用stack1存储push来的数据 2、每当要pop数据时,从stack2中取,如果 stack2为空,就先从stack1中“倒”数据到stack2。 这就是用栈实现队列的基本操作 这道题看起来比较容易,但是!如果你用C语…

jupyter notebook插件安装及插件推荐

安装插件 安装插件选择的工具栏 pip install jupyter_contrib_nbextensions将插件工具栏添加到jupyter notebook页面 jupyter contrib nbextension installdisable configuration for nbextensions without explicit compatibility (they may break your notebook environme…

《Kubernetes部署篇:Ubuntu20.04基于containerd部署kubernetes1.25.14集群(多主多从)》

一、架构图 如下图所示: 二、环境信息 1、资源下载基于containerd部署容器版kubernetes1.25.14集群资源合集 2、部署规划主机名K8S版本系统版本内核版本IP地址备注k8s-master-121.25.14Ubuntu 20.04.5 LTS5.15.0-69-generic192.168.1.12master节点 + etcd节点k8s-master-131.…

【超实用】2023年,学生上班族如何简单快速,低成本的搭建一个博客网站

文章目录 前言实操环节香港虚拟机购买博客搭建ssl证书配置备份设置 总结 前言 因为工作和生活的需要,我一直有博客的搭建需求。我将总结下来,为读者提供参考。  起初,我采用的是香港云虚拟主机,这种虚拟机极其便宜(一…

java接入烽火科技拾音器详细步骤

1 背景 项目中需要拾音器去采集音频数据并保存成mp3这种音频文件,以便以后如果有纠纷后可以作为证据去减少纠纷,于是采购了一台烽火科技的拾音器设备,包括一个采音器及一个处理终端。 2 接线 设备拿过来第一件事是接线,通电&…

WampServer下载安装+cpolar内网穿透实现公网访问本地服务【内网穿透】

文章目录 前言1.WampServer下载安装2.WampServer启动3.安装cpolar内网穿透3.1 注册账号3.2 下载cpolar客户端3.3 登录cpolar web ui管理界面3.4 创建公网地址 4.固定公网地址访问 前言 Wamp 是一个 Windows系统下的 Apache PHP Mysql 集成安装环境,是一组常用来…

如何将视频进行分割?这几种分割方法了解一下

当我们将视频分成几段后,可以更好地组织和管理不同的片段,方便后续查找和使用。我们可以根据需要调整视频的长度和内容,满足不同的观看需求。此外,分段视频可以更好地适应不同的观看场景,可以更方便地分享和传播&#…

RocketMQ 源码分析——Consumer

文章目录 消费者启动流程消费者模式集群消费广播消费 Consumer负载均衡集群模式广播模式 并发消费流程获取topic配置信息获取Group的ConsumerList获取Queue的消费Offset拉取Queue的消息更新Queue的消费Offset 顺序消费流程消费存在的问题消费卡死启动之后较长时间才消费 消费者…

操作系统(5-7分)

内容概述 进程管理 进程的状态 前驱图 同步和互斥 PV操作(难点) PV操作由P操作原语和V操作原语组成(原语是不可中断的过程),对信号量进行操作,具体定义如下: P(S)&#…

渗透测试信息收集方法和工具分享

文章目录 一、域名收集1.OneForAll2.子域名挖掘机3.subdomainsBurte4.ssl证书查询 二、获取真实ip1.17CE2.站长之家ping检测3.如何寻找真实IP4.纯真ip数据库工具5.c段,旁站查询 三、端口扫描1.端口扫描站长工具2.masscan(全端口扫描)nmap扫描3.scanport4.端口表5.利…

API(八)cosocket常用SDK

一 同步且非阻塞的底层SDK:cosocket 说明: 本篇章只是对cosocket常用话API的汇总,并没有实际案例加以辅证场景: 许多单机版的中间件都是基于cosocket做的二次开发 OpenResty 的核心和精髓 cosocket ① coscoket常用的指令 个人建议&am…