Python(PyTorch)物理变化可微分神经算法

news2024/9/20 13:07:15

🎯要点

🎯使用受控物理变换序列实现可训练分层物理计算 | 🎯多模机械振荡、非线性电子振荡器和光学二次谐波生成神经算法验证 | 🎯训练输入数据,物理系统变换产生输出和可微分数字模型估计损失的梯度 | 🎯多模振荡对输入数据进行可控卷积 | 🎯物理神经算法数学表示、可微分数学模型 | 🎯MNIST和元音数据集评估算法

🍪语言内容分比

在这里插入图片描述
在这里插入图片描述

🍇PyTorch可微分优化

假设张量 x x x是元参数, a a a是普通参数(例如网络参数)。我们有内部损失 L in  = a 0 ⋅ x 2 L ^{\text {in }}=a_0 \cdot x^2 Lin =a0x2 并且我们使用梯度 ∂ L in  ∂ a 0 = x 2 \frac{\partial L ^{\text {in }}}{\partial a_0}=x^2 a0Lin =x2 更新 a a a a 1 = a 0 − η ∂ L in  ∂ a 0 = a 0 − η x 2 a_1=a_0-\eta \frac{\partial L ^{\text {in }}}{\partial a_0}=a_0-\eta x^2 a1=a0ηa0Lin =a0ηx2。然后我们计算外部损失 L out  = a 1 ⋅ x 2 L ^{\text {out }}=a_1 \cdot x^2 Lout =a1x2。因此外部损失到 x x x 的梯度为:
∂ L out  ∂ x = ∂ ( a 1 ⋅ x 2 ) ∂ x = ∂ a 1 ∂ x ⋅ x 2 + a 1 ⋅ ∂ ( x 2 ) ∂ x = ∂ ( a 0 − η x 2 ) ∂ x ⋅ x 2 + ( a 0 − η x 2 ) ⋅ 2 x = ( − η ⋅ 2 x ) ⋅ x 2 + ( a 0 − η x 2 ) ⋅ 2 x = − 4 η x 3 + 2 a 0 x \begin{aligned} \frac{\partial L ^{\text {out }}}{\partial x} & =\frac{\partial\left(a_1 \cdot x^2\right)}{\partial x} \\ & =\frac{\partial a_1}{\partial x} \cdot x^2+a_1 \cdot \frac{\partial\left(x^2\right)}{\partial x} \\ & =\frac{\partial\left(a_0-\eta x^2\right)}{\partial x} \cdot x^2+\left(a_0-\eta x^2\right) \cdot 2 x \\ & =(-\eta \cdot 2 x) \cdot x^2+\left(a_0-\eta x^2\right) \cdot 2 x \\ & =-4 \eta x^3+2 a_0 x \end{aligned} xLout =x(a1x2)=xa1x2+a1x(x2)=x(a0ηx2)x2+(a0ηx2)2x=(η2x)x2+(a0ηx2)2x=4ηx3+2a0x
鉴于上述分析解,让我们使用 TorchOpt 中的 MetaOptimizer 对其进行验证。MetaOptimizer 是我们可微分优化器的主类。它与功能优化器 torchopt.sgdtorchopt.adam 相结合,定义了我们的高级 API torchopt.MetaSGDtorchopt.MetaAdam

首先,定义网络。

from IPython.display import display

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchopt


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = nn.Parameter(torch.tensor(1.0), requires_grad=True)

    def forward(self, x):
        return self.a * (x**2)

然后我们声明网络(由 a 参数化)和元参数 x。不要忘记为 x 设置标志 require_grad=True

net = Net()
x = nn.Parameter(torch.tensor(2.0), requires_grad=True)

接下来我们声明元优化器。这里我们展示了定义元优化器的两种等效方法。

optim = torchopt.MetaOptimizer(net, torchopt.sgd(lr=1.0))
optim = torchopt.MetaSGD(net, lr=1.0)

元优化器将网络作为输入并使用方法步骤来更新网络(由a参数化)。最后,我们展示双层流程的工作原理。

inner_loss = net(x)
optim.step(inner_loss)

outer_loss = net(x)
outer_loss.backward()
# x.grad = - 4 * lr * x^3 + 2 * a_0 * x
#        = - 4 * 1 * 2^3 + 2 * 1 * 2
#        = -32 + 4
#        = -28
print(f'x.grad = {x.grad!r}')

输出:

x.grad = tensor(-28.)

让我们从与模型无关的元学习算法的核心思想开始。该算法是一种与模型无关的元学习算法,它与任何使用梯度下降训练的模型兼容,并且适用于各种不同的学习问题,包括分类、回归和强化学习。元学习的目标是在各种学习任务上训练模型,以便它仅使用少量训练样本即可解决新的学习任务。

更新规则定义为:

给定微调步骤的学习率 α \alpha α θ \theta θ 应该最小化
L ( θ ) = E T i ∼ p ( T ) [ L T i ( θ i ′ ) ] = E T i ∼ p ( T ) [ L T i ( θ − α ∇ θ L T i ( θ ) ) ] L (\theta)= E _{ T _i \sim p( T )}\left[ L _{ T _i}\left(\theta_i^{\prime}\right)\right]= E _{ T _i \sim p( T )}\left[ L _{ T _i}\left(\theta-\alpha \nabla_\theta L _{ T _i}(\theta)\right)\right] L(θ)=ETip(T)[LTi(θi)]=ETip(T)[LTi(θαθLTi(θ))]
我们首先定义一些与任务、轨迹、状态、动作和迭代相关的参数。

import argparse
from typing import NamedTuple

import gym
import numpy as np
import torch
import torch.optim as optim

import torchopt
from helpers.policy import CategoricalMLPPolicy


TASK_NUM = 40
TRAJ_NUM = 20
TRAJ_LEN = 10

STATE_DIM = 10
ACTION_DIM = 5

GAMMA = 0.99
LAMBDA = 0.95

outer_iters = 500
inner_iters = 1

接下来,我们定义一个名为 Traj 的类来表示轨迹,其中包括观察到的状态、采取的操作、采取操作后观察到的状态、获得的奖励以及用于贴现未来奖励的伽玛值。

class Traj(NamedTuple):
    obs: np.ndarray
    acs: np.ndarray
    next_obs: np.ndarray
    rews: np.ndarray
    gammas: np.ndarray

评估函数用于评估策略在不同任务上的性能。它使用内部优化器来微调每个任务的策略,然后计算微调前后的奖励。

def evaluate(env, seed, task_num, policy):
    pre_reward_ls = []
    post_reward_ls = []
    inner_opt = torchopt.MetaSGD(policy, lr=0.1)
    env = gym.make(
        'TabularMDP-v0',
        num_states=STATE_DIM,
        num_actions=ACTION_DIM,
        max_episode_steps=TRAJ_LEN,
        seed=args.seed,
    )
    tasks = env.sample_tasks(num_tasks=task_num)
    policy_state_dict = torchopt.extract_state_dict(policy)
    optim_state_dict = torchopt.extract_state_dict(inner_opt)
    for idx in range(task_num):
        for _ in range(inner_iters):
            pre_trajs = sample_traj(env, tasks[idx], policy)
            inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
            inner_opt.step(inner_loss)
        post_trajs = sample_traj(env, tasks[idx], policy)

        pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean())
        post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean())

        torchopt.recover_state_dict(policy, policy_state_dict)
        torchopt.recover_state_dict(inner_opt, optim_state_dict)
    return pre_reward_ls, post_reward_ls

在主函数中,我们初始化环境、策略和优化器。策略是一个简单的 MLP,它输出动作的分类分布。内部优化器用于在微调阶段更新策略参数,外部优化器用于在元训练阶段更新策略参数。性能通过微调前后的奖励来评估。每次外部迭代都会记录并打印训练过程。

def main(args):

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    env = gym.make(
        'TabularMDP-v0',
        num_states=STATE_DIM,
        num_actions=ACTION_DIM,
        max_episode_steps=TRAJ_LEN,
        seed=args.seed,
    )

    policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM)
    inner_opt = torchopt.MetaSGD(policy, lr=0.1)
    outer_opt = optim.Adam(policy.parameters(), lr=1e-3)
    train_pre_reward = []
    train_post_reward = []
    test_pre_reward = []
    test_post_reward = []

    for i in range(outer_iters):
        tasks = env.sample_tasks(num_tasks=TASK_NUM)
        train_pre_reward_ls = []
        train_post_reward_ls = []

        outer_opt.zero_grad()

        policy_state_dict = torchopt.extract_state_dict(policy)
        optim_state_dict = torchopt.extract_state_dict(inner_opt)
        for idx in range(TASK_NUM):
            for _ in range(inner_iters):
                pre_trajs = sample_traj(env, tasks[idx], policy)
                inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
                inner_opt.step(inner_loss)
            post_trajs = sample_traj(env, tasks[idx], policy)
            outer_loss = a2c_loss(post_trajs, policy, value_coef=0.5)
            outer_loss.backward()
            torchopt.recover_state_dict(policy, policy_state_dict)
            torchopt.recover_state_dict(inner_opt, optim_state_dict)
            # Logging
            train_pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean())
            train_post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean())
        outer_opt.step()

        test_pre_reward_ls, test_post_reward_ls = evaluate(env, args.seed, TASK_NUM, policy)

        train_pre_reward.append(sum(train_pre_reward_ls) / TASK_NUM)
        train_post_reward.append(sum(train_post_reward_ls) / TASK_NUM)
        test_pre_reward.append(sum(test_pre_reward_ls) / TASK_NUM)
        test_post_reward.append(sum(test_post_reward_ls) / TASK_NUM)

        print('Train_iters', i)
        print('train_pre_reward', sum(train_pre_reward_ls) / TASK_NUM)
        print('train_post_reward', sum(train_post_reward_ls) / TASK_NUM)
        print('test_pre_reward', sum(test_pre_reward_ls) / TASK_NUM)
        print('test_post_reward', sum(test_post_reward_ls) / TASK_NUM)

👉参阅、更新:计算思维 | 亚图跨际

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2058002.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ubuntu命令大全

查看系统版本 lsb_release -a

C++模板方法TemplateMethod

23种设计模式分为九类 1.组件协作 2.单一职责 3.对象创建 4.对象性能 5.接口隔离 6.状态变化 7.数据结构 8.行为变化 9.领域问题 什么时候、什么地点用设计模式 才是最重要的。 关键的重构技法: 静态-----动态 早绑定—晚绑定 继承-----组合 编译时依赖------运行…

计算机毕业设计--基于深度学习(PSPNet、空洞卷积Atrous Convolutions)的多类型图像通用分割模型

基于深度学习(PSPNet、空洞卷积Atrous Convolutions)的多类型图像通用分割模型 更多基于深度学习的毕业设计请关注专栏 --- 计算机毕业设计 ✨ 动物图分割(使用训练集DIS5K-TR,DIS-TEs,DUTS-TR_TE ) ✨自然与人类图像分割&#xf…

支持最新 mysql9的workbench8.0.39 中文汉化教程来了

之前在 B 站上发布了 mysql8 workbench 汉化教程,一年多来帮助很多初学者解决了不熟悉英文的烦恼。 汉化视频可以访问: 2024最新版mysql8.0.39中文版mysql workbench汉化 中文升级 旧版汉化报错解决_哔哩哔哩_bilibili MySql Workbench汉化_哔哩哔哩_bi…

C++ 左值引用与右值引用超详解

目录 一 左值与右值 1.左值 2.右值 3.总结 二 左值引用与右值引用 1.左值引用 2.右值引用 3.总结与探究 3.1右值引用可以修改么?取地址么? 3.2左值引用与右值引用转化 左值引用 引用 右值 右值引用 引用 左值 3.3左值引用与右值引用相同之处 3.4左…

MySQL基础:函数

💎所属专栏:MySQL 函数是指一段可以直接被另一段程序调用的程序或代码,在MySQL中也内置了许多函数供开发者去调用,例如之前提到的聚合函数,本节再去介绍一些其他常用的函数 字符串函数 函数功能CONCAT(S1,S2...Sn)字…

开源的量化交易领域平台vn.py(VeighNa)

一:vn.py(VeighNa)下的工具以及社区版和Elite版的区别 vn.py是一款广泛应用于量化交易领域的开源软件,它主要有以下用途和功能: 1. 交易系统开发框架:vn.py提供了一个完整的交易系统开发框架,可…

桶排序算法及优化(java)

目录 1.1 引言 1.2 桶排序的历史 1.3 桶排序的基本原理 1.3.1 工作流程 1.3.2 关键步骤 1.4 桶排序的Java实现 1.4.1 简单实现 1.4.2 优化实现 1.4.3 代码解释 1.5 桶排序的时间复杂度 1.5.1 分析 1.5.2 证明 1.6 桶排序的稳定性 1.7 著名案例 1.7.1 应用场景 …

基于GPT-SoVITS的API实现批量克隆声音

目标是将每一段声音通过GPT-SoVITS的API的API进行克隆,因为拼在一起的整个片段处理会造成内存或者缓存溢出。 将目录下的音频文件生成到指定目录下,然后再进行拼接。 通过AI工具箱生成的数据文件是这样的结构,temp目录下是没个片段生成的部分,connect_是正常拼接的音频文件…

笨鸟先飞(疯狂的小鸟)小游戏自制分享

《Flappy Bird》是一款由越南独立游戏开发者阮哈东(Dong Nguyen)制作并发布的移动端小游戏。该游戏最初于2013年上线,在2014年初迅速走红,成为全球范围内的热门现象。 游戏的玩法非常简单,玩家只需通过点击屏幕来控制…

Python | Leetcode Python题解之第355题设计推特

题目: 题解: class Twitter:class Node:def __init__(self):self.followee set()self.tweet list()def __init__(self):self.time 0self.recentMax 10self.tweetTime dict()self.user dict()def postTweet(self, userId: int, tweetId: int) ->…

基于人工智能、三维视觉、混合现实等技术的智慧能源开源了

一、简介 AI视频监控平台, 是一款功能强大且简单易用的实时算法视频监控系统。愿景在最底层打通各大芯片厂商相互间的壁垒,省去繁琐重复的适配流程,实现芯片、算法、应用的全流程组合,减少企业级应用约 95%的开发成本,在强大视频算…

AI学习记录 - LSTM详细拆解

拒绝熬夜,一点点写,拆解LSTM计算过程和最后的总结 遗忘门的计算流程 拼接词向量,前面来的,现在输入的 然后进行计算:

浅谈移动端车牌识别技术的实现过程及应用场景

随着移动互联技术的飞速发展和智能设备的普及,Android、iOS平台上的车牌识别技术逐渐成熟并广泛应用于各个领域。该技术通过智能手机的摄像头捕捉车牌图像,利用先进的图像处理与机器学习算法,实现车牌号码的自动识别。相比传统的人工录入或固…

opencv中Core中的Norm函数解释

1. Norm的类型 NORM_L1: L1 范数(曼哈顿范数)。数组中所有元素绝对值之和。 NORM_L2: L2 范数(欧几里得范数)。数组中所有元素平方和的平方根。 NORM_INF:无穷范数(最大绝对值范数&…

Nginx的7大调度算法详解

Nginx的7大调度算法详解 一、Sticky二、Round-Robin(RR)三、Weight四、Least_conn五、IP_hash六、Fair七、URL_hash总结 💖The Begin💖点点关注,收藏不迷路💖 Nginx作为一款高性能的HTTP和反向代理服务器&a…

Linux虚拟机磁盘管理-添加磁盘

添加磁盘--添加前请选关闭虚拟机 添加步骤: 1.编辑虚拟机设置 2.选择硬盘 3.选择SCSI 4.创建新虚拟磁盘 5.设置磁盘大小 6.点击完成 开机的时候会去读取有几块硬盘,总共我们是有4块硬盘,sda\sdb\sdc\sdd 注意:新加的硬盘实际我们…

VScode相关使用、配置

VScode 拉取新分支 点击左下角分支会出现这个 选择创建新分支依据… 选择一个分支为从这个分支拉新分支 输入新分支的名称即可 VScode 合并分支 切到最终要合并到的分支,通过快捷键 shiftctrlp 出现框中 ,选择 git 合并分支 选择要合并过来的分…

【Docker】Docker Consul

docker consul Docker Consul 是一个用于服务发现和配置的开源工具,它是 HashiCorp 公司推出的一个项目。Consul 提供了一个中心化的服务注册和发现系统,可以帮助开发人员轻松地在 Docker 容器和集群之间进行服务发现和配置管理。 Consul 使用基于 HTT…

位运算使用

在写代码过程中&#xff0c;适当的位运算是一种提高代码质量的有效手段。 0 位运算 常用的运算符共 6 种&#xff0c;分别为按位与&、按位或|、按位异或^、按位取反~、左移位<<、右移位>>。 0.1 按位与&、按位或|、按位异或^ 按位与&、按位或|、按…