【深度学习】PyTorch深度学习笔记02-线性模型

news2024/12/29 11:35:35

1. 监督学习

2. 数据集的划分

3. 平均平方误差MSE

4. 线性模型Linear Model - y = x * w

用穷举法确定线性模型的参数

import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]


def forward(x):
    return x * w

def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) * (y_pred - y)

w_list = []
mse_list = []

for w in np.arange(0.0, 4.0, 0.1):
    print('w=', w)
    l_sum = 0
    for x_val, y_val in zip(x_data, y_data):  
        y_pred_val = forward(x_val)
        loss_val = loss(x_val, y_val)  
        l_sum += loss_val
        print('\t', x_val, y_val, y_pred_val, loss_val)
    print('MSE=', l_sum / len(x_data))  
    w_list.append(w)
    mse_list.append(l_sum / len(x_data))

plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()

详细过程

    本课程的主要任务是构建一个完整的线性模型:
        导入numpy和matplotlib库;
        导入数据 x_data 和 y_data;
        定义前向传播函数:
            forward:输出是预测值y_hat
        定义损失函数:
            loss:平方误差
        创建两个空列表,后面绘图的时候要用:
            分别是横轴的w_list和纵轴的mse_list
        开始计算(这里没有训练的概念,只是单纯的计算每一个数据对应的预测值,然后让预测值跟真实y值求MSE):
            外层循环:
                在0.0~4.0之间均匀取点,步长0.1,作为n个横坐标自变量,用w表示;
            内层循环:核心计算内容
                从数据集中,按数据对取出自变量x_val和真实值y_val;
                先调用forward函数,计算y的预测值 w*x
                调用loss函数,计算单个数据的平方误差;
                累加损失;
                打印想要看到的数值;
                在外层循环中,把计算的结果放进之前的空列表,用于绘图;
    在获得了打印所需的数据列表之后,模式化地打印图像:

运行结果

ps:

visdom库可用于可视化

np.meshgrid()可用于绘制三维图

5. 线性模型Linear Model - y = x * w + b

有w,b两个参数,穷举最小值

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

x_data = [1.0, 2.0, 3.0]
y_data = [3.0, 4.0, 6.0]

def forward(x, w, b):
    return x * w + b

def loss(x, y, w, b):
    y_pred = forward(x, w, b)
    loss = (y_pred - y) * (y_pred - y)
    return loss

w_list = np.arange(0.0, 4.1, 0.1)
b_list = np.arange(-2.0, 2.1, 0.1)

# np.zeros(): 返回给定维度的全零数组; mse_matrix用于存储不同 w,b 组合下的均方误差损失
mse_matrix = np.zeros((len(w_list), len(b_list)))

for i, w in enumerate(w_list):
    for j, b in enumerate(b_list):
        l_sum = 0
        for x_val, y_val in zip(x_data, y_data):
            l_sum += loss(x_val, y_val, w, b)
        mse_matrix[i, j] = l_sum / len(x_data)

W, B = np.meshgrid(w_list, b_list)
fig = plt.figure('Linear Model Cost Value')
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(W, B, mse_matrix.T, cmap='viridis')
ax.set_xlabel('w')
ax.set_ylabel('b')
ax.set_zlabel('loss')
plt.show()

可以得出,穷举法算法的时间复杂度 随着参数的个数增大 而变得很大,因此使用穷举法找到最优解,很不合理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1925674.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【原创】springboot+mysql图书共享交流平台设计与实现

个人主页:程序猿小小杨 个人简介:从事开发多年,Java、Php、Python、前端开发均有涉猎 博客内容:Java项目实战、项目演示、技术分享 文末有作者名片,希望和大家一起共同进步,你只管努力,剩下的交…

HTTP请求走私漏洞原理与利用手段分析

文章目录 前言Http请求走私1.1 漏洞诞生场景1.2 漏洞基本原理1.3 HTTP1.1与2.0 请求走私分类2.1 CL.TE类型实例2.2 TE.CL类型实例2.3 TE.TE混淆实例2.4 漏洞检测工具? 请求走私利用3.1 绕过前端安全控制3.2 揭示前端请求重写3.3 捕获他人请求内容3.4 走私构造反射XS…

用Java链接MySQL数据库的总结

✨个人主页: 不漫游-CSDN博客 前言 在日常开发中,使用Java连接MySQL数据库是一个常见的任务,涉及多个步骤。接着我就带着大家细细看来~ 一.下载.jar 包文件 1.什么是.jar 文件 通俗点讲就是一个压缩包,不过里面存放的都是由Java代…

实验2——基于NAT技术的实验(基于实验1)

目录 实验拓扑图​ 实验要求: 实验思路 基于NAT的简单知识点: 实验步骤 1. 给路由器R1配置IP 2.创建区域 2.1 电信: 2.2 移动: 3.办公区的NAT策略 3.1 服务器映射(移动链路)​编辑 3.2 写一条分公…

【算法/数列】等差数列子序列算术序列

概念: 等差数列:任意两项的差总等于同一个常数 子数组 :是数组中的一个连续序列。 子序列:是通过从原序列删除零个或多个元素并在不改变顺序的情况下排列其余元素而获得的序列 算术序列:是一个数字列表,其中…

HyperSD - 会画草图就能玩AI绘画,AI一键手绘,实时同步 本地一键整合包下载

字节跳动的Lightning团队发布的新图像模型蒸馏算法Hyper-SD,是一项在图像处理和机器学习领域的重要进展。这项技术通过创新的方法提升了模型在不同推理步骤下的性能,同时保持了模型大小的精简。 基于这个算法模型,一个很实用的功能出现了&am…

Linux RTL8111/RTL8168 不能联网 / 最新版驱动下载安装

注: 机翻,未校对。 如何让 Realtek RTL8111/RTL8168 在 Linux 下工作 这篇文章于 2016 年 8 月在我原来的博客上发布。尽管如今 Linux 下的 RTL8111/RTL8168 网络接口的情况变得越来越稳定,但它们仍然会导致数据包丢失或网络连接不稳定等问题…

【错题集】ruby 和薯条(排序 + 二分 / 双指针)

牛客对应题目链接:ruby和薯条 (nowcoder.com) 一、分析题目 1、解法一:排序 二分。 先排序,然后枚举较⼤值,在 [1, i - 1] 区间找差值的左右端点即可。 2、解法二:排序 前缀和 双指针。 先排序; …

数据结构(Java):力扣Stack集合OJ题

1、括号匹配问题 . - 力扣(LeetCode) 1.1 思路分析 根据栈的先进后出原则,我们可以这样解决问题: 遍历字符串,遇见左括号就将左括号push入栈;遇见右括号就pop出栈,将出栈的元素和该右括号比较…

JDK14新特征最全详解

JDK 14一共发行了16个JEP(JDK Enhancement Proposals,JDK 增强提案),筛选出JDK 14新特性。 - 343: 打包工具 (Incubator) - 345: G1的NUMA内存分配优化 - 349: JFR事件流 - 352: 非原子性的字节缓冲区映射 - 358: 友好的空指针异常 - 359: Records…

游戏的无边框模式是什么?有啥用?

现在很多游戏的显示设置中,都有个比较特殊的选项“无边框”。小伙伴们如果尝试过,就会发现这个效果和全屏几乎一毛一样,于是就很欢快地用了起来,不过大家也许会发现,怎么和全屏比起来,似乎有点不够爽快&…

单例模式Singleton

设计模式 23种设计模式 Singleton 所谓类的单例设计模式,就是采取一定的方法保证在整个的软件系统中,对某个类只能存在一个对象实例,并且该类只提供一个取得其对象实例的方法。 饿汉式 public class BankTest {public static void main(…

[图解]SysML和EA建模住宅安全系统-14-黑盒系统规约

1 00:00:02,320 --> 00:00:07,610 接下来,我们看下一步指定黑盒系统需求 2 00:00:08,790 --> 00:00:10,490 就是说,把这个系统 3 00:00:11,880 --> 00:00:15,810 我们的目标系统,ESS,看成黑盒 4 00:00:18,030 --> …

Kafka基础入门篇(深度好文)

Kafka简介 Kafka 是一个高吞吐量的分布式的基于发布/订阅模式的消息队列(Message Queue),主要应用与大数据实时处理领域。 1. 以时间复杂度为O(1)的方式提供消息持久化能力。 2. 高吞吐率。(Kafka 的吞吐量是MySQL 吞吐量的30…

数据结构初阶(C语言)-复杂度的介绍

在学习顺序表之前,我们需要先了解下什么是复杂度: 一,复杂度的概念 我们在进行代码的写作时,通常需要用到许多算法,而这些算法又有优劣之分,区分算法的优劣则是通过算法的时间复杂度和空间复杂度来决定。 …

【眼疾病识别】图像识别+深度学习技术+人工智能+卷积神经网络算法+计算机课设+Python+TensorFlow

一、项目介绍 眼疾识别系统,使用Python作为主要编程语言进行开发,基于深度学习等技术使用TensorFlow搭建ResNet50卷积神经网络算法,通过对眼疾图片4种数据集进行训练(‘白内障’, ‘糖尿病性视网膜病变’, ‘青光眼’, ‘正常’&…

Python+wxauto=微信自动化?

Pythonwxauto微信自动化? 一、wxauto库简介 1.什么是wxauto库 wxauto是一个基于UIAutomation的开源Python微信自动化库。它旨在帮助用户通过编写Python脚本,轻松实现对微信客户端的自动化操作,从而提升效率并满足个性化需求。这一工具的出现&…

SAP PP学习笔记26 - User Status(用户状态)的实例,订单分割中的重要概念 成本收集器,Confirmation(报工)的概述

上面两章讲了生产订单的创建以及生产订单的相关内容。 SAP PP学习笔记24 - 生产订单(制造指图)的创建_sap 工程外注-CSDN博客 SAP PP学习笔记25 - 生产订单的状态管理(System Status(系统状态)/User Status(用户状态)),物料的可用性检查,生…

语音识别概述

语音识别概述 一.什么是语音? 语音是语言的声学表现形式,是人类自然的交流工具。 图片来源:https://www.shenlanxueyuan.com/course/381 二.语音识别的定义 语音识别(Automatic Speech Recognition, ASR 或 Speech to Text, ST…