深度学习基础与线性回归实例

news2025/2/27 8:13:31

1、机器学习基础-线性回归

dffdb631fed54cf193a854e3e6607292.png

介绍:这是一个教育对收入影响的数据,从图像的走势来看,它是具有一个线性关系,即受教育年限越长收入越高,这样我们可以通直线来抽象出它们的关系。

接下来,我们将会介绍一些方法,分别是单变量线性回归算法、成本函数与损失函数、梯度下降算法。

首先要提到的是单变量线性回归算法,我们有这样一个函数f(x)=w*x+b;即x代表,f(x)代表收入我们使用f(x)这个函数来映射输入特征值和输出值。这个时候问题就转化为了,这条直线需要画在什么地方才合适,或者我们说w和b该取什么样的值呢?

然后是,成本函数与损失函数,使用均方差作为成本函数,也就是预测值和真实值之间的平方取均值,我们的优化目标(y代表实际的收入)是找到合适的w和b,使得(f(x)-y)**2越小越好,这样我们获得的直线会更好些。

最后是使用梯度下降算法,转而求解参数w,b,后面我会再着重讲解这里,大家只要明白这里会用到这个方法即可。

2、收入数据集读取与观察

现在,我们可以读取文件,使用matplotlib函数来绘制散点图。

下面是代码,我这里是用的是pandas,当然,如果你自己有能力也可以手写一个。

import torch

import pandas as pd
import numpy
import matplotlib.pyplot as plt

data=pd.read_csv("./dataset/Income1.csv")
data.info()  #返回这个文件的一些信息
print(data)
plt.scatter(data.Education,data.Income)
plt.xlabel("Education"),plt.ylabel("Income")
plt.show()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 30 entries, 0 to 29
Data columns (total 3 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   Unnamed: 0  30 non-null     int64  
 1   Education   30 non-null     float64
 2   Income      30 non-null     float64
dtypes: float64(2), int64(1)
memory usage: 848.0 bytes

 此为info方法获得的信息,下面是用matplotlib所画的散点图。

dffdb631fed54cf193a854e3e6607292.png

3、初始化模型、损失函数和优化方法

from torch import nn
import torch

import pandas as pd
import numpy
import matplotlib.pyplot as plt
import numpy as np

data=pd.read_csv("./dataset/Income1.csv")

# x=data.Education
# print(x)   #返回的是原来的列
# x=data.Education.values
# print(x)


# x=data.Education.values.reshape(-1,1).shape
# print(x)  #(30,1)指,30个数据
#
# x=data.Education.values.reshape(-1,1).astype(np.float64)
# print(x)

X=torch.from_numpy(data.Education.values.reshape(-1,1).astype(np.float64))
print(X)

Y=torch.from_numpy(data.Income.values.reshape(-1,1).astype(np.float64))
print(Y)

model=nn.Linear(1,1)  #out=w@input+b Linear表示随机生成一个权重(有w,b)
#这时就是指input和out都是1 ,w@input+b等价于 model(input)

#计算均方误差
loss_fn=nn.MSELoss()   #损失函数
#优化算法
opt=torch.optim.SGD(model.parameters(),lr=0.0001)

关于预处理部分,我简单说说,调用data.Education返回的是一个原来csv文件的那一列数据,这并不是我们想要的,所以在这里,我先将其转化为了numpy的数据类型ndarrary数组类型,用的是(data.Education.values)方式,我希望我的数据能够一个输入一个输出,所以又添加了reshape方法,将其shape变为(30,1),而torch.from_numpy显而易见是将numpy的ndarrary类型转化为pytorch所用的tensor类型。

这里对于X,Y处理过后,采用nn.Linear()方式随机生成一个权重,一输入一输出。对于损失函数我们可以用nn,MSELoss(),那么建立模型还有一步是进行优化,pytorch当然也提供了,只不过它是在torch当中,torch.optim.SGD(model.parameters(),lr=0.0001),model.parameters()返回需要优化参数,lr为学习速率,具体是什么,我会在后面讲到,这里只需要知道要用就对了。

4、模型训练与结果可视化

import pandas as pd
import numpy
import matplotlib.pyplot as plt
import numpy as np

from torch import nn
import torch

data=pd.read_csv("./dataset/Income1.csv")
X = torch.from_numpy(data.Education.values.reshape(-1, 1)).type(torch.FloatTensor)
# print(X)

Y = torch.from_numpy(data.Income.values.reshape(-1, 1)).type(torch.FloatTensor)
# print(Y)

model=nn.Linear(1,1)

#计算均方误差
loss_fn=nn.MSELoss()   #损失函数
#优化算法
opt = torch.optim.SGD(model.parameters(), lr=0.0001)

for epoch in range(5000):
    for x,y in zip(X,Y):
        y_pred = model(x)   #使用模型预测
        loss=loss_fn(y,y_pred)  #根据预测结果计算损失
        opt.zero_grad()  #把变量的梯度清零
        loss.backward() #反向传播算法,求解梯度
        opt.step()  #优化模型参数
print(model.weight,model.bias)
plt.scatter(data.Education,data.Income)
plt.xlabel("Education"),plt.ylabel("Income")
plt.plot(X.numpy(),model(X).data.numpy(),c='r')
plt.show()

打印weight,bias值 

tensor([[4.9757]], requires_grad=True) 
tensor([-28.3907], requires_grad=True)

较好的拟合了本次数据,那么这就是创建模型,训练模型和使用模型的过程。

5、资源分享

博客的相关代码与csv文件已上传至

GitHub:pytorch-Learning-and-Practice/Reference code/Deep Learning and Linear Regression at main · Auorui/pytorch-Learning-and-Practice (github.com)

如果大家觉得有用,在GitHub里面点击收藏即可。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/39283.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java中线程的状态

Java中线程的状态操作系统中线程的状态Java中线程的状态线程状态枚举类操作系统中线程的状态 从操作系统层面来看&#xff0c;线程通常有以下五种状态&#xff0c;前三种是线程的基本状态。 【运行态】&#xff1a;进程正处在处理机上运行&#xff0c;在单处理机环境下&#…

【学习笔记39】获取DOM标签对象

获取DOM标签对象一、认识DOM二、获取非常规DOM(html head body)1、HTML2、head3、body三、获取常规DOM&#xff08;一&#xff09;按照类名、标签名和ID名获取标签1、类名(伪数组)2、标签名(伪数组)3、ID名(唯一性)&#xff08;二&#xff09;按照选择器获取标签1、querySelect…

《人月神话》(The Mythical Man-Month)1 看清问题的本质:如果我们想解决问题,就必须试图先去理解它...

第一章 焦油坑&#xff08;The Tar Pit&#xff09;史前史中&#xff0c;没有比巨兽在焦油坑中垂死挣扎的场面更令人震撼的了。上帝见证着恐龙、猛犸象、剑齿虎在焦油中挣扎。它们挣扎得越是猛烈&#xff0c;焦油纠缠得越紧&#xff0c;没有任何猛兽足够强壮或具有足够的技巧&a…

IDEA注释配置程序员信息(带完整截图步骤,超级详细)

1.配置类注释的程序员信息 效果图 配置截图 模板 &#xff08;可根据需要进行位置调整及个数删除&#xff09; /***BelongsProject: ${PROJECT_NAME}*BelongsPackage: ${PACKAGE_NAME}*ClassName ${NAME}*Author: XUXIAN*CreateTime: ${YEAR}-${MONTH}-${DAY} ${HOUR}:${MINU…

这可能是最权威、最全面的Go语言编码风格规范了!

每种编程语言除了固定的语法之外&#xff0c;都会有属于自己的地道的(idiomatic)写法。其实&#xff0c;自然语言也不例外&#xff0c;你想&#xff0c;你用心想想是不是这样。语言的设计者们希望开发人员都能编写统一风格的地道的代码&#xff0c;这样不仅代码可读性好&#x…

细分图中的可到达节点 : 常规最短路运用题

题目描述 Tag : 「最短路」、「单源最短路」、「Dijkstra」、「SPFA」 给你一个无向图&#xff08;原始图&#xff09;&#xff0c;图中有 n 个节点&#xff0c;编号从 0 到 n - 1 。你决定将图中的每条边 细分 为一条节点链&#xff0c;每条边之间的新节点数各不相同。 图用…

[前端框架]-VUE(上篇)

Vue (读音 /vjuː/&#xff0c;类似于 view) 是一套用于构建用户界面的渐进式框架。与其它大型框架不同的是&#xff0c;Vue 被设计为可以自底向上逐层应用。Vue 的核心库只关注视图层&#xff0c;不仅易于上手&#xff0c;还便于与第三方库或既有项目整合。另一方面&#xff0…

链表经典算法题目

1.回文链表 编写一个函数&#xff0c;检查输入的链表是否是回文的。 示例 1&#xff1a; 输入&#xff1a; 1->2 输出&#xff1a; false 示例 2&#xff1a; 输入&#xff1a; 1->2->2->1 输出&#xff1a; true 笔试的写法 重点是快速code,不考虑空间复杂度…

蒙特卡洛法(Monte Carlo)电动汽车负荷预测matlab程序设计

电动汽车充电负荷的时间分布预测 规模化电动汽车充电负荷在未来某一天随时间特性的分布规律是研究电动汽车发展对配 电网影响以及充电站选址定容问题的前提与基础。电动汽车充电负荷的分布情况与车主的行 为特征有关&#xff0c;不同类型的电动汽车车主出行规律以及充电习惯不…

<Linux系统复习>信号

一、本章重点 1、什么是信号&#xff1f; 2、查看信号列表 3、信号捕捉 4、信号产生的5种方式 5、介绍CoreDump 6、信号处理的方式 7、如何理解信号产生到处理的过程 8、sigpending、sigprocmask、sigaction函数的使用 9、信号处理的时机 10、SIGCHLD信号 11、可重入函数 01 什…

Codeforces Round 836 (Div. 2) A - C

A:SSeeeeiinngg DDoouubbllee 题意&#xff1a;给定一个字符串&#xff0c;每个字符串的字符可以出现两次&#xff0c;要求通过重新排列构造一个回文串。 思路&#xff1a;直接暴力可以&#xff0c;每个字符头部一个尾部一个。 #include<cstdio> #include <iostream…

不使用实体类的情况下接收SQL查询结果、@Autowired注入为null解决

目录 一、场景 二、环境 三、使用 1、数据库表以及数据准备 2、项目导入必要依赖 3、添加连接数据库配置文件 4、编写测试方法 5、执行结果 四、将SQL单独提取出来 2.1 定义查询接口方法 2.2 测试 2.3 测试结果 五、问题记录&#xff1a; Autowired注入失败/null的…

b、B、KB、Kib、MB、MiB、GB、GiB、TB、TiB的区别

1024这个数字&#xff0c;想必计算机行业从业人员应该不会陌生&#xff0c;甚至10月24日还被当做程序员日&#xff0c;如果你问一个程序员1GB等于多少MB,他大概率会不假思索回答:1024。 没错&#xff0c;对于稍微对计算机或者网络有了解的人&#xff0c;一般都认为1024是数据容…

最短路算法 - dijkstra

最短路算法 - dijkstra1. 算法介绍2. 实战2.1 Reachable Nodes In Subdivided Graph3 参考1. 算法介绍 算法目的&#xff1a;求图中某点 s 到其余各点的最短距离 算法步骤&#xff1a; 初始化距离数组 dis 和优先级队列&#xff0c;其中 dis[i] 表示 s 点到当前 i 点的最短距…

树莓派上搭建SVN服务器

目录 一、服务端安装步骤 1.安装svn 2.创建目录 3.创建版本仓库 4.修改配置&#xff08;authz,passwd,svnserve.conf&#xff09; 5.启动服务 二、tortoisSVN客户端安装 三、结束 一、服务端安装步骤 1.安装svn sudo apt-get install subversion 2.创建目录 sudo m…

品RocketMQ 源码,学习并发编程三大神器

这篇文章&#xff0c;笔者结合 RocketMQ 源码&#xff0c;分享并发编程三大神器的相关知识点。 1 CountDownLatch 实现网络同步请求 CountDownLatch 是一个同步工具类&#xff0c;用来协调多个线程之间的同步&#xff0c;它能够使一个线程在等待另外一些线程完成各自工作之后&…

selenium--获取页面信息和截图

获取页面信息namecurrent_urltitlecurrent_window_handlewindow_handlespage_source简单用法—— 判断页面截图1.get_screenshot_as_png2.get_screenshot_as_file获取页面信息 主要方法如下图&#xff1a; 介绍一下常用的方法&#xff1a; name 获取浏览器名字 current_u…

Packet Tracer 实验 - 排除多区域 OSPFv3 故障

地址分配表 设备 接口 IPv6 全局单播地址 IPv6 本地链路地址 默认网关 ISP GigabitEthernet0/0 2001:DB8:C1:1::1/64 FE80::C1 不适用 ASBR GigabitEthernet0/0 2001:DB8:C1:1::2/64 FE80::7 不适用 Serial0/0/0 2001:DB8:A8EA:F0A::1 FE80::7 不适用 S…

如何通过 kubectl 进入 node shell

概述 假设这样一个场景&#xff1a; 生产环境中&#xff0c;Node 都需要通过堡垒机登录&#xff0c;但是 kubectl 是可以直接在个人电脑上登录的。 这种场景下&#xff0c;我想要通过 kubectl 登录到 K8S 集群里的 Node&#xff0c;可以实现吗&#xff1f; 可以的&#xff…

LinkedList与链表

目录 1.链表 2.链表的模拟实现 3.LinkedList的模拟实现 4.LinkedList的使用 4.1 什么是LinkedList 4.2 LinkedList的使用 5.ArrayList和LinkedList的区别 我的GitHub&#xff1a;Powerveil GitHub 我的Gitee&#xff1a;Powercs12 (powercs12) - Gitee.com 皮卡丘每天学…