PyTorch实现逻辑回归

news2025/1/19 20:43:22

最终效果

先看下最终效果:
1
这里用一条直线把二维平面上不同的点分开。

生成随机数据

#创建训练数据
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))


#构建线性回归参数
w = torch.randn((1))#随机初始化w,要用到自动梯度求导
b = torch.zeros((1))#使用0初始化b,要用到自动梯度求导

n_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5)  # 生成均值为2.标准差为1.5的随机数组成的矩阵
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5)  # 生成均值为-2.标准差为1.5的随机数组成的矩阵
c1 = torch.ones(100)

x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)

数据可视化

def plot(x, y, c):
    ax = plt.gca()
    sc = ax.scatter(x, y, color='black')
    paths = []
    for i in range(len(x)):
        if c[i].item() == 0:
            marker_obj = mmarkers.MarkerStyle('o')
        else:
            marker_obj = mmarkers.MarkerStyle('x')
        path = marker_obj.get_path().transformed(marker_obj.get_transform())
        paths.append(path)
    sc.set_paths(paths)
    return sc
plot(x, y, c)
plt.show()

使用x和o来表示两种不同类别的数据。
1

定义模型和损失函数

#构建逻辑回归参数
w = torch.tensor([1.,],requires_grad=True)  # 随机初始化w
b = torch.zeros((1),requires_grad=True)  # 使用0初始化b

wx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()

这里使用了平方损失函数来估算模型准确度。

训练模型

最多训练100次,每次都会更新模型参数,当损失值小于0.03时停止训练。

xx = torch.arange(-4, 5)
lr = 0.02 #学习率
for iteration in range(100):
    #前向传播
    loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()
    #反向传播
    loss.backward()
    #更新参数
    b.data.sub_(lr*b.grad) # b = b - lr*b.grad
    w.data.sub_(lr*w.grad) # w = w - lr*w.grad
    #绘图
    if iteration % 3 == 0:
        plot(x, y, c)
        yy = w*xx + b
        plt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)
        plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})
        plt.xlim(-4,4)
        plt.ylim(-4,4)
        plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))
        plt.show()

        if loss.data.numpy() < 0.03:  # 停止条件
            break

全部代码

import torch
import matplotlib.pyplot as plt
import matplotlib.markers as mmarkers

#创建训练数据
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))


#构建线性回归参数
w = torch.randn((1))#随机初始化w,要用到自动梯度求导
b = torch.zeros((1))#使用0初始化b,要用到自动梯度求导

wx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b


n_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5)  # 生成均值为2.标准差为1.5的随机数组成的矩阵
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5)  # 生成均值为-2.标准差为1.5的随机数组成的矩阵
c1 = torch.ones(100)

x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)


def plot(x, y, c):
    ax = plt.gca()
    sc = ax.scatter(x, y, color='black')
    paths = []
    for i in range(len(x)):
        if c[i].item() == 0:
            marker_obj = mmarkers.MarkerStyle('o')
        else:
            marker_obj = mmarkers.MarkerStyle('x')
        path = marker_obj.get_path().transformed(marker_obj.get_transform())
        paths.append(path)
    sc.set_paths(paths)
    return sc
plot(x, y, c)
plt.show()


#构建逻辑回归参数
w = torch.tensor([1.,],requires_grad=True)#随机初始化w
b = torch.zeros((1),requires_grad=True)#使用0初始化b

wx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()

xx = torch.arange(-4, 5)
lr = 0.02 #学习率
for iteration in range(100):
    #前向传播
    loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()
    #反向传播
    loss.backward()
    #更新参数
    b.data.sub_(lr*b.grad) # b = b - lr*b.grad
    w.data.sub_(lr*w.grad) # w = w - lr*w.grad
    #绘图
    if iteration % 3 == 0:
        plot(x, y, c)
        yy = w*xx + b
        plt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)
        plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})
        plt.xlim(-4,4)
        plt.ylim(-4,4)
        plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))
        plt.show()

        if loss.data.numpy() < 0.03:#停止条件
            break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1298733.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[Linux] Apache的配置与运用

一、web虚拟主机的构台服务器上运行多个网站&#xff0c;每个网站实际上并不独立占用整个服务器&#xff0c;因此称为"虚拟"虚拟主机的虚拟主机服务可以让您充分利用服务器的硬件资源&#xff0c;大大降低了建立和运营网站的成本 Httpd服务使构建虚拟主机服务器变得容…

[LeetCode周赛复盘] 第 119 场双周赛20231209

[LeetCode周赛复盘] 第 119 场双周赛20231209 一、本周周赛总结100130. 找到两个数组中的公共元素1. 题目描述2. 思路分析3. 代码实现 100152. 消除相邻近似相等字符1. 题目描述2. 思路分析3. 代码实现 100147. 最多 K 个重复元素的最长子数组1. 题目描述2. 思路分析3. 代码实…

数据结构之归并排序及排序总结

目录 归并排序 归并排序的时间复杂度 排序的稳定性 排序总结 归并排序 归并排序大家只需要掌握其递归方法即可&#xff0c;非递归方法由于在某些特殊场景下边界难控制&#xff0c;我们一般很少使用非递归实现归并排序。那么归并排序的递归方法我们究竟是怎样实现呢&#xff…

蓝桥杯航班时间

蓝桥杯其他真题点这里&#x1f448; //飞行时间 - 时差 已过去的时间1 //飞行时间 时差 已过去的时间2 //两个式子相加会发现 飞行时间 两段时间差的和 >> 1import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader;public cl…

ECS云主机容量大于2TB,初始化Linux数据盘(parted)

本文为您介绍当容量大于2TB时&#xff0c;如何在Linux环境下适用parted分区工具初始化数据盘。 操作场景 本文以“CentOS 7.6 64位”操作系统为例&#xff0c;介绍当磁盘容量大于2TB时&#xff0c;如何使用parted分区工具在Linux操作系统中为数据盘设置分区&#xff0c;操作回…

使用粗糙贴图制作粗纹皮革手提包3D模型

在线工具推荐&#xff1a; 3D数字孪生场景编辑器 - GLTF/GLB材质纹理编辑器 - 3D模型在线转换 - Three.js AI自动纹理开发包 - YOLO 虚幻合成数据生成器 - 三维模型预览图生成器 - 3D模型语义搜索引擎 当谈到游戏角色的3D模型风格时&#xff0c;有几种不同的风格&#xf…

对无向图进行邻接矩阵的转化,并且利用DFS(深度优先)和BFS(广度优先)算法进行遍历输出, 在邻接矩阵存储结构上,完成最小生成树的操作。

一 实验目的 1&#xff0e;掌握图的相关概念。 2&#xff0e;掌握用邻接矩阵和邻接表的方法描述图的存储结构。 3&#xff0e;掌握图的深度优先搜索和广度优先搜索遍历的方法及其计算机的实现。 4&#xff0e;理解最小生成树的有关算法 二 实验内容及要求 实验内容&#…

管理类联考——数学——真题篇——按知识分类——数据

文章目录 排列组合2023真题&#xff08;2023-05&#xff09;-数据分析-排列组合-组合-C运算-至少-需反面思考真题&#xff08;2023-08&#xff09;-数据分析-排列组合-相邻不相邻-捆绑法插空法-插空法注意空位比座位多1个&#xff0c;是用A&#xff1b;捆绑法内部排序用A&#…

实现简易的一对一用户聊天

服务端 package 一对一用户;import java.awt.BorderLayout; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.io.PrintWriter; import java.net.ServerSocket; import java.net.Socket; import java.util.Vector…

langchain入门及两种模型的使用

一、简介 1、OpenAi、chatgpt Openai就是开发chatgpt系列AI产品的公司。 chatgpt是一款AI产品&#xff0c;chatgpt plus也是一款AI产品&#xff0c;后者可以看做是前者的会员版/付费版。 chatgpt-3.5、chatgpt-4这俩简单说都是AI技术模型&#xff0c;后者可以看做是前者的升…

linux下的进程程序替换

进程程序替换 替换概念替换函数execl()execv()execvp()/execlp()execle()/execvpe() 如何在C/C程序里面执行别的语言写的程序。小tips 替换概念 当进程调用一种exec函数时&#xff0c;该进程的用户空间代码和数据完全被新程序替换&#xff0c;从新程序的代码部分开始运行。调用…

案例005:基于小程序的电子点菜系统开发设计与实现

文末获取源码 开发语言&#xff1a;Java 框架&#xff1a;SSM JDK版本&#xff1a;JDK1.8 数据库&#xff1a;mysql 5.7 开发软件&#xff1a;eclipse/myeclipse/idea Maven包&#xff1a;Maven3.5.4 小程序框架&#xff1a;uniapp 小程序开发软件&#xff1a;HBuilder X 小程序…

【机器学习】亚马逊云科技基础知识:以推荐系统为例。你知道机器学习的关键所在么?| 机器学习管道的各个阶段及工作:以Amazon呼叫中心转接问题为例讲解

有的时候,暂时的失利比暂时胜利要好得多。 ————经典网剧《mao pian》,邵半仙儿 🎯作者主页: 追光者♂🔥 🌸个人简介: 💖[1] 计算机专业硕士研究生💖 🌿[2] 2023年城市之星领跑者TOP1(哈尔滨)🌿 🌟[3] 2022年度博客之星人工智能领域TOP

MongoDB中的sort()排序方法、aggregate()聚合方法和索引

本文主要介绍MongoDB中的sort()排序方法、aggregate()聚合方法和索引。 目录 MongoDB的sort()排序方法MongoDB的aggregate()聚合方法MongoDB的索引 MongoDB的sort()排序方法 在MongoDB中&#xff0c;sort()方法是用来对查询结果进行排序的。sort()方法可以用于在查询语句中对指…

红队攻防实战之ThinkPHP-RCE集锦

你若不勇敢&#xff0c;谁又可以替你坚强&#xff1f; ThinkPHP 2.x RCE漏洞 1、查询phpinfo() 2、任意代码执行 3、Getshell 蚁剑连接&#xff1a; ThinkPHP5 5.0.23 RCE漏洞 发送数据包&#xff1a; 成功执行id命令&#xff1a; 工具验证 ThinkPHP5 SQL注入漏洞 &&am…

Android开发,JNI,NDK,C++层操作java的对象实践

Android开发&#xff0c;JNI&#xff0c;NDK&#xff0c;C层操作java的对象实践 1.数组 在jni中调用数组 extern "C" JNIEXPORT void JNICALL Java_com_example_myapplication_MainActivity_testArr(JNIEnv *env, jobject thiz, jint a, jstring s,jintArray ints…

Axure网页端高交互组件库, 下拉菜单文件上传穿梭框日期城市选择器

作品说明 组件数量&#xff1a;共 11 套 兼容软件&#xff1a;Axure RP 9/10&#xff0c;不支持低版本 应用领域&#xff1a;web端原型设计、桌面端原型设计 作品特色 本作品为「web端组件库」&#xff0c;高保真高交互 (带仿真功能效果)&#xff1b;运用了动态面板、中继…

LNMP网站架构分布式搭建部署

1. 数据库的编译安装 1. 安装软件包 2. 安装所需要环境依赖包 3. 解压缩到软件解压缩目录&#xff0c;使用cmake进行编译安装以及模块选项配置&#xff08;预计等待20分钟左右&#xff09;&#xff0c;再编译及安装 4. 创建mysql用户 5. 修改mysql配置文件&#xff0c;删除…

如何为游戏角色3D模型设置纹理贴图

在线工具推荐&#xff1a; 3D数字孪生场景编辑器 - GLTF/GLB材质纹理编辑器 - 3D模型在线转换 - Three.js AI自动纹理开发包 - YOLO 虚幻合成数据生成器 - 三维模型预览图生成器 - 3D模型语义搜索引擎 当谈到游戏角色的3D模型风格时&#xff0c;有几种不同的风格&#xf…

【面试经典150 | 二叉树】从中序与后序遍历序列构造二叉树

文章目录 写在前面Tag题目来源题目解读解题思路方法一&#xff1a;递归 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法&#xff0c;两到三天更新一篇文章&#xff0c;欢迎催更…… 专栏内容以分析题目为主&#xff0c;并附带一些对于本题涉及到的数据结构等内容…