2. 线性模型

news2024/10/6 0:29:48

b站刘二老师pytorch深度学习课程:https://www.bilibili.com/video/BV1Y7411d7Ys?p=2&vd_source=b17f113d28933824d753a0915d5e3a90

image-20230519145913075
  • 如果每周学习4个小时,那能够获得什么成绩?
    • y已知的是采样得到的数据,属于训练集(training)
    • y未知的是需要通过模型预测得到的,属于测试集(test)

机器学习的过程

  1. 先把数据集(Data Set)交给算法进行训练
  2. 然后把新的输入(Input)输给训练好的算法,能够得到相应的预测(Prediction)结果
image-20230628103509217
  • 在上述问题中,输入是x,输出是y,在训练集中x和y的值都是已知的,该类问题也称为监督学习(Supervised Learning)。即在学习训练过程中,知道输入所对应的输出值是多少,那么就可以在训练过程中得到模型预测值和真实值之间的偏差,从而对模型进行调整,使得该偏差尽可能小。

  • 模型在训练好之后,都要先用测试集(test set)进行测试,以评估模型的性能。

    • overfitting,过拟合:模型在训练集上表现很好,但在测试集中表现较差,即泛化能力差

      image-20230628104858231
    • 通常会将训练集再划分成两块,一块作为训练集,用于模型训练,另一部分作为开发集(验证集),用于模型的评估

      • 如果模型训练后评估的性能比较好,则再将所有训练集数据再丢到模型中训练,然后再用测试集测试
      image-20230628105015726

模型设计

  • 要解决的问题:**对数据而言,什么样的模型是最合适的???**即 f ( x ) f(x) f(x)的形式是什么???

    • 最基本的:线性模型
    image-20230628105758998
    • 线性模型中,训练时关键的就是确定 w w w b b b的值
      • w w w被称为权重, b b b称为偏置

对模型做一个简化,去掉截距 b b b

y ^ \hat{y} y^表示预测结果

image-20230628110404005
  • 不同的权重 w w w,曲线斜率不一样,那么该如何找到最优的权重值???

    • 初始的时候是随机猜测,权重可能大可能小,不一定正好落在真实值,因此需要进行评估

      • 当取了一个权重之后,它所表示的模型和数据集里面的数据之间的偏移程度有多大

        image-20230628110821357
      • 评估模型(loss):评估模型偏差

        • 平均损失降到最低
        image-20230628111137018 image-20230628111242780 image-20230628111336759
  • 如何找到合适的权重值,使得损失最小 ???

    • 损失函数(Loss function)是针对一个样本的,对于整个训练集需要将每个样本的预测值和真实值求差然后计算均方根误差
    image-20230628111803947 image-20230628111923193
    • 穷举法

      image-20230628112044006

实践代码:

image-20230628114247619
import numpy as np
import matplotlib.pyplot as plt

# 训练集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]


def forward(x):
    # 定义模型:y_hat = x * w
    return x * w


def loss(x, y):
    # 定义损失函数loss function
    y_pred = forward(x)
    return (y_pred - y) * (y_pred - y)  # (y_hat - y)的平方


# w_list存储每个权重值,mse_list存储每个权重值对应的损失值
w_list = []
mse_list = []

for w in np.arange(0.0, 4.1, 0.1):
    # 权重w在0到4.1之间进行采样,采样间隔为0.1,[0.0, 0.1, 0.2, ..., 4.0]
    print("w = ", w)
    l_sum = 0

    for x_val, y_val in zip(x_data, y_data):
        # 遍历训练集中的每一个样本,并计算每个样本的损失值
        # zip用法:https://blog.csdn.net/qq_45766916/article/details/125960493
        # zip(numbers, letters)创建一个生成 (x, y) 形式的元组的迭代器,[(numbers[0], letters[0]),…,(numbers[n], letters[n])]
        y_pred_val = forward(x_val)     # x_val的预测值
        loss_val = loss(x_val, y_val)   # x_val的损失值
        l_sum += loss_val               # 叠加每个样本的损失值(此处不是求均值)
        print('\t', x_val, y_val, y_pred_val, loss_val)

    print('MSE = ', l_sum / 3)          # 计算得到权重值w对应的损失值
    w_list.append(w)                    # 将权重w添加到w_list中
    mse_list.append(l_sum / 3)          # 将权重w对应的损失值添加到mse_list中

# loss曲线绘制,x轴是权重w,y轴是loss值,即表示每个权重值w对应的loss值
plt.plot(w_list, mse_list)
plt.ylabel('loss')
plt.xlabel('w')
plt.show()
  • 在深度学习中做训练的时候,loss曲线中一般不是用权重来做横坐标,而是训练轮数(epoch)

    image-20230628114657902

visdom:可视化工具包

  • https://github.com/fossasia/visdom
image-20230628114842294

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/695202.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用 css 禁用 input 控件实现 disable 效果

文章目录 需求分析代码 需求 使用 css 禁用 input 控件实现 disable 效果 分析 在 js 中,我们使用以下方法来阻止input,select,checkbox的默认事件,如 javascript event.preventDefault() event.stopPropagation()但在 css 中,我们可以设置对…

Android App的几个核心概念

Application启动 点击桌面图标启动App(如下流程图) 针对以上流程图示: ActivityManagerService#startProcessLocked()Process#start()ActivityThread#main(),入口分析的地方ActivityThread#attach(),这个里面的逻辑很核心 ActivityManagerS…

20230621 taro+vue3+webpack5+antdv时,在vue文件中特定组件时,devH5环境报错

问题 在某个vue文件下 import { notification } from ant-design-vue;然后在终端开始 yarn dev:h5在浏览器看效果 回出现以下错误 Uncaught (in promise) TypeError: __webpack_require__.hmd is not a functionat ./node_modules/.taro/h5/prebundle/ant-design-vue.js原因…

提高客户体验:智能客服外包服务的优势

随着科技的发展,智能客服外包服务越来越受到企业的青睐。这种服务能够帮助企业提高客户体验,减少客服成本,提高工作效率。本文将从技术、用户体验等方面阐述智能客服外包服务的优势。 人工智能技术 智能客服外包服务采用了一系列的技术手段&…

自然语言处理库NLTK的初步环境配置和使用例子

NLTK的基本介绍见此, 了解自然语言处理_bcbobo21cn的博客-CSDN博客 先安装python;然后用pip命令安装nltk; 然后进入python,下载nltk的数据包;输入下图语句,弹出一个框, 一般选择 all&#xff…

PHP 实验室设备系统mysql数据库web结构apache计算机软件工程网页wamp

一、源码特点 PHP 实验室设备系统 是一套完善的web设计系统,对理解php编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。 代码下载 https://download.csdn.net/download/qq_41221322/87959348https://downlo…

Kaggle 数据竞赛 | ICR - 鉴定与年龄相关的疾病

文章目录 一、前言二、主要内容1. 评估2. 时间线3. 奖金4. 代码要求 三、总结 🍉 CSDN 叶庭云:https://yetingyun.blog.csdn.net/ 一、前言 使用机器学习技术,通过匿名健康特征的测量数据来检测疾病。 比赛目标 本次比赛的目标是预测一个人是…

华为云Classroom一站式教学实践平台,开启云端教学新征程

随着高考落下帷幕,各高校将迎来新一届大学新生入学,他们的学长学姐们经过四年的学习,也即将步入社会,迈向一段新的人生旅程。 在这里小智先祝大家未来一切顺意,不忘初心,大鹏一日同风起,扶摇直…

win10环境下php安装thinkPHP5的曲线方式

win10环境下php安装thinkPHP5的曲线方式 强调一下在win10环境安装thinkPHP5需要使用Composer。 首先是thinkphp的教程:https://www.kancloud.cn/manual/thinkphp5/118006 你就会发现很坑逼 安装Composer的时候一种报错,就是php.ini文件错误。网上说什么…

Unity内置渲染管线升级URP教程

简介 URP全称为Universal Render Pipeline(通用渲染管线),可以提供更加灵活的渲染方案,通过添加Render Feature实现各种渲染效果。并且可以针对移动平台进行专门的优化,同时还提供了SRPBatcher提高渲染效率。Unity的一些工具,比如…

客户关系管理系统有哪些?5款客户关系管理软件评测

客户关系管理系统是一种企业与客户之间的交互平台,它将客户的需求、市场环境、企业的发展战略等融入到整个业务流程中,在企业和客户之间建立起一个共享的资源库,使企业对客户的了解更深更全面,进而实现与客户的深入互动&#xff0…

C#创建窗体应用程序

1、新建项目,选择窗体应用 2、打开相关视图 工具箱:将工具箱中的控件直接拖拽到界面中,会自动生成对用的控件。 属性:可以设置控件的相关属性,包括事件,双击 3、设计应用界面 4、新建一个交互窗口 5、在登…

MVP(Multi-view Prompting):多视图提示改进了方面情感元组预测

论文题目(Title):MVP: Multi-view Prompting Improves Aspect Sentiment Tuple Prediction 研究问题(Question):多视图提示对方面情感元组检测的影响 研究动机(Motivation)&#x…

AI智能人脸识别,抠图-应用证件照

效果展示: 关键代码: import numpy as np import cv2 import osdef crop_face(source_image_path, output_folder_path, tag_width, tag_height):face_detector cv2.CascadeClassifier(cv2.data.haarcascades haarcascade_frontalface_default.xm…

【算法题解】45. N叉树的层序遍历

这是一道 中等难度 的题 https://leetcode.cn/problems/n-ary-tree-level-order-traversal/ 题目 给定一个 N 叉树,返回其节点值的层序遍历。(即从左到右,逐层遍历)。 树的序列化输入是用层序遍历,每组子节点都由 nu…

fopen,fputs,fgets,fclose

fopen 是打开文件 fputs 往文件里面写内容(里面有2个参数其中第一个是一个char 型 数组用于存放读取的字符串,表示读取 1-n个字符。第二个表示是文件读入指针) fgets 读取文件里面的内容 (里面有三个参数其中 第一个是一个char 型 数组用于存放读取的字符串。第二…

阿里云安全组 IP地址段 设置方法 斜线后面数字含义

比如公司搬家后,我的ip变成了101.83.11.11 但我希望安全组中.只限制ip的前2段,后面两段是多少,都不会限制访问 先登录阿里云的服务器管理后台,找到主机,进安全组,添加 设置方法为 端口1234为自定义的,比如远程桌面访问你的服务器,源:就是ip地址段. 101.83.1.1/16解释 斜杠…

使用 Jetpack Compose 构建 LinearProgressIndicator

欢迎阅读这篇关于如何使用 Jetpack Compose 构建 LinearProgressIndicator(线性进度指示器)的博客。Jetpack Compose 是 Google 推出的一款现代化 UI 工具包,用于构建 Android 界面。其声明式的设计使得 UI 开发更加简洁、直观。 什么是 Line…

mfc100u.dll丢失的各种解决方法分享,探究mfc100u.dll文件

在计算机系统中,有许多重要的文件扮演着关键角色。其中之一就是Mfc100u.dll,一但这个文件丢失了,那么你的电脑就会出现问题,如程序运行不了等等。今天主要来给大家讲讲Mfc100u.dll这个文件,mfc100u.dll丢失的各种解决方…

由spring定时任务@Scheduled(cron = “0 0 0/1 * * ?“)引起的坑

这两天做到的一个功能,定时任务每整点生成一条记录,然后使用的cron表达式是: Scheduled(cron "0 0 0/1 * * ?")意为每整点执行一次。 定时任务执行之后使用new Date() 拿到当前本机时间,作为记录的创建时间&#xf…