2 线性模型

news2024/11/25 23:27:37

文章目录

    • 一般流程
    • 问题引入
    • 数据集与测试集
      • 过拟合与泛化
      • 开发集
      • 监督学习和非监督学习
    • 问题分析
      • 训练集、验证集、测试集
    • 模型设计
      • 模拟训练过程
    • 课程代码
      • 课后习题代码

课程来源: 链接
文档参考: 链接
以及 BirandaのBlog!

一般流程

对于一般的线性模型来说,分析问题的格式一般为:

  1. 搜集数据
  2. 选择模型
  3. 神经训练
  4. 开始推理

问题引入

对于某同学平时花费x小时,期末得到的分数y的图表:

x(小时)y(分)
12
24
36
4

求问在平时花费4小时的情况下,最终的成绩为?

数据集与测试集

详细见:链接

数据集拿到后一般划分为两部分,训练集和测试集,然后使用训练集的数据来训练模型,用测试集上的误差作为最终模型在应对现实场景中的泛化误差。

我们可以使用训练集的数据来训练模型,然后用测试集上的误差作为最终模型在应对现实场景中的泛化误差。有了测试集,我们想要验证模型的最终效果,只需将训练好的模型在测试集上计算误差,即可认为此误差即为泛化误差的近似,我们只需让我们训练好的模型在测试集上的误差最小即可。

过拟合与泛化

下面拿小猫图像识别做例子,说明一下过拟合和泛化的概念;

过拟合: 在训练集上匹配度很好,但是太过了,把噪声什么的也学进来了。

泛化能力: 对于没见过的图像也能进行识别,这是我们所需要的。

开发集

有时候无法看到测试集,我们又人为地把数据集划分一部分出来作为验证评估,称为“开发集”。

监督学习和非监督学习

有监督学习方法必须要有训练集与测试样本。在训练集中找规律,而对测试样本使用这种规律。而非监督学习没有训练集,只有一组数据,在该组数据集内寻找规律。

有监督学习的方法就是识别事物,识别的结果表现在给待识别数据加上了标签。因此训练样本集必须由带标签的样本组成。而非监督学习方法只有要分析的数据集的本身,预先没有什么标签。如果发现数据集呈现某种聚集性,则可按自然的聚集性分类,但不予以某种预先分类标签对上号为目的。

问题分析

数据集需要交付给算法模型进行训练,利用所训练的模型,在获得新的数据时可以获得相应的输出。(监督学习)

训练集、验证集、测试集

按照上面的介绍很简单就可以得出1和2当作训练集,3当作验证集,4当作测试集。

模型设计

线性模型的基本模型 y ^ = ω x + b \widehat y=\omega x+b y =ωx+b,其中的 ω \omega ω b b b是模型中的参数,训练模型的过程即为确定模型中参数的过程

在本模型中设置成 y ^ = ω x \widehat y=\omega x y =ωx,对于不同的 ω \omega ω有不同的线性模型及图像与之对应。

在这里插入图片描述

模拟训练过程

在模型训练中会先随机取得一个值,继而计算其和标准量之间的偏移量,从而判断当前模型是否符合预期。
记实际值为 y ( x ) y(x) y(x),模型对应的预测值为 y ^ ( x ) \widehat y(x) y (x),则其中的偏移量为 ∣ y ^ ( x ) − y ( x ) ∣ \left|\widehat y(x)-y(x)\right| y (x)y(x),以此来代表模型估计值对原值的误差。
通常,该公式定义为Training Loss (Error)
l o s s = ( y ^ − y ) 2 = ( ω x − y ) 2 loss = (\widehat y - y)^2 = (\omega x - y)^2 loss=(y y)2=(ωxy)2
原题目中的几种 ω \omega ω所对应的Loss如下

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
其中的每行为 w w w不同时的单个样本的损失,最后一行为平均损失。
对于单个样本,有loss可用于指代样本误差。对于所有样本,可同理用(MSE)来指代整体样本的平均平方误差(均方差cost)
c o s t = 1 N ∑ n = 1 N ( y ^ n − y n ) 2 cost = \frac{1}{N} \displaystyle\sum_{n=1}^{N}(\widehat y_n-y_n)^2 cost=N1n=1N(y nyn)2
由cost的计算公式可知,当平均损失为0时,模型最佳,但由于仅当数据无噪声且模型完美贴合数据的情况下才会出现这种情况,因此模型训练的目的应当是尽可能小,而非找到误差为0的情况。

课程代码

import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
#前馈计算
def forward(x):
    return x * w
#求loss
def loss(x, y):
    y_pred = forward(x)
    return (y_pred-y)*(y_pred-y)

w_list = []
mse_list = []
#从0.0一直到4.1以0.1为间隔进行w的取样
for w in np.arange(0.0,4.1,0.1):
    print("w=", w)
    l_sum = 0
    for x_val,y_val in zip(x_data,y_data):
        y_pred_val = forward(x_val)
        loss_val = loss(x_val,y_val)
        l_sum += loss_val
        print('\t',x_val,y_val,y_pred_val,loss_val)
    print("MSE=",l_sum/3)
    w_list.append(w)
    mse_list.append(l_sum/3)

#绘图
plt.plot(w_list,mse_list)
plt.ylabel("Loss")
plt.xlabel('w')
plt.show()

输出结果:

在这里插入图片描述

课后习题代码

前提知识:3d绘图链接

import numpy as np
import matplotlib.pyplot as plt;
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

#线性模型
def forward(x,w,b):
    return x * w+ b

#损失函数
def loss(x, y,w,b):
    y_pred = forward(x,w,b)
    return (y_pred - y) * (y_pred - y)


def mse(w,b):
    l_sum = 0
    for x_val, y_val in zip(x_data, y_data):
        y_pred_val = forward(x_val,w,b)
        loss_val = loss(x_val, y_val,w,b)
        l_sum += loss_val
        print('\t', x_val, y_val, y_pred_val, loss_val)
    print('MSE=', l_sum / 3)
    return  l_sum/3

#迭代取值,计算每个w取值下的x,y,y_pred,loss_val
mse_list = []



##画图

##定义网格化数据
b_list=np.arange(-30,30,0.1)
w_list=np.arange(-30,30,0.1);

##生成网格化数据
xx, yy = np.meshgrid(b_list, w_list,sparse=False, indexing='xy')

##每个点的对应高度
zz=mse(xx,yy)

fig = plt.figure()
ax = Axes3D(fig)
x = np.arange(-4, 4, 0.25)
y = np.arange(-4, 4, 0.25)
x, y = np.meshgrid(x, y)
r = np.sqrt(x**2 + y**2)
z = np
ax.plot_surface(xx, yy, zz, rstride=1, cstride=1, cmap=cm.viridis)
plt.show()

输出结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/181210.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微信小程序 Springboot校园招聘求职系统

基于微信小程序的校园求职系统的设计基于现有的手机,可以实现首页、个人中心、岗位类型管理、用户管理、企业管理、招聘信息管理、应聘信息管理、系统管理等功能。方便用户对首页、招聘信息、我的等详细的了解及统计分析 一个基本的程序包含app.json、project.confi…

谈谈SpringBoot(二)

1. Spring Boot缓存 1.1 JSR-107 Spring从3.1开始定义了org.springframework.cache.Cache 和org.springframework.cache.CacheManager接口来统一不同的缓存技术; 并支持使用JCache(JSR-107)注解简化我们开发。 Cache接口为缓存的组件规范定义…

day23|93.复原IP地址、78.子集、90.子集II

93.复原IP地址 有效 IP 地址 正好由四个整数(每个整数位于 0 到 255 之间组成,且不能含有前导 0),整数之间用 . 分隔。 例如:"0.1.2.201" 和 "192.168.1.1" 是 有效 IP 地址,但是 &q…

layui框架学习(2:颜色、图标、动画)

B站教学视频中对Layui的颜色没有专门介绍,而Layui官方教程中虽然有颜色章节,但也只是简单介绍了基色调、辅色调、中性的颜色的概念及用途,最后说明layui 内置了七种背景色,以便用于各种元素中,如:徽章、分割…

Go语言基础入门第二章

Go语言环境安装 下载地址:https://golang.google.cn/dl/ 下载完安装包直接安装即可,安装完毕后,打开cmd控制台,输入”go version“查看是否安装成功以及对应安装版本。 配置环境变量Go语言需要一个安装目录,还需要一个…

Spring Cloud_Eureka服务注册与发现

目录一、Eureka基础知识1.什么是服务治理2.什么是服务注册3.Eureka两组件二、单机Eureka构建步骤1.IDEA生成eurekaServer端服务注册中心2.服务提供者3.服务消费者代码链接 https://github.com/lidonglin-bit/cloud 一、Eureka基础知识 1.什么是服务治理 SpringCloud封装了Ne…

金融风控09

迁移学习 为什么要? 源域样本与目标域样本分布有区别,目标域样本量不够 平时建模用的迁移学习场景 1、新开某个消费分期场景样本量少,需要用其他场景的数据建模 2、业务被迫停滞3个月再重启,大部分训练样本比较老旧&#xff…

含分布式光伏的配电网集群划分和集群电压协调控制(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

激光在大气中传输特性

在光纤通信中,单模光纤在波长1.55μm窗口具有巨大的潜在带宽和极低的损耗,传输数字信号的容量已能达到10Tb/s,每信道光源功率仅需1mW左右,无中间放大的距离超过100km。而且,光纤作为光波导,红外…

PostSharp Ultimate添加模式和线程安全

PostSharp Ultimate添加模式和线程安全 PostSharpUltimate允许开发人员通过将重复的工作从人身上转移到机器上,从而消除样板代码。它包含最常见模式的现成实现,并为您提供了为自己的模式构建自动化的工具。开发人员通常根据设计模式进行思考,…

Go语言基础入门第一章

Go语言基础入门 Go语言的logo 为什么需要一个新的语言最近十年来,C/C在计算机领域得到了很好的发展,并没有新的系统编程语言出现。对开发程度和执行效率在很多情况下并不能兼得。要么是执行效率高,但是低效的开发和编译,如C&…

Redux Toolkit + React + Tailwind CSS 学习心得

Redux Toolkit React Tailwind CSS 学习心得 预览地址:https://goldenaarcher.com/movie-app-home-only,只实现了一个简单的首页功能,API 用的是 the Movie Database,不想用 API 的也可以装一个 faker-js/faker 用来随机生成伪…

学生护眼灯怎么选择?分享适合学生的护眼灯

现阶段的青少年与儿童的近视率非常高,选择一款好的台灯能够保证双眼的健康,首先先看亮度是否合理,不能刺眼,选择三基色灯管,光很柔和,看频闪,好的护眼台灯可以做到无可视频闪,是的视…

移动web适配和Less

移动web适配和Lessrem 适配rem 单位媒体查询flexible.js如何把设计稿的px转换为remLESSLess注释less 运算less 嵌套less 变量less导入less 导出控制当前Less文件导出less 禁止导出小结rem 适配 rem 单位 rem 是一个相对单位,1rem 就是 html 文字的大小 比如 /* …

Java基础10:常用API

Java基础10:常用API一、Math二、System1. currentTimeMillis2. arraycopy三、Runtime四、Object1. toString2. equals3. clone五、Objects六、BigInteger1. 构造方法(获取BigInteger)2. 常用方法七、BigDecimal1. 构造方法(获取Bi…

计算机相关专业混体制的解决方案(考公务员)

文章目录序:编制介绍1、公务员报考要求2、公务员工作待遇3、公务员工作内容4、公务员报考复习序:编制介绍 编制介绍:编制,也就是常说的铁饭碗。 编制的诞生为了控制吃财政饭的人员数量无限膨胀而设置的,所以名额有限受…

密码学基本概念

密码学简介 密码是经过加密过后的口令,是指用特定的变换对数据信息进行加密保护或者安全身份认证的物质和技术,密码学是对安全通信技术的研究,要能够有效的防范潜在攻击,也就是对信息加密解密的过程。 密码基本性质 密码学的发展…

CSS3 选择器 :nth-child 与 :nth-of-type 区别

一、:nth-child 1.1 说明 :nth-child(n) 选择器匹配属于其父元素的第 N 个子元素&#xff0c;不论元素的类型。n 可以是数字、关键词或公式。 注意&#xff1a;如果第 N 个子元素与选择的元素类型不同则样式无效&#xff01; 1.2 示例 <style> div>p:nth-child(2…

1行Python代码识别身份证信息,还能自动告警,YYDS

大家好&#xff0c;这里是程序员晚枫。 录入身份证信息是一件繁琐的工作&#xff0c;如果可以自动识别并且录入系统&#xff0c;那可真是太好了。 今天我们就来学习一下&#xff0c;如何自动识别身份证信息并且录入系统~ 识别身份证信息 识别身份证信息的代码最简单&#x…

【金融量化】CTA策略之VeighNa量化实战笔记(1)

量化投资实战笔记 1 基本概念 1、一手股票&#xff1a;100支股票 2、收盘比开盘上涨的百分比&#xff1a;&#xff08;收盘-开盘&#xff09;/开盘 3、开盘比前日收盘的百分比&#xff1a;&#xff08;开盘-前日收盘&#xff09;/前日收盘 4、从dataframe中取每个月的第一天 …