实现一个简单的线性回归和多项式回归(2)

news2025/1/12 23:10:08

对于多项式回归,可以同样使用前面线性回归中定义的LinearRegression算子、训练函数train、均方误差函数mean_squared_error,生成数据集create_toy_data,这里就不多做赘述咯~

        拟合的函数为

def sin(x):
    y = torch.sin(2 * math.pi * x)
    return y

1.数据集的建立

func = sin
interval = (0,1)
train_num = 15
test_num = 10
noise = 0.5
X_train, y_train = create_toy_data(func=func, interval=interval, sample_num=train_num, noise = noise)
X_test, y_test = create_toy_data(func=func, interval=interval, sample_num=test_num, noise = noise)

X_underlying = torch.linspace(interval[0], interval[1], 100)
y_underlying = sin(X_underlying)

# 绘制图像
plt.rcParams['figure.figsize'] = (8.0, 6.0)
plt.scatter(X_train, y_train, facecolor="none", edgecolor='#e4007f', s=50, label="train data")
#plt.scatter(X_test, y_test, facecolor="none", edgecolor="r", s=50, label="test data")
plt.plot(X_underlying, y_underlying, c='#000000', label=r"$\sin(2\pi x)$")
plt.legend(fontsize='x-large')
plt.savefig('ml-vis2.pdf')
plt.show()

        生成结果: 

2.模型构建

# 多项式转换
def polynomial_basis_function(x, degree=2):
    """
    输入:
       - x: tensor, 输入的数据,shape=[N,1]
       - degree: int, 多项式的阶数
       example Input: [[2], [3], [4]], degree=2
       example Output: [[2^1, 2^2], [3^1, 3^2], [4^1, 4^2]]
       注意:本案例中,在degree>=1时不生成全为1的一列数据;degree为0时生成形状与输入相同,全1的Tensor
    输出:
       - x_result: tensor
    """

    if degree == 0:
        return torch.ones(x.size(), dtype=torch.float32)

    x_tmp = x
    x_result = x_tmp

    for i in range(2, degree + 1):
        x_tmp = torch.multiply(x_tmp, x)  # 逐元素相乘
        x_result = torch.concat((x_result, x_tmp), dim=-1)

    return x_result

        我的理解多项式回归的模型的建立更像是将原来的一个x变成x x^2 x^3 ... x^degree这个过程

3.模型训练

plt.rcParams['figure.figsize'] = (12.0, 8.0)

for i, degree in enumerate([0, 1, 3, 8]):  # []中为多项式的阶数
    model = Linear(degree)
    X_train_transformed = polynomial_basis_function(X_train.reshape([-1, 1]), degree)
    X_underlying_transformed = polynomial_basis_function(X_underlying.reshape([-1, 1]), degree)

    model = optimizer_lsm(model, X_train_transformed, y_train.reshape([-1, 1]))  # 拟合得到参数

    y_underlying_pred = model(X_underlying_transformed).squeeze()

    print(model.params)

    # 绘制图像
    plt.subplot(2, 2, i + 1)
    plt.scatter(X_train, y_train, facecolor="none", edgecolor='#e4007f', s=50, label="train data")
    plt.plot(X_underlying, y_underlying, c='#000000', label=r"$\sin(2\pi x)$")
    plt.plot(X_underlying, y_underlying_pred, c='#f19ec2', label="predicted function")
    plt.ylim(-2, 1.5)
    plt.annotate("M={}".format(degree), xy=(0.95, -1.4))

# plt.legend(bbox_to_anchor=(1.05, 0.64), loc=2, borderaxespad=0.)
plt.legend(loc='lower left', fontsize='x-large')
plt.savefig('ml-vis3.pdf')
plt.show()

         训练结果:

观察可视化结果,红色的曲线表示不同阶多项式分布拟合数据的结果:

* 当 $M=0$$M=1$时,拟合曲线较简单,模型欠拟合;

* 当 $M=8$ 时,拟合曲线较复杂,模型过拟合;

* 当 $M=3$ 时,模型拟合最为合理。

4.模型评估

        下面通过均方误差来衡量训练误差、测试误差以及在没有噪音的加入下sin函数值与多项式回归值之间的误差,更加真实地反映拟合结果。多项式分布阶数从0到8进行遍历。

# 训练误差和测试误差
training_errors = []
test_errors = []
distribution_errors = []

# 遍历多项式阶数
for i in range(9):
    model = Linear(i)

    X_train_transformed = polynomial_basis_function(X_train.reshape([-1, 1]), i)
    X_test_transformed = polynomial_basis_function(X_test.reshape([-1, 1]), i)
    X_underlying_transformed = polynomial_basis_function(X_underlying.reshape([-1, 1]), i)

    optimizer_lsm(model, X_train_transformed, y_train.reshape([-1, 1]))

    y_train_pred = model(X_train_transformed).squeeze()
    y_test_pred = model(X_test_transformed).squeeze()
    y_underlying_pred = model(X_underlying_transformed).squeeze()

    train_mse = mean_squared_error(y_true=y_train, y_pred=y_train_pred).item()
    training_errors.append(train_mse)

    test_mse = mean_squared_error(y_true=y_test, y_pred=y_test_pred).item()
    test_errors.append(test_mse)

    # distribution_mse = mean_squared_error(y_true=y_underlying, y_pred=y_underlying_pred).item()
    # distribution_errors.append(distribution_mse)

print("train errors: \n", training_errors)
print("test errors: \n", test_errors)
# print ("distribution errors: \n", distribution_errors)

# 绘制图片
plt.rcParams['figure.figsize'] = (8.0, 6.0)
plt.plot(training_errors, '-.', mfc="none", mec='#e4007f', ms=10, c='#e4007f', label="Training")
plt.plot(test_errors, '--', mfc="none", mec='#f19ec2', ms=10, c='#f19ec2', label="Test")
# plt.plot(distribution_errors, '-', mfc="none", mec="#3D3D3F", ms=10, c="#3D3D3F", label="Distribution")
plt.legend(fontsize='x-large')
plt.xlabel("degree")
plt.ylabel("MSE")
plt.savefig('ml-mse-error.pdf')
plt.show()

        可视化结果如下:

由可视化结果可得:

  1. 当阶数较低的时候,模型的表示能力有限,训练误差和测试误差都很高,代表模型欠拟合;
  2. 当阶数较高的时候,模型表示能力强,但将训练数据中的噪声也作为特征进行学习,一般情况下训练误差继续降低而测试误差显著升高,代表模型过拟合。

        对于模型过拟合的情况,可以引入正则化方法,通过向误差函数中添加一个惩罚项来避免系数倾向于较大的取值。

degree = 8 # 多项式阶数
reg_lambda = 0.0001 # 正则化系数

X_train_transformed = polynomial_basis_function(X_train.reshape([-1,1]), degree)
X_test_transformed = polynomial_basis_function(X_test.reshape([-1,1]), degree)
X_underlying_transformed = polynomial_basis_function(X_underlying.reshape([-1,1]), degree)

model = Linear(degree) 

optimizer_lsm(model,X_train_transformed,y_train.reshape([-1,1]))

y_test_pred=model(X_test_transformed).squeeze()
y_underlying_pred=model(X_underlying_transformed).squeeze()

model_reg = Linear(degree) 

optimizer_lsm(model_reg,X_train_transformed,y_train.reshape([-1,1]),reg_lambda=reg_lambda)

y_test_pred_reg=model_reg(X_test_transformed).squeeze()
y_underlying_pred_reg=model_reg(X_underlying_transformed).squeeze()

mse = mean_squared_error(y_true = y_test, y_pred = y_test_pred).item()
print("mse:",mse)
mes_reg = mean_squared_error(y_true = y_test, y_pred = y_test_pred_reg).item()
print("mse_with_l2_reg:",mes_reg)

# 绘制图像
plt.scatter(X_train, y_train, facecolor="none", edgecolor="#e4007f", s=50, label="train data")
plt.plot(X_underlying, y_underlying, c='#000000', label=r"$\sin(2\pi x)$")
plt.plot(X_underlying, y_underlying_pred, c='#e4007f', linestyle="--", label="$deg. = 8$")
plt.plot(X_underlying, y_underlying_pred_reg, c='#f19ec2', linestyle="-.", label="$deg. = 8, \ell_2 reg$")
plt.ylim(-1.5, 1.5)
plt.annotate("lambda={}".format(reg_lambda), xy=(0.82, -1.4))
plt.legend(fontsize='large')
plt.savefig('ml-vis4.pdf')
plt.show()

可视化结果为:

         多项式回归的难点就是是否真正理解了线性回归中的算子,均方误差等函数的目的和用法,其他的就是简单的函数调用问题啦~不懂的话还是建议多看看线性函数,先把线性函数的看明白最好~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1074724.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3、在 CentOS 8 系统上安装 PostgreSQL 15.4

PostgreSQL,作为一款备受欢迎的开源关系数据库管理系统(RDBMS),已经存在了三十多年的历史。它提供了SQL语言支持,用于管理数据库和执行CRUD操作(创建、读取、更新、删除)。 由于其卓越的健壮性…

Linux网络监控工具 - nethogs

nethogs 是一个基于命令行的网络监控工具,用于实时监视每个进程的网络流量。它可以显示每个进程使用的带宽、连接数和数据包数量等信息。 安装 在大多数Linux发行版中,您可以使用包管理器来安装 nethogs。例如,在Ubuntu/Debian上&#xff0c…

【Java 进阶篇】CSS语法格式详解

在前端开发中,CSS(层叠样式表)用于控制网页的样式和布局。了解CSS的语法格式是学习如何设计和美化网页的关键。本文将深入解释CSS的语法格式,包括选择器、属性和值等基本概念,同时提供示例代码以帮助初学者更好地理解。…

【单片机】18-红外线遥控

一、红外遥控背景知识 1.人机界面 (1)当面操作:按键,旋转/触摸按键,触摸屏 (2)遥控操作:红外遥控,433M/2.4G无线通信【穿墙能力强】,蓝牙-WIFI-Zigbee-LoRa等…

WPFdatagrid结合comboBox

在WPF的DataGrid中希望结合使用ComboBox下拉框,达到下拉选择绑定的效果,在实现的过程中,遇到了一些奇怪的问题,因此记录下来。 网上能够查询到的解决方案: 总共有三种ItemSource常见绑定实现方式: 1.ItemS…

【bug日记】spring项目使用配置类和测试类操作数据库

最近学校课程要求使用spring操作数据库&#xff0c;时间有点久了&#xff0c;操作都不太熟悉了&#xff0c;遇到了很多坑&#xff0c;特此记录一下。 导入依赖 <!-- Spring Framework --> <dependency><groupId>org.springframework</groupId><ar…

用Nginx搭建一个可用的静态资源Web服务器

sudo wget http://dlib.net/files/dlib-19.24.tar.bz2下载需要的文件。 sudo tar jxf dlib-19.24.tar.bz2进行解压。 sudo mkdir /nginx/dlib在nginx安装目录/nginx创建一个新的目录dlib。 配置文件里边的内容如下&#xff1a; worker_processes 1; events {worker_con…

如何批量获取1688商品详情数据接口,1688商品详情数据接口

批量获取1688商品详情数据接口的步骤如下&#xff1a; 获取API接口权限。编写API请求代码。应用爬取下来的数据。 1688商品详情数据接口步骤如下&#xff1a; 注册成为1688开放平台的开发者&#xff0c;并创建一个应用&#xff0c;获取到所需的App Key和App Secret等信息。使…

SpringBoot 如何使用 Prometheus 进行监控

在当今的软件开发世界中&#xff0c;监控是至关重要的一部分。它允许开发人员和运维团队实时跟踪应用程序的性能、可用性和健康状况。Spring Boot是一个流行的Java框架&#xff0c;用于构建微服务和Web应用程序&#xff0c;而Prometheus是一个开源的监控和警报工具。本文将介绍…

数据结构和算法——线性结构

文章目录 前言线性表顺序表链表合并有序链表反转链表 队列循环队列双端队列资源分配问题 栈共享栈表达式求值递归处理迷宫问题 串串的模式匹配BF算法KMP算法next数组的求解next数组的优化 前言 本文所有代码均在仓库中&#xff0c;这是一个完整的由纯C语言实现的可以存储任意类…

spark-08

学习视频&#xff1a; 黑马程序员Spark全套视频教程&#xff0c;4天spark3.2快速入门到精通&#xff0c;基于Python语言的spark教程_哔哩哔哩_bilibili

增强LLM:使用搜索引擎缓解大模型幻觉问题

论文题目&#xff1a;FRESHLLMS:REFRESHING LARGE LANGUAGE MODELS WITH SEARCH ENGINE AUGMENTATION 论文地址&#xff1a;https://arxiv.org/pdf/2310.03214.pdf 论文由Google、University of Massachusetts Amherst、OpenAI联合发布。 大部分大语言模型只会训练一次&#…

Spring Data Redis使用方式

1.导入Spring Data Redis的maven坐标 pom.xml <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-data-redis</artifactId> </dependency> 2. 配置Redis数据源 2.1application.yml文件…

idea compile项目正常,启动项目的时候build失败,报“找不到符号”等问题

1、首先往上找&#xff0c;看能不能找到如下报错信息 You aren’t using a compiler supported by lombok, so lombok will not work and has been disabled. 2、这种问题属于lombok编译失败导致&#xff0c;可能原因是依赖jar包没有更新到最新版本 3、解决方案 1&#xff09…

C语言编程实现只有一个未知数的两个多项式合并的程序

背景&#xff1a; 直接看题目把&#xff01;就是C语言写两个多项式多项式合并 题目要求&#xff1a; 1. 题目&#xff1a; 编程实现只有一个未知数的两个多项式合并的程序。如&#xff1a; 3x^26x7 和 5x^2-2x9合并结果为8x^24x16。 2. 设计要求 &#xff08;1&#xff09…

有哪些值得推荐的Java 练手项目?

大家好&#xff0c;我是 jonssonyan 我是一名 Java 后端程序员&#xff0c;偶尔也会写一写前端&#xff0c;主要的技术栈是 JavaSpringBootMySQLRedisVue.js&#xff0c;基于我学过的技术认真的对每个分享的项目进行鉴别&#xff0c;今天就和大家分享我曾经用来学习的开源项目…

微调codebert、unixcoder、grapghcodebert完成漏洞检测代码

文件结构如下所示&#xff1a; mode.py # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch import torch.nn as nn import torch from torch.autograd import Variable import copy from torch.nn import CrossEntropyLoss, MSELosscl…

若依4.7.6 版本任意文件下载(CVE-2023-27025)

CVE-2023-27025 框架说明 若依/ruoyi 是使用java主流框架的一款优秀的国内开源cms&#xff0c; 基于SpringBoot、Shiro、Mybatis的权限后台管理系统。 环境搭建 查询最近的漏洞信息 https://cve.mitre.org/ 搜索ruoyi 代码审计感兴趣的漏洞&#xff1a;CVE-2023-27025 …

MyBatis-Plus 内置雪花算法主键重复问题

Mybatis-Plus 使用ID_WORKER生成主键id重复 问题描述 目前项目使用的id是mybatis-plus 内置的主键生成策略 ID_WORKER &#xff0c;最近测试在做性能压测&#xff0c;部署架构是单服务集群的部署方式&#xff0c;然后就发现了id重复的异常&#xff0c;异常如下 问题分析 首先分…

JSONUtil.parse将java对象转为json时,需要在java对象中设置get、set方法

想要使用JSONUtil.parse将java对象转为json格式&#xff0c;但是一直为空&#xff0c;代码如下 public class MyTest {public static void main(String[] args) {Test3<String> test3 new Test3<>("2","hhhhhhaaa");System.out.println(JSON…