pytorch搭建神经网络(手搓方法)

news2025/1/10 10:20:36

假如我们有一个数据集形状为(348,14)。即有348个记录,每个记录有14个特征值。

我们想要搭建一个如下的神经网络:

import torch
import numpy as np

# 创建数据集: 每个样本有14个特征
x_train = np.array([
    [0.5, -1.2, 0.3, 0.8, 1.0, -0.5, 2.3, 1.2, -0.3, 1.5, -1.1, 0.6, -0.8, 0.7],
    [1.5,  2.2, 1.3, -0.7, 1.1,  0.5, -1.3, 0.4,  1.2, 0.8,  0.3, 0.6,  2.1, 0.2],
    [0.9, -0.2, -0.5, -1.2, 1.3, -1.1, 0.7,  1.5,  0.9, 1.0, -0.4, 0.5, -1.0, 1.4],
    [-0.4, 0.8, 1.2, -0.1, 1.5, 0.2, 0.6, -1.3, 1.0, 1.3, 0.3, -0.9, 1.1, 0.5],
    [1.0, 0.2, -1.4, 0.3, -0.7, 1.1, -0.1, 0.5, 0.6, 1.5, 0.7, -0.5, 0.9, -0.2]
], dtype=np.float64)

y_train = np.array([[5.0], [6.0], [4.0], [7.0], [3.0]], dtype=np.float64)

# 将数据转化为张量
x = torch.tensor(x_train, dtype=torch.float64)
y = torch.tensor(y_train, dtype=torch.float64)

# 权重参数初始化
weights1 = torch.randn((14, 128), dtype=torch.float64, requires_grad=True)  # 输入维度是14, 第一层有128个神经元
biases1 = torch.randn((1, 128), dtype=torch.float64, requires_grad=True)
weights2 = torch.randn((128, 1), dtype=torch.float64, requires_grad=True)
biases2 = torch.randn((1, 1), dtype=torch.float64, requires_grad=True)

learning_rate = 0.001  # 学习率
losses = []  # 用于存储损失值

for i in range(1000):  # 这里遍历1000次
    net1 = x.mm(weights1) + biases1  # 如果有5条记录那么结构为[5,128]
    out1 = torch.relu(net1)  # 通过激活函数[5,128]

    predictions = out1.mm(weights2) + biases2  # 输出,预测值[5,1]

    # 计算损失
    loss = torch.mean((predictions - y) ** 2)
    losses.append(loss.detach().numpy())

    if i % 100 == 0:
        print("loss:", loss.item())

    # 反向传播计算
    loss.backward()

    # 更新权重
    with torch.no_grad():  # 使用no_grad避免梯度跟踪
        weights1 -= learning_rate * weights1.grad
        biases1 -= learning_rate * biases1.grad
        weights2 -= learning_rate * weights2.grad
        biases2 -= learning_rate * biases2.grad

    # 每次迭代清空梯度累加值
    weights1.grad.zero_()
    biases1.grad.zero_()
    weights2.grad.zero_()
    biases2.grad.zero_()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2185596.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在Ubuntu 20.04中安装CARLA

0. 引言 CARLA (Car Learning to Act) 是一款开源自动驾驶模拟器,其支持自动驾驶系统全管线的开发、训练和验证(Development, Training, and Validation of autonomous driving systems)。Carla提供了丰富的数字资产,例如城市布局…

前端编程艺术(2)----CSS

目录 1.CSS 2.CSS引入 3.选择器 1.标签选择器 2.类选择器 3.id选择器 4.属性选择器 5.后代选择器 5.直接子元素选择器 6.伪类选择器 链接相关 动态伪类 结构化伪类 否定伪类 其他伪类 UI元素状态伪类 4.字体 1.font-family 2.font-size 3.font-style 4.fo…

Linux查找隐藏病毒进程

工具连接 下载工具不要分,随便下 下载后修改工具名:如修改为lsof、ps、top等并为工具加入执行权限 2、 直接执行即可,与正常命令用法一致(截图如下)

足球预测推荐软件:百万数据阐述百年足球历史-大数据模型量化球员成就值

我开始创建这个模型是从梅西22世界杯夺冠第二天开始准备的,当时互联网上充斥了太多了个人情感的输出,有的人借题对C罗冷嘲热讽,有的人质疑梅西的阿根廷被安排夺冠不配超越马拉多纳做GOAT。作为一个从2002年开始看球的球迷,说实话有…

linux自用小手册

一、GDB常用命令 想用gdb调试C或C程序,编译时需要加-g选项,编译出的文件为debug状态(如果不加则是release状态),且不可以加-O选项进行优化。 命令简写解释set args 设置程序传递的参数 例:./demo -v value…

【MySQL报错】---Data truncated for column ‘age‘ at row...

目录 一、前言二、问题分析三、解决办法 一、前言 欢迎大家来到权权的博客~欢迎大家对我的博客进行指导,有什么不对的地方,我会及时改进哦~ 博客主页链接点这里–>:权权的博客主页链接 二、问题分析 问题一修改表结构 XXX 为 not n…

指针 (5)

目录 1. 字符指针变量 2. 数组指针变量 3. ⼆维数组传参的本质 4. 函数指针变量 5.typedef 关键字 6 函数指针数组 7.转移表 计算器的⼀般实现 1. 字符指针变量 在指针的类型中我们知道有⼀种指针类型为字符指针 char* #include <stdio.h> int main() {char* ch …

ARM assembly: Lesson 10

今天&#xff0c;我们来看一下基于ARM汇编&#xff0c;如何实现函数的调用。 基础知识 在ARM汇编中&#xff0c;函数的前四个参数存放于 R0~R3寄存器中, 剩余的参数存放于栈中&#xff0c;返回值存放于r0。在栈中存放数值&#xff0c;可以避免在调用过程中&#xff0c;数据的…

记一次炉石传说记牌器 Crash 排查经历

大家好这里是 Geek技术前线。最近在打炉石过程中遇到了HSTracker记牌器的一个闪退问题&#xff0c;尝试性排查了下原因。这里简单记录一下 最近炉石国服回归&#xff1b;由于设备限制&#xff0c;我基本只会在 Mac 上打炉石。并且由于主要打竞技场&#xff0c;所以记牌器是必不…

解决问题AttributeError: “safe_load“ has been removed, use

解决问题AttributeError: "safe_load" has been removed, use~ 1. 问题描述2. 解决方法 1. 问题描述 在复现cdvae代码时&#xff0c;运行 python scripts/compute_metrics.py --root_path MODEL_PATH --tasks recon gen opt评估模型时&#xff0c;出现以下问题。 …

Pikachu-Cross-Site Scripting-xss盲打

xss盲打&#xff0c;不是一种漏洞类型&#xff0c;而是一个攻击场景&#xff1b;在前端、或者在当前页面是看不到攻击结果&#xff1b;而是在后端、在别的页面才看到结果。 登陆后台&#xff0c;查看结果&#xff1b;

Custom C++ and CUDA Extensions - PyTorch

0. Abstract 经历了一波 pybind11 和 CUDA 编程 的学习, 接下来看一看 PyTorch 官方给的 C/CUDA 扩展的教程. 发现极其简单, 就是直接用 setuptools 导出 PyTorch C 版代码的 Python 接口就可以了. 所以, 本博客包含以下内容: LibTorch 初步;C Extension 例子; 1. LibTorch …

python-鸡尾酒疗法/图像相似度/第n小的质数

一&#xff1a;鸡尾酒疗法 题目描述 鸡尾酒疗法&#xff0c;原指“高效抗逆转录病毒治疗”&#xff08;HAART&#xff09;&#xff0c;由美籍华裔科学家何大一于 1996 年提出&#xff0c;是通过三种或三种以上的抗病毒药物联合使用来治疗艾滋病。该疗法的应用可以减少单一用药产…

什么是ETL?什么是ELT?怎么区分它们使用场景

ELT和ETL这两种模式从字面上来看就是一个顺序颠倒的问题&#xff0c;每个单词拆开来看其实都是一样的。E代表的是Extract&#xff08;抽取&#xff09;&#xff0c;也就是从源端拉取数据&#xff1b;T代表的是Transform&#xff08;转换&#xff09;&#xff0c;对一些结构化或…

Visual Studio2017编译GDAL3.0.2源码过程

一、编译环境 操作系统&#xff1a;Windows 10企业版 编译工具&#xff1a;Visual Studio 2017旗舰版 源码版本&#xff1a;gdal3.0.2 二、生成解决方案 打开Visual Studio 2017的x64本机生成工具&#xff0c;切换到gdal3.0.2源码根目录&#xff1b;执行generate_vcxproj.b…

D25【 python 接口自动化学习】- python 基础之判断与循环

day25 for 循环 学习日期&#xff1a;20241002 学习目标&#xff1a;判断与循环&#xfe63;-35 for 循环&#xff1a;如何遍历一个对象里的所有元素&#xff1f; 学习笔记&#xff1a; for 循环与while循环的区别 for循环的定义 使用for循环遍历序列 使用for循环遍历字典…

【理论科学与实践技术】数学与经济管理中的学科与实用算法

在现代商业环境中&#xff0c;数学与经济管理的结合为企业提供了强大的决策支持。包含一些主要学科&#xff0c;包括数学基础、经济学模型、管理学及风险管理&#xff0c;相关的实用算法和这些算法在中国及全球知名企业中的实际应用。 一、数学基础 1). 发现人及著名学者 发…

目标检测评价指标

混淆矩阵&#xff08;Confusion Matrix&#xff09; 准确率&#xff08;accuracy&#xff09; 准确率&#xff1a;预测正确的样本数 / 样本数总数 &#xff08;正对角线 / 所有&#xff09; 精度&#xff08;precision&#xff09; 精度&#xff1a;预测正确里面有多少确实是…

深入理解MySQL中的MVCC原理及实现

目录 什么是MVCC&#xff1f; MVCC实现原理 Undo Log 日志 InnoDB行格式 undo日志格式 1. insert undo log格式 2. update undo log格式 事务回滚机制 Read View MVCC案例分析 案例01-读已提交RC隔离级别下的可见性分析 案例02-可重复读RR隔离级别下的可见性分析 什…

英语词汇小程序小程序|英语词汇小程序系统|基于java的四六级词汇小程序设计与实现(源码+数据库+文档)

英语词汇小程序 目录 基于java的四六级词汇小程序设计与实现 一、前言 二、系统功能设计 三、系统实现 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取&#xff1a; 博主介绍&#xff1a;✌️大厂码农|毕设布道师&a…