深度学习_18_模型的下载与读取

news2024/11/16 9:27:47

在深度学习的过程中,需要将训练好的模型运用到我们要使用的另一个程序中,这就需要模型的下载与转移操作

代码:

import math
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 生成随机的数据集
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = torch.zeros(max_degree)
true_w[0:4] = torch.Tensor([5, 1.2, -3.4, 5.6])

# 生成特征
features = torch.randn((n_train + n_test, 1))
permutation_indices = torch.randperm(features.size(0))
# 使用随机排列的索引来打乱features张量(原地修改)
features = features[permutation_indices]
poly_features = torch.pow(features, torch.arange(max_degree).reshape(1, -1))
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)

# 生成标签
labels = torch.matmul(poly_features, true_w)
labels += torch.normal(0, 0.1, size=labels.shape)


# 以下是你原来的训练函数,没有修改
def evaluate_loss(net, data_iter, loss):
    metric = d2l.Accumulator(2)
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]


def l2_penalty(w):
    w = w[0].weight
    return torch.sum(w.pow(2)) / 2


def train(train_features, test_features, train_labels, test_labels, lambd,
          num_epochs=100):
    loss = d2l.squared_loss
    input_shape = train_features.shape[-1]
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))  # 模型
    batch_size = min(10, train_labels.shape[0])

    train_iter = d2l.load_array((train_features, train_labels.reshape(-1, 1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1, 1)),
                               batch_size, is_train=False)

    # 用于存储训练和测试损失的列表
    train_losses = []
    test_losses = []
    total_loss = 0
    total_samples = 0
    for epoch in range(num_epochs):
        for X, y in train_iter:
            out = net(X)
            y = y.reshape(-1, 1)  # 确保y是二维的
            l = loss(out, y) + lambd * l2_penalty(net)

            # 反向传播和优化器更新
            l.sum().backward()
            d2l.sgd(net.parameters(), lr=0.01, batch_size= batch_size)
            total_loss += l.sum().item()  # 统计所有元素损失
            total_samples += y.numel()  # 统计个数
        a = total_loss / total_samples  # 本次训练的平均损失
        train_losses.append(a)
        test_loss = evaluate_loss(net, test_iter, loss)
        test_losses.append(test_loss)
        total_loss = 0
        total_samples = 0
        print(f"Epoch {epoch + 1}/{num_epochs}:")
        print(f"训练损失: {a:.4f}   测试损失: {test_loss:.4f} ")
    print(net[0].weight)

    torch.save(net.state_dict(), "NetSave")  # 存模型
    net_try = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    print("net_try")
    print(net_try[0].weight)
    net_try.load_state_dict(torch.load("NetSave"))
    net_try.eval()  # 评估模式
    print("net_try_load")
    print(net_try[0].weight)
    # 绘制损失曲线
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='train', color='blue', linestyle='-', marker='.')
    plt.plot(test_losses, label='test', color='purple', linestyle='--', marker='.')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.title('Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.ylim(0, 1)  # 设置y轴的范围从0.01到100
    plt.show()


# 选择多项式特征中的前4个维度
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:], 0)

##  net.parameters() 是一个 PyTorch 模型的方法,用于返回模型所有参数的迭代器。这个迭代器产生模型中所有可学习的参数(例如权重和偏置)。

上述代码的模型是简单线性模型

net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))  # 模型

此模型的下载与储存如下

    torch.save(net.state_dict(), "NetSave")  # 存模型
    net_try = nn.Sequential(nn.Linear(input_shape, 1, bias=False))  # 搭建模型框架
    print("net_try")
    print(net_try[0].weight)
    net_try.load_state_dict(torch.load("NetSave"))  # 下载模型
    net_try.eval()  # 评估模式
    print("net_try_load")
    print(net_try[0].weight)

效果
在这里插入图片描述

所以说要想在另一个程序中将训练好的模型加载到上面去,首先要保存训练好的模型,另一个程序必须有和本模型一样的框架,再将训练好的模型权重加载到另一个程序框架内即可

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1490671.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

私有化部署自己的ChatGPT,免费开源的chatgpt-next-web搭建

随着AI的应用变广,各类AI程序已逐渐普及,尤其是在一些日常办公、学习等与撰写/翻译文稿密切相关的场景,大家都希望找到一个适合自己的稳定可靠的ChatGPT软件来使用。 ChatGPT-Next-Web就是一个很好的选择。它是一个Github上超人气的免费开源…

新零售SaaS架构:订单履约系统的概念模型设计

订单履约系统的概念模型 订单:客户提交购物请求后,生成的买卖合同,通常包含客户信息、下单日期、所购买的商品或服务明细、价格、数量、收货地址以及支付方式等详细信息。 子订单:为了更高效地进行履约,大订单可能会被…

安卓开发:计时器

一、新建模块 二、填写应用名称和模块名称 三、选择模块,Next 四、可以保持不变,Finish 五、相关目录文件 六、相关知识 七、?

正大国际:期货结算价是如何理解呢?结算价有什么作用?

如何理解期货结算价: 什么是商品期货当日结算价, 商品期货当日结算价是指某一期货合约当日交易期间成交价格按成交量的加权平均价。当日 无成交的,当日结算价按照交易所相关规定确定。 股指期货当日结算价是指某一期货合约当日交易期间最后一…

采购软件是如何改善采购周期?

采购是一个复杂的职能重叠网络,由市场分析、供应商选择、发布 RPF/RFQ、合同谈判等多个工作流程组成。此外,时间紧迫、满足客户期望等压力也使这项工作极具挑战性。因此,如果企业在采购过程中采取短视的方法,没有遵循适当的结构&a…

Pygame教程02:图片的加载+缩放+旋转+显示操作

------------★Pygame系列教程★------------ Pygame教程01:初识pygame游戏模块 Pygame教程02:图片的加载缩放旋转显示操作 Pygame教程03:文本显示字体加载transform方法 Pygame教程04:draw方法绘制矩形、多边形、圆、椭圆、弧…

海王星(Neptune)系列和大禹(DAYU)系列OpenHarmony智能硬件配置解决方案

海王星(Neptune)系列和大禹(DAYU)系列OpenHarmony智能硬件对OS的适配、部件拼装配置、启动配置和文件系统配置等。产品解决方案的源码路径规则为:vendor/{产品解决方案厂商}/{产品名称}_。 解决方案的目录树规则如下&…

React__ 二、React状态管理工具Redux的使用

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言redux状态管理安装redux创建文件 并使用传参action 总结 前言 redux状态管理插件的使用 提示:以下是本篇文章正文内容,下面案例可供参考…

Typescript 哲学 morn on funtion

函数重载 overload 有一些编程语言(eg:java)允许不同的函数参数,对应不同的函数实现。但是,JavaScript 函数只能有一个实现,必须在这个实现当中,处理不同的参数。因此,函数体内部就…

【系统需求分析报告-项目案例直接套用】

软件需求分析报告 软件开发要求项目建设内容物理设计安全系统设计安全网络安全设计应用安全设计用户安全管理性能设计稳定性设计安全性设计兼容性设计易操作性设计可维护行设计 软件开发全套精华资料过去进主页领取。

10亿数据如何快速插入MySQL

最快的速度把10亿条数据导入到数据库,首先需要和面试官明确一下,10亿条数据什么形式存在哪里,每条数据多大,是否有序导入,是否不能重复,数据库是否是MySQL? 有如下约束 10亿条数据,每条数据 1 Kb 数据内容是非结构化的用户访问日志,需要解析后写入到数据库 数据存放在…

2024新版SonarQube+JenKins+Github联动代码扫描(2)-SonarQube代码扫描

文章目录 前言一、docker方式安装sonar二、启动容器三、创建数据库四、启动sonarqube五、访问sonar六、如果访问报错-通过sonar日志定位问题七、修改密码八、汉化(看个人选择)九、扫描十、我遇到的Sonar报错以及解决办法 总结 前言 这是2024新版SonarQu…

【OpenGL编程手册08】 摄像机

一、说明 前面的教程中我们讨论了观察矩阵以及如何使用观察矩阵移动场景(我们向后移动了一点)。OpenGL本身没有摄像机(Camera)的概念,但我们可以通过把场景中的所有物体往相反方向移动的方式来模拟出摄像机,产生一种我们在移动的感…

关于python函数参数传递

参数传递 在 python 中,类型属于对象,对象有不同类型的区分,变量是没有类型的: 在下面的代码示例重,[1,2,3] 是 List 类型,“qayrup” 是 String 类型,而变量 a 是没有类型,它仅仅…

PyTorch深度学习实战(38)——StyleGAN详解与实现

PyTorch深度学习实战(38)——StyleGAN详解与实现 0. 前言1. StyleGAN1.1 模型介绍1.2 模型策略分析 2. 实现 StyleGAN2.1 生成图像2.2 风格迁移 小结系列链接 0. 前言 StyleGAN (Style-Generative Adversarial Networks) 是生成对抗网络 (Generative Ad…

使用AI创建令人惊叹的3D模型

老子云平台《《《《《 使内容创作者能够在一分钟内毫不费力地将文本和图像转换为引人入胜的 3D 资产。 文本转 3D 我们的文本转 3D 工具使创作者(包括那些没有 3D 经验的创作者)能够使用文本输入在短短一分钟内生成 3D 模型。 一句话生成3D模型 老子…

Day31|贪心算法1

贪心的本质是选择每一阶段的局部最优,从而达到全局最优。 无固定套路,举不出反例,就可以试试贪心。 一般解题步骤: 1.将问题分解成若干子问题 2.找出适合的贪心策略 3.求解每一个子问题的最优解 4.将局部最优解堆叠成全局最…

Unity2023.1.19_ECS_DOTS

Unity2023.1.19_ECS_DOTS 盲学-盲目的学习: 懒着自己整理就看看别人整理的吧,整合一下逻辑通了不少: DOTS/data oriented technology stack-面向数据的技术栈 ECS/Entities-Component-System Unity-Entities包 Entities提供ECS架构面向数…

C语言操作符详解(一)

一、操作符的分类 • 算术操作符&#xff1a; 、- 、* 、/ 、% • 移位操作符:<< >> • 位操作符: & | ^ • 赋值操作符: 、 、 - 、 * 、 / 、% 、<< 、>> 、& 、| 、^ • 单⽬操作符&#xff1a; &#xff01;、、--、&、*、、…

蓝桥杯练习系统(算法训练)ALGO-987 强力党逗志芃

资源限制 内存限制&#xff1a;256.0MB C/C时间限制&#xff1a;1.0s Java时间限制&#xff1a;3.0s Python时间限制&#xff1a;5.0s 问题描述 逗志芃励志要成为强力党&#xff0c;所以他将身上所以的技能点都洗掉了重新学技能。现在我们可以了解到&#xff0c;每个技…