PyTorch基础05_模型的保存和加载

news2025/1/10 17:19:41

目录

一、模型定义组件——重构线性回归

二、模型的加载和保存

2、序列化保存对象和加载

 3、保存模型参数


一、模型定义组件——重构线性回归

回顾之前的手动构建线性回归案例:

1.构建数据集2.加载数据集(数据集转换为迭代器);3.参数初始化4.线性回归(模型函数,前向回归);5.损失函数(均分误差,反向传播对象);6.优化器(梯度更新);7.训练数据集8.预测数据

import math
import torch
import random
from sklearn.datasets import make_regression

def build_data():
    """
    构建数据集
    """
    # 噪声
    noise = random.randint(1,5)
    # 样本数量
    sample = 1000
    # 目标y的真实偏置 bias
    bias = 0.5
    # coef:真实系数 coef=True 表示希望函数返回生成数据的真实系数
    x,y,coef = make_regression(n_samples=sample,n_features=4,bias=bias,noise=noise,coef=True,random_state=666)
    # 数据转换成张量
    x = torch.tensor(x,dtype=torch.float32)
    y = torch.tensor(y,dtype=torch.float32)
    coef = torch.tensor(coef,dtype=torch.float32)

    return x,y,coef,bias


def load_data(x,y):
    """
    加载数据集
    将数据集转换为迭代器,以便在训练过程中进行批量处理。
    """
    # 单批次数量
    batch_size = 16
    # 样本总数量
    n_samples = x.shape[0]
    # 一轮训练的次数
    n_batches = math.ceil(n_samples/batch_size)
    # 构建数据索引
    indices = list(range(n_samples)
    # 打乱索引
    random.shuffle(indices)
    # 从每批次中取出的数据
    for i in range(0,n_batches):
        start = i*batch_size
        end = min((i+1)*batch_size,n_samples)
        # 数据下标切片
        index = indices[start,end]
        # 返回数据
        return x[index],y[index]


def initialize(n_feature):
    """
    参数初始化
    随机初始化权重w, 并将偏置b初始化为1
    """
    torch.manual_seed(66)
    # 权重 正态分布
    w = torch.randn(n_feature,required_grad=True,dtype=torch.float32)
    # 偏置
    b = torch.tensor(0.0,required_grad=True,dtype=torch.float32)
    
    return w,b


def regressor(x,w,b):
    """
    线性回归
    模型函数 "前向传播"
    """
    return x@w + b

def MSE(y_pred,y_true):
    """
    损失函数
    均分误差 反向传播的对象
    """
    return torch.mean((y_pred-y_true)**2)


def optim_step(w,b,dw,db,lr):
    """
    优化器
    梯度更新 向梯度下降的方向更新
    """
    # 修改的不是原tenser而是tensor的data
    w.data -= lr*dw.data
    b.data -= lr*db.data


def train():
    """
    训练数据集
    """
    # 创建数据
    x,y,coef,bias = build_data()
    # 初始化参数
    w,b = initialize(x.shape[0])
    # 设置训练参数
    lr = 0.1 # 学习率
    epoch = 500 # 迭代次数
    # 训练数据
    # 迭代循环
    for i in range(epoch):
        total_loss = 0 # 误差总和
        count = 0 # 训练次数
        # 批次循环
        for batch_x,batch_y_true in load_data(x,y):
            count += 1
            # 代入线性回归得出预测值
            batch_y_pred = regressor(x,w,b)
            # 计算损失函数
            loss = MSE(batch_y_pred,btach_y_true)
            tatol_loss += loss
            # 梯度清零
            if w.grad is not None:
                w.data.zero_()
            if b.grad is not None:
                b.data.zero_()
            # 反向传播 计算梯度
            loss.backward()
            # 梯度更新 得出预测w和b
            w,b = optim_step(w,b,w.grad,b.grad,lr)
        # 打印数据
        print(f'epoch:{i},loss:{total_loss/count}')
    return w.data,b.data,coef,bias


def detect(x,w,b):
    """
    预测数据
    """
    return torch.matmul(x.type(torch.float32),w) + b


if __name__ == "__main__":
    w,b,coef,bias = train()
    print(f'真实系数:{coef},真实偏置:{bias}')
    print(f'预测系数:{w},预测偏置:{b}')
    y_pred = detect(torch.tensor([[4,5,6,6],[7,8,8,9]]),w,b)
    print(f'y_pred:{y_pred}')

 这个手动实现的过程对深度学习的思维很有帮助,现在结合上一篇的官方数据加载器,我们将它重构:

import torch
from sklearn.datasets import make_regression
from torch.utils.data import DataLoader,TensorDataset

def build_dataset():
    """
    构建数据集
    """
    noise = random.randint(1,5)
    bias = 14.5
    X,y,coef = make_regression(n_samples=1000,
                               n_features=4,
                               coef=True,
                               bias=bias,
                               noise=noise,
                               random_state=66)
    X = torch.tensor(X,dtype=torch.float32)
    y = torch.tensor(y,dtype=torch.float32)
    return X,y,coef,bias

def train():
    """
    训练数据集
    """
    # 01 加载数据
    X,y,coef,bias = build_dataset()
    # 02 构建模型
    """
    torch.nn.Linear(in_features,out_features)
    in_features 输入的特征数量——w数量
    out_features 输出的数量——y数量
    """
    model = torch.nn.Linear(X.shape[1],1)
    # 03 初始化参数
    # 若不手动初始化则会自动初始化 这里选择自动初始化
    # 04 构建损失函数
    loss_fn = torch.nn.MSELoss() # 均方误差
    # 05 构建优化器
    sgd = torch.optim.SGD(model.parameter(),lr) # 传入模型参数和学习率
    # 06 训练
    epoch = 500
    # 06.1 循环次数
    for i in range(epoch):
        # 06.2 计算损失
        data_loader = DataLoader(data,batch_size=16,shuffle=True) # 按小批次划分并随机打乱
        total_loss = 0
        count = 0
        for x,y in data_loader:
            count += 1
            y_pred = model(x) # 模型预测的输出值
            loss = loss_fn(y_pred,y)
            total_loss += loss
            # 06.3 梯度清零
            sgd.zero_grad()
            # 06.4 反向传播
            loss.backward()
            # 06.5 更新参数
            sgd.step()
            print(f'epoch:{epoch},loss:{total_loss/count}') # 打印每一批次的结果
    # 07 保存模型参数
    print(f'weight:{model.weight},bias:{model.bias}')
    print(f'true_weight:{coef},true_bias:{bias}')


if __name__ == '__main__':
    train()
            

可见得方便了许多。

 

二、模型的加载和保存

训练一个模型通常需要大量的数据、时间和计算资源。通过保存训练好的模型,可以满足后续的模型部署、模型更新、迁移学习、训练恢复等各种业务需要求。

1、标准网络模型构建

class MyModle(nn.Module):
    """
    标准网络模型构建
    """
    def __init__(self, input_size, output_size):
        super(MyModle, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        output = self.fc3(x)
        return output

2、序列化保存对象和加载

import torch
import torch.nn as nn

class MyModle(nn.Module):
    """
    标准网络模型构建
    """
    def __init__(self, input_size, output_size):
        super(MyModle, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        output = self.fc3(x)
        return output


def train01():
    """
    保存
    """
    model = MyModle(10,5)
    # 序列化方式保存模型对象
    torch.save(model, "./data/model.pkl")


def detect01():
    """
    加载
    """
    # 注意设备问题
    model = torch.load("./data/model.pkl", map_location="cpu")
    print(model)



if __name__ == "__main__":
    test01()
    test02()

 3、保存模型参数

更常用的保存和加载方式,只需要保存权重、偏执、准确率等相关参数,都可以在加载后打印观察。

import torch
import torch.nn as nn
import torch.optim as optim


class MyModle(nn.Module):
    """
    标准网络模型构建
    """
    def __init__(self, input_size, output_size):
        super(MyModle, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        output = self.fc3(x)
        return output


def train02():
    model = MyModle(input_size=128, output_size=32)
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    # 自己构建要存储的模型参数
    save_dict = {
        "init_params": {
            "input_size": 128,  # 输入特征数
            "output_size": 32,  # 输出特征数
        },
        "accuracy": 0.99,  # 模型准确率
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    torch.save(save_dict, "model_dict.pth")


def detect02():
    save_dict = torch.load("model_dict.pth")
    model = MyModle(
        input_size=save_dict["init_params"]["input_size"],
        output_size=save_dict["init_params"]["output_size"],
    )
    # 初始化模型参数
    model.load_state_dict(save_dict["model_state_dict"])
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    # 初始化优化器参数
    optimizer.load_state_dict(save_dict["optimizer_state_dict"])
    # 打印模型信息
    print(save_dict["accuracy"])
    print(model)


if __name__ == "__main__":
    train02()
    detect02()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2248618.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《Python基础》之函数的用法

一、简介 在 Python 中,函数是一段可重用的代码块,用于执行特定的任务。函数可以帮助你将代码模块化,提高代码的可读性和可维护性。 函数的用途 代码重用:通过函数,你可以将常用的代码块封装起来,避免重复…

java:aqs实现自定义锁

aqs采用模板方式设计模式,需要重写方法 package com.company.aqs;import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.AbstractQueuedSynchronizer; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock;…

【小白学机器学习34】基础统计2种方法:用numpy的方法np().mean()等进行统计,pd.DataFrame.groupby() 分组统计

目录 1 用 numpy 快速求数组的各种统计量:mean, var, std 1.1 数据准备 1.2 直接用np的公式求解 1.3 注意问题 1.4 用print() 输出内容,显示效果 2 为了验证公式的背后的理解,下面是详细的展开公式的求法 2.1 均值mean的详细 2.2 方差…

vue2 中使用 Ag-grid-enterprise 企业版

文章目录 问题Vue2 引入企业版不生效npm run dev 时卡住了94% after seal 卡在这里了测试打包源 git 解决方案记录 问题 我想用企业版的树状表格 Vue2 引入企业版不生效 编译引入 // vue.config.js module.exports {transpileDependencies: ["ag-grid-enterprise"…

RESTful快速开发

(3)RESTful快速开发 (2)中的控制器仍然存在大量的冗余代码 问题1: 每个方法的RequestMapping注解中都定义了访问路径/users,重复性太高 问题2:每个方法的RequestMapping注解中都要使用method属…

万能门店小程序管理系统 doPageGetFormList SQL注入漏洞复现

0x01 产品简介 万能门店小程序管理系统是一款功能强大的工具,旨在为各行业商家提供线上线下融合的全方位解决方案。是一个集成了会员管理和会员营销两大核心功能的综合性平台。它支持多行业使用,通过后台一键切换版本,满足不同行业商家的个性化需求。该系统采用轻量后台,搭…

【作业九】RNN-SRN-Seq2Seq

点击查看作业内容 目录 1 实现SRN (1)使用numpy实现 (2)在(1)的基础上,增加激活函数tanh (3)使用nn.RNNCell实现 (4)使用nn.RNN实现 2 使用R…

Emgu (OpenCV)

Emgu Github Emgu 环境: Emgu CV 4.9.0 netframework 4.8 1、下载 libemgucv-windesktop-4.9.0.5494.exe 安装后,找到安装路径下的runtime文件夹复制到c#项目Debug目录下 安装目录 c# Debug目录

YOLOv8模型pytorch格式转为onnx格式

一、YOLOv8的Pytorch网络结构 model DetectionModel((model): Sequential((0): Conv((conv): Conv2d(3, 64, kernel_size(3, 3), stride(2, 2), padding(1, 1))(act): SiLU(inplaceTrue))(1): Conv((conv): Conv2d(64, 128, kernel_size(3, 3), stride(2, 2), padding(1, 1))(a…

澳洲房产市场数据清洗、聚类与可视化综合分析

本项目涉及数据清洗及分析时候的思路,如果仅在CSDN中看,可能会显得有些乱,建议去本人和鲸社区对应的项目中去查看,源代码和数据集都是免费下载的。 声明:本项目的成果可无偿分享,用于学习交流。但请勿用于…

IT服务团队建设与管理

在 IT 服务团队中,需要明确各种角色。例如系统管理员负责服务器和网络设备的维护与管理;软件工程师专注于软件的开发、测试和维护;运维工程师则保障系统的稳定运行,包括监控、故障排除等。通过清晰地定义每个角色的职责&#xff0…

go-zero(八) 中间件的使用

go-zero 中间件 一、中间件介绍 中间件(Middleware)是一个在请求和响应处理之间插入的程序或者函数,它可以用来处理、修改或者监控 HTTP 请求和响应的各个方面。 1.中间件的核心概念 请求拦截:中间件能够在请求到达目标处理器之…

Qt Graphics View 绘图架构

Qt Graphics View 绘图架构 "QWGraphicsView.h" 头文件代码如下&#xff1a; #pragma once#include <QGraphicsView>class QWGraphicsView : public QGraphicsView {Q_OBJECTpublic:QWGraphicsView(QWidget *parent);~QWGraphicsView();protected:void mouseM…

【eNSP】动态路由协议RIP和OSPF

动态路由RIP&#xff08;Routing Information Protocol&#xff0c;路由信息协议&#xff09;和OSPF&#xff08;Open Shortest Path First&#xff0c;开放式最短路径优先&#xff09;是两种常见的动态路由协议&#xff0c;它们各自具有不同的特点和使用场景。本篇会对这两种协…

差分 + 模拟,CF 815A - Karen and Game

目录 一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 二、解题报告 1、思路分析 2、复杂度 3、代码详解 一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 815A - Karen and Game 二、解题报告 1、思路分析 一个经典的差分数组的…

vue3【实战】响应式的登录界面

效果预览 WEB 端效果 移动端效果 技术方案 vue3 vite Element Plus VueRouter UnoCSS TS vueUse AutoImport 技术要点 响应式设计 移动端&#xff1a;图片切换为绝对定位&#xff0c;下移一层&#xff0c;成为背景图片 <el-imageclass"w-screen h-screen lt-md…

加速科技精彩亮相中国国际半导体博览会IC China 2024

11月18日—20日&#xff0c;第二十一届中国国际半导体博览会&#xff08;IC China 2024&#xff09;在北京国家会议中心顺利举办&#xff0c;加速科技携重磅产品及全系测试解决方案精彩亮相&#xff0c;加速科技创始人兼董事长邬刚受邀在先进封装创新发展论坛与半导体产业前沿与…

php反序列化1_常见php序列化的CTF考题

声明&#xff1a; 以下多内容来自暗月师傅我是通过他的教程来学习记录的&#xff0c;如有侵权联系删除。 一道反序列化的CTF题分享_ctf反序列化题目_Mr.95的博客-CSDN博客 一些其他大佬的wp参考&#xff1a;php_反序列化_1 | dayu’s blog (killdayu.com) 序列化一个对象将…

RustDesk 搭建

RustDesk 服务端下载&#xff1a;https://github.com/rustdesk/rustdesk-server/releases RustDesk 客户端下载&#xff1a;https://github.com/rustdesk/rustdesk/releases RustDesk 官方部署教程&#xff1a;https://rustdesk.com/docs/zh-cn/ 1&#xff1a;RustDesk 概览# 1…

Qt读写Usb设备的数据

Qt读写Usb设备的数据 问题:要读取usb设备进行通讯&#xff0c;qt好像没有对应的库支持。解决&#xff1a;libusbwindow下载 :Linux下载: QtUsb 开源的第三方库库里面的函数说明&#xff1a;window版本&#xff1a;Linux中也提供的直接下载测试代码&#xff1a;库下载&#xff1…