pytorch入门3--线性回归以及许多python,pytorch函数的用法

news2024/9/24 6:51:46

先补充一些知识点,这里不一定用得到,后面的学习过程中可能用得到。
1.batch表示批量,就是一批数据集的意思;
2.batch_size表示数据集(样本集、训练集)的大小(数据的个数);
3.iteration是迭代的意思,即一个iteration表示迭代一次;
4.epoch可以看成训练次数,一个epoch指用训练集中的全部样本训练一次,最初训练DNN采用一次对全体训练集的样本进行训练(即用一个epoch),但样本集较大时占用内存大,目前常用随机梯度下降SGD来训练,将训练集分为多个mini_bach(即batch),一次迭代训练一个minibatch(batch_size个样本),根据该batch数据的loss更新权值。
举个例子,训练集有1000个样本,batch_size=10,那么训练完整个样本集需要100次iteration,1次epoch。epoc指把训练集所有的数据都跑一遍,当然也有将训练集的所有数据跑多遍的算法,即多个epoch。
以房价预测问题为例.
假设1:影响房价的关键因素是卧室个数,卫生间个数和居住面积,记为x1,x2,x3.
假设2:成交价是关键因素的加权和:
y = w1x1+w2x2+w3*x3+b(其中,w1,w2,w3是权重,b是偏差)
一、理论
1.拓展到线性模型:
给定n维输入:
在这里插入图片描述
线性模型有一个n维权重和一个标量偏差:
在这里插入图片描述
说明:<w,x>是张量的内积,和向量的内积是一个意思。
2.衡量预估质量:
比较真实值和预估值,例如房屋售价和估价。假设y是真实值,y^是估计值,可以比较:
在这里插入图片描述
这个叫做平方损失。
3.训练数据
收集一些数据点来决定参数(权重和偏差)。例如过去6个月卖的房子,这被称之为训练数据。假如我们有n个样本,记:
在这里插入图片描述
4.参数学习
在这里插入图片描述
最小化损失函数就是找到使损失函数l(x,y,w,b)最小的w,b的值。
5.显示解
在这里插入图片描述
线性回归可以看成单层神经网络。
在这里插入图片描述
二、线性回归代码实现
1.零基础代码

%matplotlib inline
import random
import torch
from d2l import torch as d2l

# 用于生成人造数据集
def synthetic_data(w,b,num_examples):
    x = torch.normal(0,1,(num_examples,len(w))) # torch.normal()表示正太分布,有三个参数torch.normal(mean,std,size),三个参数
    y = torch.matmul(x,w)+b
    y += torch.normal(0,0.01,y.shape) #加入一个均值噪音
    return x,y.reshape(-1,1) #生成列向量返回(后面有说明)

true_w = torch.tensor([2,-3.4])
true_b = 4.2
features,labels = synthetic_data(true_w,true_b,1000)
# 绘制点图显示
d2l.set_figsize()
d2l.plt.scatter(features[:,1].detach().numpy(),labels.detach().numpy(),1);

def data_iter(batch_size,features,labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0,num_examples,batch_size):
        batch_indices = torch.tensor(indices[i:min(i+batch_size,num_examples)])
        yield features[batch_indices],labels[batch_indices]
 
batch_size = 10
for x,y in data_iter(batch_size,features,labels):
    print(x,'\n',y)
    break
# 定义初始化模型参数
w = torch.normal(0,0.01,size=(2,1),requires_grad=True)
b =torch.zeros(1,requires_grad=True)

def linreg(x,w,b):
    # 线性回归模型
    return torch.matmul(x,w)+b

# 定义损失函数
def squared_loss(y_hat,y):
    return (y_hat - y.reshape(y_hat.shape))**2/2

# 定义优化算法(小批量随机梯度下降)sgd在深度学习中表示随机梯度下降法(Stochastic Gradient Descent)
def sgd(params,lr,batch_size):
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_() # 把梯度设为零
                
#训练过程
lr = 0.03 # 学习率,即步长
num_epochs = 3
net =  linreg
loss = squared_loss
for epoch in range(num_epochs):
    for x,y in data_iter(batch_size,features,labels):
        l =loss(net(x,w,b),y)
        l.sum().backward()
        sgd([w,b],lr,batch_size)
    with torch.no_grad():
        train_l = loss(net(features,w,b),labels)
        print(f'epoch{epoch +1},loss{float(train_l.mean()):f}')

说明:(1)len()函数返回对象的长度,注意不是length()函数。
在这里插入图片描述
(2)torch.normal(mean,std,size)用于产生正太分布的一个张量,参数mean是正太分布的均值,参数std是标准差,size是张量的shape.
在这里插入图片描述
(3)torch.matmul用于实现两个张量相乘。(这里满足之前说过的张量相乘的不同情况。
两个一维向量实现内积:
在这里插入图片描述
两个多维张量实现的是线性代数里的矩阵相乘:
1
报错原因为a的列不等于b的行。
在这里插入图片描述
在这里插入图片描述
(4)random.shuffle()用于将一个列表中的元素打乱顺序,元素个数和值不变
在这里插入图片描述
注意:该函数没有返回值,而是将参数值改变并传回给参数本身,即使参数值改变。
(5)reshape()函数中参数为-1表示无意义。reshape(-1,1)表示将二维数组重整为一个一列的数组,reshape(1,-1)表示把一个二维数组重整为一个一行的数组。
在这里插入图片描述
(6)range(n)用于生成一个从0到n-1的整数序列,list()用于将序列转化为列表。
在这里插入图片描述
(7)zero_()可用于将张量清零
在这里插入图片描述
(8)yield
python中含有yield的函数相当于一个迭代器,即iterator,生成的迭代器(函数执行的返回结果)可以用于for循环。
在这里插入图片描述
注意:当执行g=gen(5)时,gen中的代码并没有执行(节省内存空间),只是创建了一个生成器对象(generator),然后,执行for i in g,每次执行一次循环就会执行到yield处,返回一次yield的值。
在这里插入图片描述
从上例可以看出,yield也是使函数返回,暂停执行,且下一次执行是从上次被暂停的地方,而不是重新来一轮。

2.使用pytorch的深度学习框架来实现线性回归模型

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l

# 生成人工数据集
true_w = torch.tensor([2,-4.3])
true_b = 4.2
features,labels =d2l.synthetic_data(true_w,true_b,1000# 调用框架中的现有的API来读取数据
def load_array(data_arrays,batch_size,is_train=True):
     dataset = data.TensorDataset(*data_arrays)
     return data.DataLoader(dataset,batch_size,shuffle=is_train)
batch_size = 10
data_iter = load_array((features,labels),batch_size)
next(data_iter)

# 使用预定义好的层
from torch import nn # nn是神经网络的缩写
net = nn.Sequential(nn.Linear(2,1))

# 初始化模型参数
net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)

# 计算均方误差使用的是MSELoss类,也称为平方范数
loss = nn.MSELoss()

# 实例化SGD实例
trainer = torch.optim.SGD(net.parameters(),lr=0.03)

# 训练数据
num_epochs = 3
for epoch in range(num_epochs):
    for x,y in data_iter:
        l = loss(net(x),y) # net中自带模型参数,不用再传
        trainer.zero_grad()
        l.backward()
        trainer.step() # step()函数进行一次模型的更新
   l=loss(net(features),labels)
   print(f'epoch{epoch +1},loss{l:f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/376683.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

进程与线程的区别

进程和线程 进程 一个在内存中运行的应用程序。每个进程都有自己独立的一块内存空间&#xff0c;一个进程可以有多个线程&#xff0c;比如在Windows系统中&#xff0c;一个运行的xx.exe就是一个进程。 线程 进程中的一个执行任务&#xff08;控制单元&#xff09;&#xf…

深入理解跳表及其在Redis中的应用

前言跳表可以达到和红黑树一样的时间复杂度 O(logN)&#xff0c;且实现简单&#xff0c;Redis 中的有序集合对象的底层数据结构就使用了跳表。其作者威廉普评价&#xff1a;跳跃链表是在很多应用中有可能替代平衡树的一种数据结构。本篇文章将对跳表的实现及在Redis中的应用进行…

蓝桥杯:染色时间

蓝桥杯&#xff1a;染色时间https://www.lanqiao.cn/problems/2386/learning/?contest_id80 问题描述 输入格式 输出格式 样例输入输出 样例输入 样例输出 评测用例规模与约定 解题思路&#xff1a;优先队列 AC代码(Java)&#xff1a; 问题描述 小蓝有一个 n 行 m 列…

华为OD机试题,用 Java 解【任务混部】问题

最近更新的博客 华为OD机试题,用 Java 解【停车场车辆统计】问题华为OD机试题,用 Java 解【字符串变换最小字符串】问题华为OD机试题,用 Java 解【计算最大乘积】问题华为OD机试题,用 Java 解【DNA 序列】问题华为OD机试 - 组成最大数(Java) | 机试题算法思路 【2023】使…

本地docker部署mysql,IDEA直连实战

1、安装mysql镜像 前文中我们安装了docker和redis镜像&#xff0c;并在idea中成功连接&#xff0c;现在安装mysql镜像 docker pull mysql &#xff0c;默认最新版本 ps:可以参考https://www.runoob.com/docker/docker-install-mysql.html 2、启动mysql 打开powershell&…

快速掌握 Flutter 图片开发核心技能

大家好&#xff0c;我是 17。 在 Flutter 中使用图片是最基础能力之一。17 做了精心准备&#xff0c;满满的都是干货&#xff01;本文介绍如何在 Flutter 中使用图片&#xff0c;尽量详细&#xff0c;示例完整&#xff0c;包会&#xff01; 使用网络图片 使用网络图片超级简…

【035】基于java的进销库存管理系统(Vue+Springboot+Mysql)前后端分离项目,附万字课设论文

1.3 系统实现的功能 本次设计任务是要设计一个超市进销存系统&#xff0c;通过这个系统能够满足超市进销存系统的管理及员工的超市进销存管理功能。系统的主要功能包括&#xff1a;首页、个人中心、员工管理、客户管理、供应商管理、承运商管理、仓库信息管理、商品类别管理、 …

【知识图谱】架构-特点-缺点简介

架构物联网、云计算、人工智能等新一代信息技术的迅猛发展&#xff0c;带来了制造业的新一轮突破&#xff0c;推动着制造系统向智能化方向发展&#xff0c;驱动着未来制造模式的创新。其中数据和知识是实现制造业与新一代信息技术融合的基础&#xff0c;是实现智能制造的保障。…

PyQt5(二) python程序打包成.exe文件

目录一、安装 **pyinstaller**二、pyinstaller 打包2.1 pyinstaller 打包机制参考链接前言我们在 pycharm 上写的程序在发送到一台没有安装 python 解释器的机器上是不能运行的&#xff0c;甚至还要安装程序中所使用的第三方包&#xff0c;这样极其不方便。 但是 PC 是可以直接…

【C++】哈希——unordered系列容器|哈希冲突|闭散列|开散列

文章目录一、unordered系列关联式容器二、哈希概念三、哈希冲突四、哈希函数五、解决哈希冲突1.闭散列——开放定址法2.代码实现3.开散列——开链法4.代码实现六、结语一、unordered系列关联式容器 在C98中&#xff0c;STL提供了底层为红黑树结构的一系列关联式容器&#xff0c…

MySQL 横表和竖表相互转换

一 竖表转横表 1. 首先创建竖表 create table student ( id varchar(32) primary key, name varchar (50) not null, subject varchar(50) not null, result int); 2. 插入数据 insert into student (id, name, subject, result) values (0001, 小明, 语文, 83); insert into…

RK系列(RK3568) 收音机tef6686芯片驱动,i2c驱动

SOC:RK3568模块&#xff1a;tef6686系统&#xff1a;Android121.首先目前tef6686只有单片机才有驱动&#xff0c;Linux要集成只需要控制模块内部的i2c地址的顺序从github下载tef6686 Andruino的代码 https://github.com/tehniq3/TEF6686解压进入TEF6686-master\TEF6686_1602i2c…

华为OD机试用Python实现 -【任务混部】(2023-Q1 新题)

华为OD机试题 华为OD机试300题大纲任务混部题目输入输出示例一输入输出说明示例二输入输出说明备注Code代码编写思路华为OD机试300题大纲 参加华为od机试,一定要注意不要完全背诵代码,需要理解之后模仿写出,通过率才会高。 华为 OD 清单查看地址:blog.csdn.net/hihell/ca…

Google Guice 5:AOP

1. AOP 1.1 实际开发中面临的问题 在实际开发中&#xff0c;经常需要打印一个方法的执行时间&#xff0c;以确定是否存在慢操作 最简单的方法&#xff0c;直接修改已有的方法&#xff0c;在finnally语句中打印耗时 Override public Optional<Table> getTable(String da…

中级嵌入式系统设计师2014下半年下午试题与答案解析

中级嵌入式系统设计师2014下半年下午试题与答案解析 试题一 阅读下列说明和图,回答下列问题。 [说明] ATM自动取款机系统是一个由终端机、ATM系统、数据库组成的应用系统,具有提取现金、查询账户余额、修改密码及转账等功能。ATM自动取款机系统用例图如图1所示。

win11开始菜单增强工具:StartAllBack

StartAllBack是一款Windows11开始菜单增强工具&#xff0c;在任务栏上为Windows 11恢复经典样式的Windows 7主题风格开始菜单&#xff0c;主要功能包括&#xff1a;恢复和改进开始菜单样式、个性化任务栏、资源管理器等功能。软件功能恢复和改进任务栏在任务图标上显示标签调整…

【Spring】通过JdbcTemplate实现CRUD操作

个人简介&#xff1a;Java领域新星创作者&#xff1b;阿里云技术博主、星级博主、专家博主&#xff1b;正在Java学习的路上摸爬滚打&#xff0c;记录学习的过程~ 个人主页&#xff1a;.29.的博客 学习社区&#xff1a;进去逛一逛~ 通过JdbcTemplate实现 增删查改一、添加相关依…

python-下载某短视频平台视频(高清无水印)

python-下载某短视频平台音视频&#xff08;高清无水印&#xff09;前言1、获取视频 url2、发送请求3、数据解析4、本地保存5、完整代码前言 1、Cookie中文名称为小型文本文件&#xff0c;指某些网站为了辨别用户身份而储存在用户本地终端&#xff08;Client Side&#xff09;…

RTMP的工作原理及优缺点

一.什么是RTMP&#xff1f;RTMP&#xff08;Real-Time Messaging Protocol&#xff0c;实时消息传输协议&#xff09;是一种用于低延迟、实时音视频和数据传输的双向互联网通信协议&#xff0c;由Macromedia&#xff08;后被Adobe收购&#xff09;开发。RTMP的工作原理是&#…

Windows 11 网卡MAC地址 | 机器地址 | 网络地址 为 0 | 00-00-00-00-00-00?手动修复……

一位同事反映&#xff0c;他的电脑今天上班开机无法上网&#xff0c;上周末还正常&#xff0c;请我帮忙检修。该同事的电脑安装的是Windows 11&#xff0c;检查网络连接的详细信息&#xff0c;发现IP地址、网关、DNS参数都正常&#xff0c;但物理地址为00-00-00-00-00-00。另外…