pytorch学习(一)线性模型

news2024/11/15 12:49:11

文章目录

    • 线性模型
    • pytorch
    • 使用sklearn训练

pytorch是一个基础的python的科学计算库,它有以下特点:

  • 类似于numpy,但是它可以使用GPU
  • 可以用它来定义深度学习模型,可以灵活的进行深度学习模型的训练和使用

线性模型

线性模型的基本形式为: f ( x ) = w T x + b f(x)=w^Tx+b f(x)=wTx+b,线性模型的参数是w和b,它的学习是通过不断减少损失实现的,其损失一般为均方损失

在这里插入图片描述
pytorch代码实现:

# from tkinter import W
import numpy as np
import matplotlib.pyplot as plt

x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]

def forward(x):
    return x*w

def loss(x,y):
    y_pred=forward(x)
    return (y_pred-y)*(y_pred-y)

w_list=[]
mse_list=[]

# 对于不同的w,手动模拟学习的过程
for w in np.arange(0.0,4.1,0.1):
    print("w=",w)
    l_sum=0# 损失值
    # 一次计算MSE的过程
    for x_val,y_val in zip(x_data,y_data):
        # 计算预测值
        y_pred_val=forward(x_val)
        # 计算损失函数
        loss_val=loss(x_val,y_val)
        l_sum+=loss_val
        print("\t",x_val,y_val,y_pred_val,loss_val)
    print("MSE=",l_sum/3)
    w_list.append(w)
    mse_list.append(l_sum/3)

# 画图
plt.plot(w_list,mse_list)
plt.ylabel("Loss")
plt.xlabel("w")
plt.show()

结果截图:
在这里插入图片描述

pytorch

# 使用pytorch      
import torch
import matplotlib.pyplot as plt
# 加载数据
x_data=torch.Tensor([[1.0],[2.0],[3.0]])
y_data=torch.Tensor([[2.0],[4.0],[6.0]])

# 构造线性模型
# tensor.nn.Linear(in_features, out_features, bias=True)其中in_features表示输入的样本,out_features表示输出的样本
class LinearModule(torch.nn.Module):
    def __init__(self):
        super(LinearModule,self).__init__()
        self.linear=torch.nn.Linear(1,1)
        
    # 定义前馈函数
    def forward(self,x):
        y_pred=self.linear(x)
        return y_pred
    
# 构造对象
model=LinearModule()
# 定义损失函数
# torch.nn.MSELoss(size_average=True, reduce=True)其中size_average是否求均值、reduce是否降维求和
criterion=torch.nn.MSELoss(size_average=False)

# 定义优化器
# SGD表示随机梯度下降
# torch.optim.SGD(params【权重参数】, lr=【学习率】, momentum=0【冲量】)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)

epoches=[]
losses=[]

# 训练
for epoch in range(1000):
    # 前馈
    y_pred=model(x_data)
    # 计算损失函数
    loss=criterion(y_pred,y_data)
    print(epoch,loss.item())
    # 获取数据
    epoches.append(epoch)
    losses.append(loss.item())
    
    # 清零
    optimizer.zero_grad()
    # 反馈
    loss.backward()
    # 权重更新
    optimizer.step()
    

# 输出权重和偏移量
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())

# 画图
plt.plot(epoches,losses)
plt.xlabel('epoch')
plt.ylabel('Loss')
plt.show()

# 预测值
x_test=torch.tensor([4.0])
y_test=model(x_test)
print('y_pred=',y_test.data)

部分结果截图
在这里插入图片描述

使用sklearn训练

# 使用sklearn训练
from sklearn.linear_model import LinearRegression
import numpy as np
import matplotlib.pyplot as plt

lr=LinearRegression()
x=np.array([1.0,2.0,3.0],dtype='float')
x=x.reshape(-1,1)
print(x.shape)
y=np.array([2.0,4.0,6.0],dtype='float')
y=y.reshape(-1,1)
print(y.shape)
lr.fit(x,y)
print('直线的斜率:',lr.coef_)
print('截距:',lr.intercept_)

# 画图
plt.plot(x,y,'b.')
plt.xlabel('X',fontsize=18)
plt.ylabel('Y',rotation=0,fontsize=18)

plt.plot(x,y,'r-',linewidth=2,label='predictions')
plt.legend(loc="upper left", fontsize=14)
plt.show()

训练结果截图
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1397814.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Studio One2024免费版下载及入门教程分享

众所周知,Studio One是一个专业的音频编辑软件,近几年随着音视频剪辑越来越火,Studio One也逐渐被人们所熟知。最近,就有许多小伙伴私信我,寻求Studio One的入门教程。 这不,今天小编就给大家带来了音频剪…

iphone5s基带部分电源部分主主电源供电及

时序: 1.,基带电源的供电,基带电源也叫pmu。 首先时序图说电池提供供电,电池是J6接口,视频习惯把接口称之为座子。查U2_RF芯片,发现供电信号为PP_BATT_VCC_CONN,但是没查到跟电池座子有关系,电池座子写的是…

Flask框架小程序后端分离开发学习笔记《1》网络知识

Flask框架小程序后端分离开发学习笔记《1》网络知识 Flask是使用python的后端,由于小程序需要后端开发,遂学习一下后端开发。 一、网址组成介绍 协议:http,https (https是加密的http)主机:g.cn zhihu.com之类的网址…

Python使用pyechart分析疫情确诊人数图(2024)

import json from pyecharts.charts import Map from pyecharts import options as opts# 首先打开文件获取数据 f open("/Desktop/python/Project/数据可视化/疫情.txt", "r", encoding"UTF-8") data f.read()# 字符串转化成json数据 data_js…

rust使用protobuf

前言 c,java,go 等直接是用 ,具体就不说了,这章主要讲述rust 使用protobuf 这章主要讲述2种 1 > protoc protoc-gen-rust plugin 2> protoc prost-build 1:环境 win10 rustrover64 25-2 下载地址 https://github.com/protocolbu…

CHS_01.2.2.1+调度的概念、层次

CHS_01.2.2.1调度的概念、层次 调度的概念、层次知识总览调度的基本概念调度的三个层次——高级调度![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/6957fdec179841f69a0508914145da36.png)调度的三个层次——低级调度调度的三个层次——中级调度补充知识&#xff…

unity-声音与声效OLD

声音与声效 基本概念audio clipaudio listeneraudio source 基本操作如何创建音频源(背景音乐)如何在测试的时候关闭声音 常用代码一般流程如何在一个物体上播放多个音效如何在代码中延时播放多个声音如何在代码中停止音频的播放如何判断当前是否在播放音…

Web3解密:区块链技术如何颠覆传统互联网

随着区块链技术的崛起,Web3正逐渐成为新一代互联网的代名词。它不再依赖中心化的权威机构,而是通过去中心化、透明、安全的特性,为用户带来更为开放和公正的互联网体验。本文将深入解密Web3,揭示区块链技术如何颠覆传统互联网的基…

Linux搭建dns主从服务器

一、实验要求 配置Dns主从服务器,能够实现正常的正反向解析 二、知识点 1、DNS简介 DNS(Domain Name System)是互联网上的一项服务,它作为将域名和IP地址相互映射的一个分布式数据库,能够使人更方便的访问互联网。…

git提交代码到远端仓库的方法详解

一、何为git git就是版本控制器,就比如说你新建了一个git文件夹,里面用于存放你的C语言实习报告,现在要用git对该文件夹进行接管。当你修改了你的C语言实习报告点击保存之后,就用git的相关命令,提交给git,让…

ctfshow命令执行(web29-web52)

目录 web29 web30 web31 web32 web33 web34 web35 web36 web37 web38 web39 web40 web41 web42 web43 web44 web45 web46 web47 web48 web49 web50 web51 web52 web29 <?php error_reporting(0); if(isset($_GET[c])){$c $_GET[c];if(!preg_match…

go语言(十)---- 面向对象封装

面向对象的封装 package mainimport "fmt"type Hero struct {Name stringAd intLevel int }func (this Hero) Show(){fmt.Println("Name ", this.Name)fmt.Println("Ad ", this.Ad)fmt.Println("Level ", this.Level)}func (thi…

模型的 F1 分数

模型的 F1 分数是一个综合评估模型性能的指标&#xff0c;同时考虑了模型的精确率&#xff08;Precision&#xff09;和召回率&#xff08;Recall&#xff09;。F1 分数的计算公式为&#xff1a; 其中&#xff1a; Precision 是模型正确识别为正例的样本数量与所有被模型识别为…

PYG中torch_scatter, torch_sparse等pip安装包错解决

原安装命令&#xff1a; pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch- 1.13.0cu117.html报错&#xff1a; 正确安装命令&#xff1a; pip install --no-index pyg_lib torch_scatter torch_sparse …

JavaWeb:servlet+jsp+mybatis商品管理增删改查

文章目录 1、环境准备1.1、创建数据库1.2、创建项目导入依赖1.3、创建包1.4、创建实体类1.5、准备mybatis环境1.6、编写Mybatis工具类1.7、编写主页面 2、功能实现2.1、查询所有2.2、添加功能2.3、修改数据回显2.4、修改数据2.5、删除数据 1、环境准备 1.1、创建数据库 CREAT…

Python文件自动化处理

os模块 Python标准库和操作系统有关的操作创建、移动、复制文件和文件夹文件路径和名称处理 路径的操作 获取当前Python程序运行路径不同操作系统之间路径的表示方式 windows中采用反斜杠(\)作为文件夹之间的分隔符 Mac和Linux中采用斜杠(/)作为文件夹之间的分隔符 把文件…

【用队列实现栈】【用栈实现队列】Leetcode 232 225

【用队列实现栈】【用栈实现队列】Leetcode 232 225 队列的相关操作栈的相关操作用队列实现栈用栈实现队列 ---------------&#x1f388;&#x1f388;题目链接 用队列实现栈&#x1f388;&#x1f388;------------------- ---------------&#x1f388;&#x1f388;题目链…

React Store及store持久化的使用

1.安装 npm insatll react-redux npm install reduxjs/toolkit npm install redux-persist2. 使用React Toolkit创建counterStore并配置持久化 store/modules/counterStore.ts&#xff1a; import { createSlice } from reduxjs/toolkit// 定义状态类型 interface Action {…

4个值得使用的免费爬虫工具

在信息时代&#xff0c;数据的获取对于各行业都至关重要。而在数据采集的众多工具中&#xff0c;免费的爬虫软件成为许多用户的首选。本文将专心分享四款免费爬虫工具&#xff0c;突出介绍其中之一——147采集软件&#xff0c;为您揭示这些工具的优势和应用&#xff0c;助您在数…

使用Sqoop从Oracle数据库导入数据

在大数据领域&#xff0c;将数据从关系型数据库&#xff08;如Oracle&#xff09;导入到Hadoop生态系统是一项常见的任务。Sqoop是一个强大的工具&#xff0c;可以帮助轻松完成这项任务。本文将提供详细的指南&#xff0c;以及丰富的示例代码&#xff0c;帮助了解如何使用Sqoop…