【动手学深度学习Pytorch】4. 神经网络基础

news2024/11/23 8:52:41

模型构造

        回顾一下感知机。

nn.Sequential():定义了一种特殊的module。

torch.rand():用于生成具有均匀分布的随机数,这些随机数的范围在[0, 1)之间。它接受一个形状参数(shape),返回一个指定形状的张量(Tensor)。

import torch
from torch import nn
from torch.nn import functional as F
# 定义网络结构
net = nn.Sequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))
# 生成随机数[0,1),大小为2x20
X = torch.rand(2, 20)
# 输入网络
net(X)

torch.nn.functional.relu():ReLU激活函数

# 自定义MLP模块
class MLP(nn.Module):
    def __init__(self): #初始化类与参数
        super().__init__() #父类Module的init
        self.hidden = nn.Linear(20, 256) #隐藏层
        self.out = nn.Linear(256, 10) #输出层

    def forward(self, X): #进行前向传播
        return self.out(F.relu(self.hidden(X)))

# 实例化MLP类,然后在每次调用前向传播函数的时候调用这些层
net = MLP() # 实例化的时候,()里面的参数是init中的参数
net(X) # 调用的时候,()里面的参数是forward中的参数

.values(): 返回一个字典中所有值

# 自定义顺序块
class MySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        for block in args:
            self._modules[block] = block

    def forward(self, X):
        for block in self._modules.values():
            X = block(X)
        return X

net = MySequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 100))
net (X)

 torch.mm:矩阵乘法

# 隐藏层固定的MLP
class FixedHiddenMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_weight = torch.rand((20, 20), requires_grad = False) #权重固定无需进行训练
        self.linear = nn.Linear(20, 20)

    def forward(self, X):
        X = self.linear(X)
        X = F.relu(torch.mm(X, self.rand_weight) + 1)
        X = self.linear(X)
        while X.abs().sum() > 1:
            X /= 2
        return X.sum()

net = FixedHiddenMLP()
net(X)
# 混合嵌套MLP
class NestMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),
                                nn.Linear(64, 32), nn.ReLU())
        self.linear = nn.Linear(32, 16)

    def forward(self, X):
        return self.linear(self.net(X))

chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP())
chimera(X)

 参数管理

state_dict:一个Python字典,将每一层的参数映射成Tensor张量,只包含具有可学习参数的层,例如卷积层(Conv2d)和全连接层(Linear)。权重和偏置就是Linear层的状态。

import torch
from torch import nn
# 关注具有隐藏层的多感知机
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8,1))
X = torch.rand(size = (2, 4))
net(X)
# 参数访问
print(net[2].state_dict()) #net[2]拿到的是最后的输出层

bias:偏置,包括参数值和梯度值

bias.data:访问参数值

bias.grad:访问梯度值

named_parameters():返回一个包含元组的迭代器,每个元组包含两个元素,参数的名称和参数的值。

# 目标函数
print(type(net[2].bias))
print(net[2].bias)
print(net[2].bias.data) # 参数值
net[2].weight.grad == None # 梯度值

# 一次性访问所有参数
print(*[(name, param.shape) for name, param in net[0].named_parameters()])
print(*[(name, param.shape) for name, param in net.named_parameters()])
# ReLU没有参数所以拿不出
net.state_dict()['2.bias'].data
# 通过名字获取想要的参数

net.add_module(name, module): 添加每一层,并且为每一层增加了一个单独的名字。

# 从嵌套块收集函数
def block1():
    return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4), nn.ReLU())

def block2():
    net = nn.Sequential()
    for i in range(4):
        net.add_module(f'block{i}', block1())
    return net

rgnet = nn.Sequential(block2(), nn.Linear(4, 1))
rgnet(X)
print(rgnet)

torch.nn.init模块中的所有函数都用于初始化神经网络参数,因此它们都在torch.no_grad()模式下运行,autograd不会将其考虑在内。

nn.init.normal_(tensor, mean=0.0, std=1.0): 将tensor变成指定mean和std正态分布

torch.nn.init.zeros_(tensor): 将tensor置零

torch.nn.init.constant_(tensor, val): 用val数值填充tensor

# 内置初始化
def init_normal(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, mean=0, std=0.01)
        nn.init.zeros_(m.bias)

net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]

def init_constant(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight, 1)
        nn.init.zeros_(m.bias)

net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]

​ 

 torch.nn.ini.xavier_uniform_(tensor, gain=1.0):

# 对某些块应用不同的初始化方法
def xavier(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)

def init_42(m):
    if type(m) == nn.Linear:
        nn. init.constant_(m.weight, 42)

net[0].apply(xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)

nn.init.uniform_(tensor, val1, val2): 将tensor变成均匀分布

def my_init(m):
    if type(m) == nn.Linear:
        print(
            "Init",
            *[(name, param.shape) for name, param in m.named_parameters()][0])
        nn.init.uniform_(m.weight, -10, 10)
        m.weight.data *= m.weight.data.abs() >= 5

net.apply(my_init)
net[0].weight[:2]
net[0].weight.data[:]+=1
net[0].weight.data[0,0]=42
net[0].weight.data[0]

# 参数绑定
shared = nn.Linear(8, 8)
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), shared, nn.ReLU(), shared, nn.ReLU(), nn.Linear(8, 1))
net(X)
print(net[2].weight.data[0] == net[4].weight.data[0])
net[2].weight.data[0,0] = 100
print(net[2].weight.data[0] == net[4].weight.data[0])

自定义层

# 自定义层
import torch
import torch.nn.functional as F
from torch import nn

class CenteredLayer(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, X):
        return X - X.mean()

layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
Y = net(torch.rand(4, 8))
Y.mean()
class MyLinear(nn.Module):
    def __init__(self, in_units, units):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_units, units))
        self.bias = nn.Parameter(torch.randn(units,))

    def forward(self, X):
        linear = torch.matmul(X, self.weight.data) + self.bias.data
        return F.relu(linear)

dense = MyLinear(5, 3)
dense.weight

# 使用自定义层直接执行正向传播计算
dense(torch.rand(2,5))
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
net(torch.rand(2, 64))

读写文件

import torch
from torch import nn
from torch.nn import funcational as F
# 加载和保存张量
x = torch.arange(4)
torch.save(x, 'x-file')

x2 = torch.load("x-file")
# 存储一个张量列表,读回内存
y = torch.zeros(4)
torch.save([x, y], 'x-files')
x2, y2 = torch.load('x-files')
(x2, y2)
# 写入或读取字符串映射到张量的字典
mydict = {'x':x, 'y':y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')
# 加载和保存模型参数
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
torch.save(net.state_dict(), 'mlp.params')
# 实例化了原始多层感知机模型的一个备份
clone = MLP()
clone.load_state_dict(torch.load("mlp.params"))
clone.eval()
Y_clone = clone(X)
Y_clone == Y

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2245880.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

新版自助笔记-工作记录-2024-11-18

环境&#xff1a; Windows11 .Net 4.5.2 Vs20151.Web <sKey>平台上获取的通讯码</sKey> Web -> 设置 -> 系统设置 -> 通讯密钥<SoftKey>设备身份识别码</SoftKey> Web -> 终端设备管理 -> 身份识别码<ZZUrl>Web服务</ZZUr…

【Linux课程学习】:进程程序替换,execl,execv,execlp,execvp,execve,execle,execvpe函数

&#x1f381;个人主页&#xff1a;我们的五年 &#x1f50d;系列专栏&#xff1a;Linux课程学习 &#x1f337;追光的人&#xff0c;终会万丈光芒 &#x1f389;欢迎大家点赞&#x1f44d;评论&#x1f4dd;收藏⭐文章 ​ ​ 目录 替换原理&#xff1a; 替换函数&…

Bug:引入Feign后触发了2次、4次ContextRefreshedEvent

Bug&#xff1a;引入Feign后发现监控onApplication中ContextRefreshedEvent事件触发了2次或者4次。 【原理】在Spring的文档注释中提示到&#xff1a; Event raised when an {code ApplicationContext} gets initialized or refreshed.即当 ApplicationContext 进行初始化或者刷…

【智谱清言-注册_登录安全分析报告】

前言 由于网站注册入口容易被机器执行自动化程序攻击&#xff0c;存在如下风险&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露&#xff0c;不符合国家等级保护的要求。短信盗刷带来的拒绝服务风险 &#xff0c;造成用户无法登陆、注册&#xff0c;大量收到垃圾短信的…

煤炉Mercari新手开店十问十答

在跨境电商的浪潮中&#xff0c;Mercari&#xff08;煤炉&#xff09;作为一个备受瞩目的C2C二手商品交易平台&#xff0c;吸引了众多新手卖家的目光。然而&#xff0c;初次在Mercari开店可能会遇到各种问题和挑战。为此&#xff0c;我特别整理了2024年最新的十问十答指南&…

[面试]-golang基础面试题总结

文章目录 panic 和 recover**注意事项**使用 pprof、trace 和 race 进行性能调试。**Go Module**&#xff1a;Go中new和make的区别 Channel什么是 Channel 的方向性&#xff1f;如何对 Channel 进行方向限制&#xff1f;Channel 的缓冲区大小对于 Channel 和 Goroutine 的通信有…

【Flask+Gunicorn+Nginx】部署目标检测模型API完整解决方案

【Ubuntu 22.04FlaskGunicornNginx】部署目标检测模型API完整解决方案 文章目录 1. 搭建深度学习环境1.1 下载Anaconda1.2 打包环境1.3 创建虚拟环境1.4 报错 2. 安装flask3. 安装gunicorn4. 安装Nginx4.1 安装前置依赖4.2 安装nginx4.3 常用命令 5. NginxGunicornFlask5.1 ng…

一个用纯PHP开发的服务器-workerman

什么是Workerman 简单的说Workerman是一个纯php开发的服务器。 workerman的目标是让PHP开发者更容易的开发出基于socket的高性能的应用服务&#xff0c;而不用去了解PHP socket以及PHP多进程细节。 workerman本身是一个PHP多进程服务器&#xff0c;类似nginxphp-fpm的结合体&am…

如何在Linux上安装Canal同步工具

1. 下载安装包 所用到的安装包 canal.admin-1.1.4.tar.gz 链接&#xff1a;https://pan.baidu.com/s/1B1LxZUZsKVaHvoSx6VV3sA 提取码&#xff1a;v7ta canal.deployer-1.1.4.tar.gz 链接&#xff1a;https://pan.baidu.com/s/13RSqPinzgaaYQUyo9D8ZCQ 提取码&#xff1a;…

Leetcode 组合

使用回溯来解决此问题。 提供的代码使用了回溯法&#xff08;Backtracking&#xff09;&#xff0c;这是一种通过递归探索所有可能解的算法思想。以下是对算法思想的详细解释&#xff1a; 核心思想&#xff1a; 回溯法通过以下步骤解决问题&#xff1a; 路径选择&#xff1a…

PyTorch使用教程-深度学习框架

PyTorch使用教程-深度学习框架 1. PyTorch简介 1.1-什么是PyTorch ​ PyTorch是一个广泛使用的开源机器学习框架&#xff0c;特别适合深度学习的应用。它以其动态计算图而闻名&#xff0c;允许在运行时修改模型&#xff0c;使得实验和调试更加灵活。PyTorch提供了强大的GPU加…

HTML的自动定义倒计时,这个配色存一下

<!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>自定义倒计时</title><style>* {mar…

Spark SQL 之 QueryStage

ExchangeQueryStageExec ExchangeQueryStageExec 分为两种

自由学习记录(23)

Lua的学习 table.concat(tb,";") 如果表里带表&#xff0c;则不能拼接&#xff0c;表里带nil也不能&#xff0c;都会报错 true和false也不可以&#xff0c;数字和字符串可以 if要和一个end配对&#xff0c;所以 if a>b then return true end end 两个end …

GoZero对接GPT接口的设计与实现:问题分析与解决

在本篇文章中&#xff0c;我们将探讨如何在GoZero框架下对接GPT接口&#xff0c;并详细讨论在实现过程中遇到的一些常见问题及其解决方案。特别是遇到的错误信息&#xff0c;如 parse parameter fail,recover: interface conversion: interface {} is nil, not string 和 获取历…

【2024 Optimal Control 16-745】【Lecture 3 + Lecture4】minimization.ipynb功能分析

主要功能-最小化问题 目标函数分析: 定义函数 f ( x ) f(x) f(x) 及其一阶、二阶导数。使用绘图工具可视化函数的形状。 实现数值优化: 使用牛顿法寻找函数的极值点&#xff0c;结合一阶和二阶导数加速收敛。使用正则化牛顿法解决二阶导数矩阵可能不正定的问题。 可视化过程…

实现 UniApp 右上角按钮“扫一扫”功能实战教学

实现 UniApp 右上角按钮“扫一扫”功能实战教学 需求 点击右上角扫一扫按钮(onNavigationBarButtonTap监听)&#xff0c;打开扫一扫页面(uni.scanCode) 扫描后&#xff0c;以网页的形式打开扫描内容(web-view组件)&#xff0c;限制只能浏览带有执行域名的网站&#xff0c;例如…

ThreadLocal 父子线程、线程池、数据传递

讲一下背景&#xff1a;springboot 项目。写了个拦截器&#xff0c;解析请求头 Authorization 中传过来的 token&#xff0c;获取到登录用户信息&#xff0c;然后通过 ThreadLocal 存起来&#xff0c;后面的业务代码从 ThreadLocal 取用户信息。 再说下问题&#xff1a;当业务代…

uniapp 微信小程序地图标记点、聚合点/根据缩放重合点,根据缩放登记显示气泡marik标点

如图&#xff0c;如果要实现上方的效果&#xff1a; 上方两个效果根据经纬度标记点缩放后有重复点会添加数量 用到的文档地址https://developers.weixin.qq.com/miniprogram/dev/api/media/map/MapContext.addMarkers.htmlMapContext.addMarkers(Object object) 添加标记点Ma…

ubuntu下如何使用C语言访问Mysql数据库(详细介绍并附有案例)

一、配置 首先&#xff0c;确保你已经安装了MySQL服务器和MySQL Connector/C库。在Linux上&#xff0c;你可以使用包管理器来安装这些&#xff0c;例如&#xff1a; sudo apt-get install mysql-server libmysqlclient-dev 在ubuntu的机器上&#xff0c;库文件通常保存在 /li…