基于简单神经网络的线性回归

news2025/4/3 1:13:50

一、概述

本代码实现了一个简单的神经网络进行线性回归任务。通过生成包含噪声的线性数据集,定义一个简单的神经网络类,使用梯度下降算法训练网络以拟合数据,并最终通过可视化展示原始数据、真实线性关系以及模型的预测结果。

二、依赖库

  1. numpy:用于数值计算,包括生成数组、进行随机数操作、执行数学运算等。
  2. matplotlib.pyplot:用于数据可视化,绘制散点图和折线图以展示数据和模型的预测结果。

三、代码详解

1. 生成数据集

python

np.random.seed(42)
x = np.linspace(-10, 10, 100)
y = x + np.random.normal(0, 1, x.shape)  # 添加噪声

  • np.random.seed(42):设置随机数种子,确保每次运行代码时生成的随机数序列相同,从而使结果可复现。
  • np.linspace(-10, 10, 100):生成一个包含 100 个元素的一维数组x,元素均匀分布在 - 10 到 10 之间。
  • x + np.random.normal(0, 1, x.shape):生成因变量y,它基于真实的线性关系y = x,并添加了均值为 0、标准差为 1 的高斯噪声。np.random.normal(0, 1, x.shape)生成与x形状相同的随机噪声数组。

2. 定义神经网络(线性回归)

python

class SimpleNN:
    def __init__(self):
        self.w = np.random.randn()  # 权重
        self.b = np.random.randn()  # 偏置

    def forward(self, x):
        return self.w * x + self.b  # 前向传播

    def loss(self, y_true, y_pred):
        return np.mean((y_true - y_pred) **2)  # 均方误差

    def gradient(self, x, y_true, y_pred):
        dw = -2 * np.mean(x * (y_true - y_pred))  # 权重的梯度
        db = -2 * np.mean(y_true - y_pred)       # 偏置的梯度
        return dw, db

    def train(self, x, y, lr=0.01, epochs=1000):
        for epoch in range(epochs):
            y_pred = self.forward(x)
            dw, db = self.gradient(x, y, y_pred)
            self.w -= lr * dw  # 更新权重
            self.b -= lr * db  # 更新偏置
            if (epoch + 1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Loss: {self.loss(y, y_pred):.4f}')

  • __init__方法:初始化神经网络的权重self.w和偏置self.b,使用np.random.randn()生成随机的初始值。
  • forward方法:实现前向传播,根据输入x、权重self.w和偏置self.b计算输出y_pred,即y_pred = self.w * x + self.b
  • loss方法:计算预测值y_pred和真实值y_true之间的均方误差(MSE),公式为np.mean((y_true - y_pred) ** 2)
  • gradient方法:计算权重self.w和偏置self.b的梯度。dw是权重的梯度,计算公式为-2 * np.mean(x * (y_true - y_pred))db是偏置的梯度,计算公式为-2 * np.mean(y_true - y_pred)
  • train方法:使用梯度下降算法训练神经网络。在指定的epochs(训练轮数)内,每次迭代进行前向传播计算预测值y_pred,然后计算梯度dwdb,根据学习率lr更新权重self.w和偏置self.b。每 100 轮打印一次当前轮数和损失值。

3. 训练模型

python

model = SimpleNN()
model.train(x, y, lr=0.01, epochs=1000)

  • SimpleNN():创建一个SimpleNN类的实例model
  • model.train(x, y, lr=0.01, epochs=1000):调用modeltrain方法,使用生成的数据集xy,学习率lr=0.01,训练轮数epochs=1000进行训练。

4. 可视化结果

python

y_pred = model.forward(x)
plt.scatter(x, y, label='Data points')
plt.plot(x, x, color='red', label='y = x')
plt.plot(x, y_pred, color='green', label='Predicted')
plt.legend()
plt.show()

  • model.forward(x):使用训练好的模型model对数据集x进行前向传播,得到预测值y_pred
  • plt.scatter(x, y, label='Data points'):绘制原始数据集的散点图,标签为Data points
  • plt.plot(x, x, color='red', label='y = x'):绘制真实的线性关系y = x的折线图,颜色为红色,标签为y = x
  • plt.plot(x, y_pred, color='green', label='Predicted'):绘制模型预测结果的折线图,颜色为绿色,标签为Predicted
  • plt.legend():显示图例,方便区分不同的图形。
  • plt.show():显示绘制好的图形。

四、注意事项

  1. 本代码实现的是一个简单的线性回归神经网络,实际应用中可能需要更复杂的模型结构和优化方法。
  2. 学习率lr和训练轮数epochs是超参数,可能需要根据具体数据和任务进行调整以获得更好的训练效果。
  3. 代码中使用的均方误差损失函数和梯度计算公式是针对线性回归问题的常见选择,但在其他问题中可能需要使用不同的损失函数和梯度计算方法。

完整代码

import numpy as np
import matplotlib.pyplot as plt

# 1. 生成数据集
np.random.seed(42)
x = np.linspace(-10, 10, 100)
y = x + np.random.normal(0, 1, x.shape)  # 添加噪声

# 2. 定义神经网络(线性回归)
class SimpleNN:
    def __init__(self):
        self.w = np.random.randn()  # 权重
        self.b = np.random.randn()  # 偏置

    def forward(self, x):
        return self.w * x + self.b  # 前向传播

    def loss(self, y_true, y_pred):
        return np.mean((y_true - y_pred) **2)  # 均方误差

    def gradient(self, x, y_true, y_pred):
        dw = -2 * np.mean(x * (y_true - y_pred))  # 权重的梯度
        db = -2 * np.mean(y_true - y_pred)       # 偏置的梯度
        return dw, db

    def train(self, x, y, lr=0.01, epochs=1000):
        for epoch in range(epochs):
            y_pred = self.forward(x)
            dw, db = self.gradient(x, y, y_pred)
            self.w -= lr * dw  # 更新权重
            self.b -= lr * db  # 更新偏置
            if (epoch + 1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Loss: {self.loss(y, y_pred):.4f}')

# 3. 训练模型
model = SimpleNN()
model.train(x, y, lr=0.01, epochs=1000)

# 4. 可视化结果
y_pred = model.forward(x)
plt.scatter(x, y, label='Data points')
plt.plot(x, x, color='red', label='y = x')
plt.plot(x, y_pred, color='green', label='Predicted')
plt.legend()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2326066.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【nvidia】Windows 双 A6000 显卡双显示器驱动更新问题修复

问题描述:windows自动更新nvidia驱动会导致只检测得到一个A6000显卡。 解决方法 下载 A6000 驱动 572.83-quadro-rtx-desktop-notebook-win10-win11-64bit-international-dch-whql.exehttps://download.csdn.net/download/qq_18846849/90554276 不要直接安装。如…

《SRv6 网络编程:开启IP网络新时代》第2章、第3章:SRv6基本原理和基础协议

背景 根据工作要求、本人掌握的知识情况,仅针对《SRv6 网络编程:开启IP网络新时代》书籍中涉及的部分知识点进行总结梳理,并与工作小组进行分享,不涉及对原作的逐字搬运。 问题 组内同事提出的问题:本文缺扩展头描述…

如何将AI模型返回的字符串转为html元素?

场景&#xff1a; 接入deepseek模型的api到我们平台&#xff0c;返回的字符串需要做下格式化处理。 返回的数据是这样的&#xff1a; {"role": "assistant","content": "<think>\n嗯&#xff0c;用户问的是“星体是什么”。首先&am…

【PCIE711-214】基于PCIe总线架构的4路HD-SDI/3G-SDI视频图像模拟源

产品概述 PCIE711-214是一款基于PCIE总线架构的4路SDI视频模拟源。该板卡为标准的PCIE插卡&#xff0c;全高尺寸&#xff0c;适合与PCIE总线的工控机或者服务器&#xff0c;板载协议处理器&#xff0c;可以通过PCIE总线将上位机的YUV 422格式视频数据下发通过SDI接口播放出去&…

突破反爬困境:SDK开发,浏览器模块(七)

声明 本文所讨论的内容及技术均纯属学术交流与技术研究目的&#xff0c;旨在探讨和总结互联网数据流动、前后端技术架构及安全防御中的技术演进。文中提及的各类技术手段和策略均仅供技术人员在合法与合规的前提下进行研究、学习与防御测试之用。 作者不支持亦不鼓励任何未经授…

rce操作

Linux命令长度突破限制 源码 <?php $param $_REQUEST[param];if ( strlen($param) < 8 ) {echo shell_exec($param); } echo执行函数&#xff0c;$_REQUEST可以接post、get、cookie传参 源码中对参数长度做了限制&#xff0c;小于8位&#xff0c;可以利用临时函数&…

LabVIEW高效溢流阀测试系统

开发了一种基于LabVIEW软件和PLC硬件的溢流阀测试系统。通过集成神经网络优化的自适应PID控制器&#xff0c;该系统能自动进行压力稳定性、寿命以及动静态性能测试。该设计不仅提升了测试效率&#xff0c;还通过智能化控制提高了数据的精确性和操作的便捷性。 ​ 项目背景&…

DataGear 5.3.0 制作支持导出表格数据的数据可视化看板

DataGear 内置表格图表底层采用的是DataTable表格组件&#xff0c;默认并未引入导出数据的JS支持库&#xff0c;如果有导出表格数据需求&#xff0c;则可以在看板中引入导出相关JS支持库&#xff0c;制作具有导出CSV、Excel、PDF功能的表格数据看板。 在新发布的5.3.0版本中&a…

Web网页内嵌 Adobe Pdf Reader 谷歌Chrome在线预览编辑PDF文档

随着数字化办公的普及&#xff0c;PDF文档已成为信息处理的核心载体&#xff0c;虽然桌面端有很多软件可以实现预览编辑PDF文档&#xff0c;而在线在线预览编辑PDF也日益成为一个难题。 作为网页内嵌本地程序的佼佼者——猿大师中间件&#xff0c;之前发布的猿大师办公助手&am…

Sentinel[超详细讲解]-1

定义一系列 规则 &#x1f47a;&#xff0c;对资源进行 保护 &#x1f47a;&#xff0c; 如果违反的了规则&#xff0c;则抛出异常&#xff0c;看是否有fallback兜底处理&#xff0c;如果没有则直接返回异常信息&#x1f60e; 1. 快速入门 1.1 引入 Sentinel 依赖 <depend…

如何让 SQL2API 进化为 Text2API:自然语言生成 API 的深度解析?

在过去的十年里&#xff0c;技术的进步日新月异&#xff0c;尤其是在自动化、人工智能与自然语言处理&#xff08;NLP&#xff09;方面。 随着“低代码”平台的崛起&#xff0c;开发者和非技术人员能够更轻松地构建强大而复杂的应用程序。然而&#xff0c;尽管技术门槛降低了&…

OCCT(2)Windows平台编译OCCT

文章目录 一、Windows平台编译OCCT1、准备环境2、下载源码3、下载第三方库4、使用 CMake 配置5、编译OCCT源码6、运行示例 一、Windows平台编译OCCT 1、准备环境 安装工具&#xff1a; Visual Studio&#xff08;推荐 VS2019/2022&#xff0c;选择 C 桌面开发 组件&#xff0…

【蓝桥杯—单片机】通信总线专项 | 真题整理、解析与拓展 (更新ing...)

通信总线专项 前言SPI第十五届省赛题 UART/RS485/RS232UARTRS485RS232第十三届省赛题小结和拓展&#xff1a;传输方式的分类第十三届省赛 其他相关考点网络传输速率第十五届省赛题第十二届省赛题 前言 在本文中我会把 蓝桥杯单片机赛道 历年真题 中涉及到通信总线的题目整理出…

Uni-app页面信息与元素影响解析

获取窗口信息uni.getWindowInfo {pixelRatio: 3safeArea:{bottom: 778height: 731left: 0right: 375top: 47width: 375}safeAreaInsets: {top: 47, left: 0, right: 0, bottom: 34},screenHeight: 812,screenTop: 0,screenWidth: 375,statusBarHeight: 47,windowBottom: 0,win…

CentOS(最小化)安装之后,快速搭建Docker环境

本文以VMware虚拟机中安装最小化centos完成后开始。 1. 检查网络 打开网卡/启用网卡 执行命令ip a查看当前的网络连接是否正常&#xff1a; 如果得到的结果和我一样&#xff0c;有ens网卡但是没有ip地址&#xff0c;说明网卡未打开 手动启用&#xff1a; nmcli device sta…

【身份证证件OCR识别】批量OCR识别身份证照片复印件图片里的文字信息保存表格或改名字,基于QT和腾讯云api_ocr的实现方式

项目背景 在许多业务场景中,需要处理大量身份证照片复印件,手动输入其中的文字信息效率低下且容易出错。利用 OCR(光学字符识别)技术可以自动识别身份证图片中的文字信息,结合 QT 构建图形用户界面,方便用户操作,同时使用腾讯 OCR API 能够保证较高的识别准确率。 界面…

IP属地和发作品的地址不一样吗

在当今这个数字化时代&#xff0c;互联网已经成为人们日常生活不可或缺的一部分。随着各大社交平台功能的不断完善&#xff0c;一个新功能——IP属地显示&#xff0c;逐渐走进大众视野。这一功能在微博、抖音、快手等各大平台上得到广泛应用&#xff0c;旨在帮助公众识别虚假信…

Redis - 概述

目录 ​编辑 一、什么是redis 二、redis能做什么&#xff08;有什么特点&#xff09;&#xff1f; 三、redis有什么优势 四、Redis与其他key-value存储有什么不同 五、Redis命令 六、Redis数据结构 1、基础数据结构 2、高级数据结构 一、什么是redis 1、redis&#x…

vue3 根据城市名称计算城市之间的距离

<template><div class"distance-calculator"><h1>城市距离计算器</h1><!-- 城市输入框 --><div class"input-group"><inputv-model"city1"placeholder"请输入第一个城市"keyup.enter"cal…

html 列表循环滚动,动态初始化字段数据

html <div class"layui-row"><div class"layui-col-md4"><div class"boxall"><div class"alltitle">超时菜品排行</div><div class"marquee-container"><div class"scroll-…