使用Pytorch构建自定义层并在模型中使用

news2024/11/28 19:01:50

使用Pytorch构建自定义层并在模型中使用

继承自nn.Module类,自定义名称为NoisyLinear的线性层,并在新模型定义过程中使用该自定义层。完整代码可以在jupyter nbviewer中在线访问。

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mlxtend.plotting import plot_decision_regions
print(torch.__version__)
print(np.__version__)
2.0.1+cu118
1.24.4
创建一个包含有噪声的线性层
class NoisyLinear(nn.Module):
    def __init__(self, input_size, output_size, noise_stddev=0.1):
        super().__init__()
        w = torch.Tensor(input_size, output_size)
        self.w = nn.Parameter(w)
        nn.init.xavier_uniform_(self.w)
        b = torch.Tensor(output_size).fill_(0)
        self.b = nn.Parameter(b)
        self.noise_stddev = noise_stddev

    def forward(self, x, training=False):
        if training:
            noise = torch.normal(0.0, self.noise_stddev, x.shape)
            x_new = torch.add(x, noise)
        else:
            x_new = x
        return torch.add(torch.mm(x_new, self.w), self.b)

这段代码定义了一个名为 NoisyLinear 的类,它继承自 nn.Module,表示一个包含噪声的线性层。

class NoisyLinear(nn.Module):

定义一个名为 NoisyLinear 的类,它继承自 PyTorch 的 nn.Module 类。这意味着它可以被用作一种神经网络层。

    def __init__(self, input_size, output_size, noise_stddev=0.1):

初始化方法 __init__ 接受三个参数:输入大小 input_size,输出大小 output_size,以及噪声的标准差 noise_stddev(默认值为 0.1)。

        super().__init__()

调用父类 nn.Module 的初始化方法,以确保父类的相关属性和方法被正确初始化。

        w = torch.Tensor(input_size, output_size)

创建一个形状为 (input_size, output_size) 的张量 w,用于存储权重。

        self.w = nn.Parameter(w)

将权重 w 包装为 nn.Parameter,这意味着在训练过程中,PyTorch 会自动将其视为可学习参数。

        nn.init.xavier_uniform_(self.w)

使用 Xavier 均匀分布对权重 self.w 进行初始化。这是一种常用的初始化方法,有助于保持神经网络中信号的方差。

        b = torch.Tensor(output_size).fill_(0)

创建一个形状为 (output_size,) 的张量 b,并将其填充为 0,用于存储偏置。

        self.b = nn.Parameter(b)

将偏置 b 包装为 nn.Parameter,使其在训练过程中也是可学习的。

        self.noise_stddev = noise_stddev

将噪声的标准差 noise_stddev 存储为类的一个属性,用于后续的噪声计算。

    def forward(self, x, training=False):

定义前向传播方法 forward,接受输入 x 和一个布尔参数 training,指示当前是否在训练模式下。

        if training:

检查当前是否处于训练模式。

            noise = torch.normal(0.0, self.noise_stddev, x.shape)

如果是训练模式,则创建一个与输入 x 形状相同的噪声张量 noise,其服从均值为 0、标准差为 self.noise_stddev 的正态分布。

            x_new = torch.add(x, noise)

将噪声添加到输入 x 上,得到新的输入 x_new

        else:

如果不是训练模式,则执行以下代码。

            x_new = x

在非训练模式下,x_new 直接设置为输入 x,即没有添加噪声。

        return torch.add(torch.mm(x_new, self.w), self.b)

计算输出:首先用 torch.mm 进行矩阵乘法(x_new 和权重 self.w),然后将偏置 self.b 添加到结果中。最后返回计算出的输出。

总结来说,这个类实现了一个带噪声的线性变换,在线性层中可以根据训练模式选择性地添加噪声。

# 上述层的使用示例.
# 1、实例化这个层,并调用三次.
torch.manual_seed(1)

noisy_layer = NoisyLinear(4, 2)
x = torch.zeros((1, 4))
print(noisy_layer(x, training=True))

print(noisy_layer(x, training=True))

print(noisy_layer(x, training=False))
tensor([[ 0.1154, -0.0598]], grad_fn=<AddBackward0>)
tensor([[ 0.0432, -0.0375]], grad_fn=<AddBackward0>)
tensor([[0., 0.]], grad_fn=<AddBackward0>)
在一个示例数据上,构建一个包含该自定义层的模型
# 生成一个示例数据.
np.random.seed(1)
torch.manual_seed(1)
x = np.random.uniform(low=-1, high=1, size=(200, 2))
y = np.ones(len(x))
y[x[:, 0] * x[:, 1]<0] = 0

n_train = 100
x_train = torch.tensor(x[:n_train, :], dtype=torch.float32)
y_train = torch.tensor(y[:n_train], dtype=torch.float32)
x_valid = torch.tensor(x[n_train:, :], dtype=torch.float32)
y_valid = torch.tensor(y[n_train:], dtype=torch.float32)

fig = plt.figure(figsize=(6, 6))
plt.plot(x[y==0, 0], 
         x[y==0, 1], 'o', alpha=0.75, markersize=10)
plt.plot(x[y==1, 0], 
         x[y==1, 1], '<', alpha=0.75, markersize=10)
plt.xlabel(r'$x_1$', size=15)
plt.ylabel(r'$x_2$', size=15)
plt.tight_layout()
plt.show()

在这里插入图片描述

# 创建一个DataLoader.
train_ds = TensorDataset(x_train, y_train)
batch_size = 2
torch.manual_seed(1)

# 使用DataLoader加载数据,batchsize为2.
train_dl = DataLoader(train_ds, batch_size, shuffle=True)
# 创建一个新的模型,并且调用上述的自定义层.
class MyNoiseModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = NoisyLinear(2, 4, 0.07)
        self.a1 = nn.ReLU()
        self.l2 = nn.Linear(4, 4)
        self.a2 = nn.ReLU()
        self.l3 = nn.Linear(4, 1)
        self.a3 = nn.Sigmoid()

    def forward(self, x, training=False):
        x = self.l1(x, training)
        x = self.a1(x)
        x = self.l2(x)
        x = self.a2(x)
        x = self.l3(x)
        x = self.a3(x)
        return x
    
    def predict(self, x):
        self.eval()
        with torch.no_grad():
            x = torch.tensor(x, dtype=torch.float32)
            pred = self.forward(x)[:, 0]
            return (pred>=0.5).float()
# 模型实例化.
torch.manual_seed(1)
model = MyNoiseModule()
model
MyNoiseModule(
  (l1): NoisyLinear()
  (a1): ReLU()
  (l2): Linear(in_features=4, out_features=4, bias=True)
  (a2): ReLU()
  (l3): Linear(in_features=4, out_features=1, bias=True)
  (a3): Sigmoid()
)
# 3.在训练training batch上计算预测结果.
loss_fn = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.015)
# 模型训练,设置epochs=200
torch.manual_seed(1)
num_epochs = 200

def train(model, num_epochs, train_dl, x_valid, y_valid):
    loss_hist_train = [0] * num_epochs
    acc_hist_train = [0] * num_epochs

    loss_hist_valid = [0] * num_epochs
    acc_hist_valid = [0] * num_epochs

    for epoch in range(num_epochs):
        for x_batch, y_batch in train_dl:
            pred = model(x_batch, True)[:, 0]
            loss = loss_fn(pred, y_batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            loss_hist_train[epoch] += loss.item()
            is_correct = ((pred>=0.5).float() == y_batch).float()
            acc_hist_train[epoch] += is_correct.mean()
        loss_hist_train[epoch] /= n_train/batch_size
        acc_hist_train[epoch] /= n_train/batch_size

        pred = model(x_valid)[:, 0]
        loss = loss_fn(pred, y_valid)
        loss_hist_valid[epoch] = loss.item()
        is_correct = ((pred>=0.5).float() == y_valid).float()
        acc_hist_valid[epoch] += is_correct.mean()
    return loss_hist_train, loss_hist_valid, \
            acc_hist_train, acc_hist_valid

history = train(model, num_epochs, train_dl, x_valid, y_valid)
# 绘制决策边界.
fig = plt.figure(figsize=(16, 4))
ax = fig.add_subplot(1, 3, 1)
plt.plot(history[0], lw=4)
plt.plot(history[1], lw=4)
plt.legend(['Train loss', 'Validation loss'], fontsize=15)
ax.set_xlabel('Epochs', size=15)

ax = fig.add_subplot(1, 3, 2)
plt.plot(history[2], lw=4)
plt.plot(history[3], lw=4)
plt.legend(['Train acc.', 'Validation acc.'], fontsize=15)
ax.set_xlabel('Epochs', size=15)

ax = fig.add_subplot(1, 3, 3)
plot_decision_regions(X=x_valid.numpy(), 
                      y=y_valid.numpy().astype(np.int64),
                      clf=model)
ax.set_xlabel(r'$x_1$', size=15)
ax.xaxis.set_label_coords(1, -0.025)
ax.set_ylabel(r'$x_2$', size=15)
ax.yaxis.set_label_coords(-0.025, 1)
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2186653.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LSM6DSV16X基于MLC智能笔动作识别(2)----MLC数据采集

LSM6DSV16X基于MLC智能笔动作识别.2--MLC数据采集 概述视频教学样品申请源码下载输出速率执行流程速率设置量程设置检测状态数据单位采集数据静止(Steady)闲置(Idle)书写(Writing)其他(other) 概述 MLC 是“机器学习核心”&#xff08;Machine Learning Core&#xff09;的缩写…

全球购的智能引擎:AI与RPA如何重塑跨境电商帝国?

在全球化的大潮中&#xff0c;跨境电商已成为连接世界的桥梁。随着人工智能&#xff08;AI&#xff09;和机器人流程自动化&#xff08;RPA&#xff09;技术的飞速发展&#xff0c;跨境电商领域的运作模式正在经历一场革命性的变革。 一、跨境电商的挑战 随着互联网技术的普及…

D3.js中国地图可视化

1、项目介绍 该项目来自Github&#xff0c;基于D3.js中国地图可视化。 D3.js is a JavaScript library for manipulating documents based on data. It uses HTML, SVG, and CSS to display data. The full name of D3 is "Data-Driven Documents," which means it a…

在VirtualBox中安装OpenEuler操作系统保姆级教程

前言 OpenEuler是一个由中国华为公司主导开发和维护的开源操作系统项目&#xff0c;旨在打造一个开放、可信且可扩展的企业级操作系统&#xff0c;适用于多种应用场景。 该项目致力于通过开放和协作的方式推动操作系统的创新与发展。OpenEuler采用开源软件模型&#xff0c;允…

多模态:Florence2论文详解

文章目录 前言一、介绍二、方法1.模型结构1&#xff09;Vision encoder2&#xff09;Multi-modality encoder decoder3&#xff09;Optimization objective 2.数据工程1&#xff09;Image Collection2&#xff09;Data Annotation3&#xff09;Data filtering and enhancement4…

spring学习日记-day8-声明式事务

一、学习目标 声明式事务是Spring框架提供的一种事务管理方式&#xff0c;其主要特点是通过声明&#xff08;而非编程&#xff09;的方式来处理事务。这种方式让事务管理不侵入业务逻辑代码&#xff0c;从而提高了代码的可维护性和可读性。 定义&#xff1a;声明式事务…

[3.4]【机器人运动学MATLAB实战分析】PUMA560机器人逆运动学MATLAB计算

PUMA560是六自由度关节型机器人,其6个关节都是转动副,属于6R型操作臂。各连杆坐标系如图1,连杆参数如表1所示。 图1 PUMA560机器人的各连杆坐标系 表1 PUMA560机器人的连杆参数 用代数法对其进行运动学反解。具体步骤如下: 1、求θ1 PMUMA56

【数据结构笔记13】

408数据结构答题规范 原视频 视频参考&#xff0c;以下为视频的笔记 需要写的部分 如果题目要求了函数名、参数列表、返回值类型就按题目的来 函数的类型可以是返回值类型或者void类型&#xff0c;如果函数名不清楚里面的功能是什么&#xff0c;在函数title下面最好写一行注…

磁盘存储和文件系统管理【1.9】

磁盘存储和文件系统管理【1.9】 12、磁盘存储和文件系统12.1.管理存储12.1.1.新加10G硬盘并识别12.1.2.备份查看MBR分区表二进制信息12.1.3.删除破坏分区表12.1.4.恢复MBR分区表12.1.5.完整步骤12.1.6.fdisk分区12.1.7.gdisk分区12.2.文件系统12.2.1.查看支持的文件系统格式12.…

音视频入门基础:FLV专题(11)——FFmpeg源码中,解析SCRIPTDATASTRING类型的ScriptDataValue的实现

一、引言 从《音视频入门基础&#xff1a;FLV专题&#xff08;9&#xff09;——Script Tag简介》中可以知道&#xff0c;根据《video_file_format_spec_v10_1.pdf》第80到81页&#xff0c;SCRIPTDATAVALUE类型由一个8位&#xff08;1字节&#xff09;的Type和一个ScriptDataV…

Java类的生命周期-连接阶段

Java类的生命周期-连接阶段 上篇讲述了类的加载阶段&#xff0c;通过类加载器读取字节码文件后在方法区与堆区生成对应的存放类信息的对象&#xff0c;本篇将讲解他的下一阶段-连接阶段 上篇说到类加载的五大阶段&#xff1a; #mermaid-svg-6YmaEnIO4rCKbIZg {font-family:&quo…

Cpp::STL—vector类的模拟实现(11)

文章目录 前言一、各函数接口总览二、默认成员函数vector();vector(size_t n, const T& val T( ));template< class InputIterator> vector(InputIterator first, InputIterator last);vector(const vector<T>& v);vector<T>& operator(const v…

腾讯云SDK基本概念

本文旨在介绍您在使用音视频终端 SDK&#xff08;腾讯云视立方&#xff09;产品过程中可能会涉及到的基本概念。 音视频终端 SDK&#xff08;腾讯云视立方&#xff09; 应用 音视频终端 SDK&#xff08;腾讯云视立方&#xff09;通过应用的形式来管理您的项目&#xff08;Ap…

C/C++进阶(一)--内存管理

更多精彩内容..... &#x1f389;❤️播主の主页✨&#x1f618; Stark、-CSDN博客 本文所在专栏&#xff1a; 学习专栏C语言_Stark、的博客-CSDN博客 其它专栏&#xff1a; 数据结构与算法_Stark、的博客-CSDN博客 ​​​​​​项目实战C系列_Stark、的博客-CSDN博客 座右铭&a…

免费录屏软件工具:助力高效屏幕录制

录屏已经成为了一项非常实用且广泛应用的技术。无论是制作教学视频、记录游戏精彩瞬间&#xff0c;还是进行软件操作演示等&#xff0c;我们都常常需要一款可靠的录屏软件。今天&#xff0c;就让我们一起来探索那些功能强大录屏软件免费版&#xff0c;看看它们是如何满足我们多…

ARTS Week 42

Algorithm 本周的算法题为 2283. 判断一个数的数字计数是否等于数位的值 给你一个下标从 0 开始长度为 n 的字符串 num &#xff0c;它只包含数字。 如果对于 每个 0 < i < n 的下标 i &#xff0c;都满足数位 i 在 num 中出现了 num[i]次&#xff0c;那么请你返回 true …

【数据结构强化】应用题打卡

应用题打卡 数组的应用 对称矩阵的压缩存储 注意&#xff1a; 1. 2.上三角的行优先存储及下三角的列优先存储与数组的下表对应 上/下三角矩阵的压缩存储 注意&#xff1a; 上/下三角压缩存储是将0元素统一压缩存储&#xff0c;而不是将对角线元素统一压缩存储 三对角矩阵的…

接口隔离原则在前端的应用

什么是接口隔离 接口隔离原则&#xff08;ISP&#xff09;是面向对象编程中的SOLID原则之一&#xff0c;它专注于设计接口。强调在设计接口时&#xff0c;应该确保一个类不必实现它不需要的方法。换句话说&#xff0c;接口应该尽可能地小&#xff0c;只包含一个类需要的方法&am…

SKD4(note上)

微软提供了图形的界面API&#xff0c;叫GDI 如果你想画某个窗口&#xff0c;你必须拿到此窗口的HDC #include <windows.h> #include<tchar.h> #include <stdio.h> #include <strsafe.h> #include <string>/*鼠标消息 * 键盘消息 * Onkeydown * …

实验 3 存储器实验

实验 3 存储器实验 1、实验目的 掌握静态随机存储器 RAM 的工作特性。掌握静态随机存储器 RAM 的读写方法。 2、实验要求 (1)做好实验预习&#xff0c;熟悉MEMORY6116 芯片各引脚的功能和连接方式&#xff0c;熟悉其他实验元器件的功能特性和使用方法&#xff0c;看懂电路图…