PINN神经网络源代码解析(pyTorch)

news2024/11/18 7:42:52

参考文献

PINN(Physics-informed Neural Networks)的原理部分可参见https://maziarraissi.github.io/PINNs/

考虑Burgers方程,如下图所示,初始时刻u符合sin分布,随着时间推移在x=0处发生间断.
这是一个经典问题,可使用pytorch通过PINN实现对Burgers方程的求解。
在这里插入图片描述

源代码与注释

源代码共含有三个文件,来源于Github https://github.com/jayroxis/PINNs

在这里插入图片描述
network.py文件用于定义神经网络的结构
train.py文件用于训练神经网络
evaluate.py文件用于测试训练好的模型绘制结果图

建议使用Anaconda构建运行环境,需要安装pytorch和一些辅助包

1、network.py 文件

import torch
import torch.nn as nn
from collections import OrderedDict

# 定义神经网络的架构
class Network(nn.Module):
    # 构造函数
    def __init__(
        self,
        input_size, # 输入层神经元数
        hidden_size, # 隐藏层神经元数
        output_size, # 输出层神经元数
        depth, # 隐藏层数
        act=torch.nn.Tanh, # 输入层和隐藏层的激活函数
    ):
        super(Network, self).__init__()#调用父类的构造函数

        # 输入层
        layers = [('input', torch.nn.Linear(input_size, hidden_size))]
        layers.append(('input_activation', act()))

        # 隐藏层
        for i in range(depth):
            layers.append(
                ('hidden_%d' % i, torch.nn.Linear(hidden_size, hidden_size))
            )
            layers.append(('activation_%d' % i, act()))

        # 输出层
        layers.append(('output', torch.nn.Linear(hidden_size, output_size)))

        #将这些层组装为神经网络
        self.layers = torch.nn.Sequential(OrderedDict(layers))

    # 前向计算方法
    def forward(self, x):
        return self.layers(x)

2、train.py 文件

import math
import torch
import numpy as np
from network import Network

# 定义一个类,用于实现PINN(Physics-informed Neural Networks)
class PINN:
    # 构造函数
    def __init__(self):
        # 选择使用GPU还是CPU
        device = torch.device(
            "cuda") if torch.cuda.is_available() else torch.device("cpu")
        
        # 定义神经网络
        self.model = Network(
            input_size=2,  # 输入层神经元数
            hidden_size=16,  # 隐藏层神经元数
            output_size=1,  # 输出层神经元数
            depth=8,  # 隐藏层数
            act=torch.nn.Tanh  # 输入层和隐藏层的激活函数
        ).to(device)  # 将这个神经网络存储在GPU上(若GPU可用)

        self.h = 0.1  # 设置空间步长
        self.k = 0.1  # 设置时间步长
        x = torch.arange(-1, 1 + self.h, self.h)  # 在[-1,1]区间上均匀取值,记为x
        t = torch.arange(0, 1 + self.k, self.k)  # 在[0,1]区间上均匀取值,记为t

        # 将x和t组合,形成时间空间网格,记录在张量X_inside中
        self.X_inside = torch.stack(torch.meshgrid(x, t)).reshape(2, -1).T

        # 边界处的时空坐标
        bc1 = torch.stack(torch.meshgrid(x[0], t)).reshape(2, -1).T  # x=-1边界
        bc2 = torch.stack(torch.meshgrid(x[-1], t)).reshape(2, -1).T  # x=+1边界
        ic = torch.stack(torch.meshgrid(x, t[0])).reshape(2, -1).T  # t=0边界
        self.X_boundary = torch.cat([bc1, bc2, ic])  # 将所有边界处的时空坐标点整合为一个张量

        # 边界处的u值
        u_bc1 = torch.zeros(len(bc1))  # x=-1边界处采用第一类边界条件u=0
        u_bc2 = torch.zeros(len(bc2))  # x=+1边界处采用第一类边界条件u=0
        u_ic = -torch.sin(math.pi * ic[:, 0])  # t=0边界处采用第一类边界条件u=-sin(pi*x)
        self.U_boundary = torch.cat([u_bc1, u_bc2, u_ic])  # 将所有边界处的u值整合为一个张量
        self.U_boundary = self.U_boundary.unsqueeze(1)

        # 将数据拷贝到GPU
        self.X_inside = self.X_inside.to(device)
        self.X_boundary = self.X_boundary.to(device)
        self.U_boundary = self.U_boundary.to(device)
        self.X_inside.requires_grad = True  # 设置:需要计算对X的梯度

        # 设置准则函数为MSE,方便后续计算MSE
        self.criterion = torch.nn.MSELoss()

        # 定义迭代序号,记录调用了多少次loss
        self.iter = 1

        # 设置lbfgs优化器
        self.lbfgs = torch.optim.LBFGS(
            self.model.parameters(),
            lr=1.0,
            max_iter=50000,
            max_eval=50000,
            history_size=50,
            tolerance_grad=1e-7,
            tolerance_change=1.0 * np.finfo(float).eps,
            line_search_fn="strong_wolfe",
        )

        # 设置adam优化器
        self.adam = torch.optim.Adam(self.model.parameters())

    # 损失函数
    def loss_func(self):
        # 将导数清零
        self.adam.zero_grad()
        self.lbfgs.zero_grad()

        # 第一部分loss: 边界条件不吻合产生的loss
        U_pred_boundary = self.model(self.X_boundary)  # 使用当前模型计算u在边界处的预测值
        loss_boundary = self.criterion(
            U_pred_boundary, self.U_boundary)  # 计算边界处的MSE

        # 第二部分loss:内点非物理产生的loss
        U_inside = self.model(self.X_inside)  # 使用当前模型计算内点处的预测值

        # 使用自动求导方法得到U对X的导数
        du_dX = torch.autograd.grad(
            inputs=self.X_inside,
            outputs=U_inside,
            grad_outputs=torch.ones_like(U_inside),
            retain_graph=True,
            create_graph=True
        )[0]
        du_dx = du_dX[:, 0]  # 提取对第x的导数
        du_dt = du_dX[:, 1]  # 提取对第t的导数

        # 使用自动求导方法得到U对X的二阶导数
        du_dxx = torch.autograd.grad(
            inputs=self.X_inside,
            outputs=du_dX,
            grad_outputs=torch.ones_like(du_dX),
            retain_graph=True,
            create_graph=True
        )[0][:, 0]
        loss_equation = self.criterion(
            du_dt + U_inside.squeeze() * du_dx, 0.01 / math.pi * du_dxx)  # 计算物理方程的MSE

        # 最终的loss由两项组成
        loss = loss_equation + loss_boundary

        # loss反向传播,用于给优化器提供梯度信息
        loss.backward()

        # 每计算100次loss在控制台上输出消息
        if self.iter % 100 == 0:
            print(self.iter, loss.item())
        self.iter = self.iter + 1
        return loss

    # 训练
    def train(self):
        self.model.train()  # 设置模型为训练模式

        # 首先运行5000步Adam优化器
        print("采用Adam优化器")
        for i in range(5000):
            self.adam.step(self.loss_func)
        # 然后运行lbfgs优化器
        print("采用L-BFGS优化器")
        self.lbfgs.step(self.loss_func)

# 实例化PINN
pinn = PINN()

# 开始训练
pinn.train()

# 将模型保存到文件
torch.save(pinn.model, 'model.pth')

运行该文件后模型结果保存在model.pth文件中

3、evaluate.py 文件

import torch
import seaborn as sns
import matplotlib.pyplot as plt

# 选择GPU或CPU
device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")

# 从文件加载已经训练完成的模型
model_loaded = torch.load('model.pth', map_location=device)
model_loaded.eval()  # 设置模型为evaluation状态

# 生成时空网格
h = 0.01
k = 0.01
x = torch.arange(-1, 1, h)
t = torch.arange(0, 1, k)
X = torch.stack(torch.meshgrid(x, t)).reshape(2, -1).T
X = X.to(device)

# 计算该时空网格对应的预测值
with torch.no_grad():
    U_pred = model_loaded(X).reshape(len(x), len(t)).cpu().numpy()

# 绘制计算结果
plt.figure(figsize=(5, 3), dpi=300)
xnumpy = x.numpy()
plt.plot(xnumpy, U_pred[:, 0], 'o', markersize=1)
plt.plot(xnumpy, U_pred[:, 20], 'o', markersize=1)
plt.plot(xnumpy, U_pred[:, 40], 'o', markersize=1)
plt.figure(figsize=(5, 3), dpi=300)
sns.heatmap(U_pred, cmap='jet')
plt.show()

运行该文件后,可绘制u场的结果
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/886810.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

系统架构设计师-信息安全技术(1)

目录 一、信息安全基础 1、信息安全五要素 2、网络安全漏洞 3、网络安全威胁 4、安全措施的目标 二、信息加解密技术 1、对称加密 2、非对称加密 3、加密算法对比 三、密钥管理技术 1、数字证书 2、PKI公钥体系 四、访问控制技术 1、访问控制基本模型 2、访问控制的实现技术…

群晖安装 frpc

群晖安装 frpc 博主博客 https://blog.uso6.comhttps://blog.csdn.net/dxk539687357 写该文章之前, 我尝试过使用 “任务计划” 设置开机启动 frpc, 但是失败了。 最后尝试使用 docker 开机启动 frpc 才成功, 因此本文主要介绍使用 docker …

爬虫逆向实战(十三)--某课网登录

一、数据接口分析 主页地址:某课网 1、抓包 通过抓包可以发现登录接口是user/login 2、判断是否有加密参数 请求参数是否加密? 通过查看“载荷”模块可以发现有一个password加密参数,还有一个browser_key这个可以写死不需要关心 请求头…

Lodash——使用与实例

1. 简介 Lodash是一个一致性、模块化、高性能的JavaScript实用库。Lodash通过降低array、number、objects、string等等的使用难度从而让JavaScript变得简单。Lodash的模块方法,非常适用于: 遍历array、object 和 string对值进行操作和检测创建符合功能的…

注册中心Eureka和Nacos,以及负载均衡Ribbon

1.初识微服务 1.1.什么是微服务 微服务,就是把服务拆分成为若干个服务,降低服务之间的耦合度,提供服务的独立性和灵活性。做到高内聚,低耦合。 1.2.单体架构和微服务架构的区别: 单体架构:简单方便&#…

【干货】通过Bootstrap框架添加下拉框到导航栏

最终效果展示详细步骤及代码1、获取相关代码2、引入CSS和JavaScript文件3、全部代码 最终效果展示 详细步骤及代码 1、获取相关代码 https://v3.bootcss.com/components/#navbar-default 本文用到代码为 <nav class"navbar navbar-default"><div class&…

第六阶|见道明心的笔墨(上)从书法之美到生活之美——林曦老师的线上直播书法课

如果你有需要&#xff0c;可以找我的&#xff0c;我这边有老师的所有课程 如果你有需要&#xff0c;可以找我的&#xff0c;我这边有老师的所有课程

零售行业供应链管理核心KPI指标(二) – 线上订单履行周期

一般品牌零售商有一个大的渠道就是全国连锁的商超、大卖场&#xff0c;非常重要的渠道&#xff0c;要去铺货。同类型的产品都在竞争这个大渠道&#xff0c;但商超、大卖场在这类产品的容量是有限的&#xff0c;所以各个品牌就要去争夺整个容量&#xff0c;看谁在有限的容量里占…

最小二乘线性拟合FC(SCL计算源代码)

采用PLC等微控制器采集一些线性传感器数据时&#xff0c;如果已知线性关系&#xff0c;我们可以利用直线方程求解。具体的算法公式和讲解大家可以查看下面相关文章&#xff1a; PLC模拟量采集算法数学基础&#xff08;线性传感器&#xff09;_plc3秒采集一次模拟量_RXXW_Dor的…

expert systems with applications latex使用、投稿、合集(超详细)

目录 一、main.tex 1、框架 2、图片 3、表格 4、公式 5、文献引用 6、引用文献高亮 及 其他需要导入的包 7、特殊符号 比如 ✔ ∈ 二、投稿 1、orcid 身份码 2、.bib 文件设置为 manuscript 3、Cover Letter 4、declaration of interest statement 5、模板及其…

分享一个恒流源和恒压源电路,可实现恒压、恒流充放电

通过控制输出DA-IOUT1,DA-VOUT1电流和电压DA的大小&#xff0c;及继电器控制和CH1I,CH1V的采用反馈&#xff0c;该电路可实现&#xff0c;恒流充电&#xff0c;恒压充电&#xff0c;恒流恒压充电&#xff0c;恒流放电&#xff0c;恒阻充电&#xff0c;恒功率充电等充放电模式&a…

01_YS_LED_USART1_KEY_Test_Demo

1.硬件设计参考图 参考&#xff1a;00_YS_硬件电路图_往事不可追_来日不方长的博客-CSDN博客 2.配置LED 2.1代码部分 代码初始化部分如下图MX自动生成&#xff1a; // main.h 中/* Private defines -----------------------------------------------------------*/ #define…

Wordcloud | 风中有朵雨做的‘词云‘哦!~

1写在前面 今天可算把key搞好了&#xff0c;不得不说&#x1f3e5;里手握生杀大权的人&#xff0c;都在自己的能力范围内尽可能的难为你。&#x1f602; 我等小大夫也是很无奈&#xff0c;毕竟奔波霸、霸波奔是要去抓唐僧的。 &#x1f910; 好吧&#xff0c;今天是词云&#x…

Vue3 Router路由单页面跳转简单应用

去官网学习→介绍 | Vue Router cd 到项目 安装 Router &#xff1a; cnpm install --save vue-router 或着 创建项目时勾选Router vue create vue-demo <i> to invert selection, and <enter> to proceed)(*) Babel( ) TypeScript(*) Progressive Web …

《vue3实战》运用radio单选按钮或Checkbox复选框实现单选多选的试卷制作

文章目录 目录 系列文章目录 1.《Vue3实战》使用axios获取文件数据以及走马灯Element plus的运用 2.《Vue3实战》用路由实现跳转登录、退出登录以及路由全局守护 3.《vue3实战》运用Checkbox复选框实现单选多选的试卷展现&#xff08;本文&#xff09; 文章目录 前言 radio是什…

【探索Linux】—— 强大的命令行工具 P.3(Linux开发工具 vim)

阅读导航 前言vim简介概念特点 vim的相关指令vim命令模式(Normal mode)相关指令插入模式(Insert mode)相关指令末行模式(last line mode)相关指令 简单vim配置&#xff08;附配置链接&#xff09;温馨提示 前言 前面我们讲了C语言的基础知识&#xff0c;也了解了一些数据结构&…

observer与qt信号槽的区别

observer类图(应用) 定义/区别/注意事项 点击截图后可放大显示,也可图片另存为&#xff0c;这个技术讨论是来接受批评的。 参考&#xff1a;

专访 BlockPI:共建账户抽象未来的新一代 RPC 基础设施

在传统 RPC 服务板块上&#xff0c;开发者一直饱受故障风险、运行环境混乱等难题的折磨。实现 RPC 服务的去中心化&#xff0c;且保持成本优势和可扩展性&#xff0c;始终是区块链基础设施建设的重要命题之一。从 2018 年观察中心化 RPC 供应商服务现状开始&#xff0c;BlockPI…

设计模式之构建器(Builder)C++实现

构建器提出 在软件功能开发中&#xff0c;有时面临“一个复杂对象”的创建工作&#xff0c;该对象的每个功能接口由于需求的变化&#xff0c;会使每个功能接口发生变化&#xff0c;但是该对象使用每个功能实现一个接口的流程是稳定的。构建器就是解决该类现象的。构建就是定义…

C语言好题解析(一)

目录 选择题1选择题2选择题3选择题4编程题一 选择题1 执行下面程序&#xff0c;正确的输出是&#xff08; &#xff09;int x 5, y 7; void swap() {int z;z x;x y;y z; } int main() {int x 3, y 8;swap();printf("%d,%d\n",x, y);return 0; }A: 5,7 B: …