PyTorch常用工具(2)预训练模型

news2025/1/21 9:35:18

文章目录

  • 前言
  • 2 预训练模型

前言

在训练神经网络的过程中需要用到很多的工具,最重要的是数据处理、可视化和GPU加速。本章主要介绍PyTorch在这些方面常用的工具模块,合理使用这些工具可以极大地提高编程效率。

由于内容较多,本文分成了五篇文章(1)数据处理(2)预训练模型(3)TensorBoard(4)Visdom(5)CUDA与小结。

整体结构如下:

  • 1 数据处理
    • 1.1 Dataset
    • 1.2 DataLoader
  • 2 预训练模型
  • 3 可视化工具
  • 3.1 TensorBoard
  • 3.2 Visdom
  • 4 使用GPU加速:CUDA
  • 5 小结

全文链接:

  1. PyTorch中常用的工具(1)数据处理
  2. PyTorch常用工具(2)预训练模型
  3. PyTorch中常用的工具(3)TensorBoard
  4. PyTorch中常用的工具(4)Visdom
  5. PyTorch中常用的工具(5)使用GPU加速:CUDA

2 预训练模型

除了加载数据,并对数据进行预处理之外,torchvision还提供了深度学习中各种经典的网络结构以及预训练模型。这些模型封装在torchvision.models中,包括经典的分类模型:VGG、ResNet、DenseNet及MobileNet等,语义分割模型:FCN及DeepLabV3等,目标检测模型:Faster RCNN以及实例分割模型:Mask RCNN等。读者可以通过下述代码使用这些已经封装好的网络结构与模型,也可以在此基础上根据需求对网络结构进行修改:

from torchvision import models
# 仅使用网络结构,参数权重随机初始化
mobilenet_v2 = models.mobilenet_v2()
# 加载预训练权重
deeplab = models.segmentation.deeplabv3_resnet50(pretrained=True)

下面使用torchvision中预训练好的实例分割模型Mask RCNN进行一次简单的实例分割:

In: from torchvision import models
    from torchvision import transforms as T
    from torch import nn
    from PIL import Image
    import numpy as np
    import random
    import cv2

    # 加载预训练好的模型,不存在的话会自动下载
    # 预训练好的模型保存在 ~/.torch/models/下面
    detection = models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    detection.eval()
    def predict(img_path, threshold):
        # 数据预处理,标准化至[-1, 1],规定均值和标准差
        img = Image.open(img_path)
        transform = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
        ])
        img = transform(img)
        # 对图像进行预测
        pred = detection([img])
        # 对预测结果进行后处理:得到mask与bbox
        score = list(pred[0]['scores'].detach().numpy())
        t = [score.index(x) for x in score if x > threshold][-1]
        mask = (pred[0]['masks'] > 0.5).squeeze().detach().cpu().numpy()
        pred_boxes = [[(i[0], i[1]), (i[2], i[3])] \
                      for i in list(pred[0]['boxes'].detach().numpy())]
        pred_masks = mask[:t+1]
        boxes = pred_boxes[:t+1]
        return pred_masks, boxes

Transforms中涵盖了大部分对Tensor和PIL Image的常用处理,这些已在上文提到,本节不再详细介绍。需要注意的是转换分为两步,第一步:构建转换操作,例如transf = transforms.Normalize(mean=x, std=y);第二步:执行转换操作,例如output = transf(input)。另外还可以将多个处理操作用Compose拼接起来,构成一个处理转换流程。

In: # 随机颜色,以便可视化
    def color(image):
        colours = [[0, 255, 255], [0, 0, 255], [255, 0, 0]]
        R = np.zeros_like(image).astype(np.uint8)
        G = np.zeros_like(image).astype(np.uint8)
        B = np.zeros_like(image).astype(np.uint8)
        R[image==1], G[image==1], B[image==1] = colours[random.randrange(0,3)]
        color_mask = np.stack([R,G,B],axis=2)
        return color_mask
In: # 对mask与bounding box进行可视化
    def result(img_path, threshold=0.9, rect_th=1, text_size=1, text_th=2):
        masks, boxes = predict(img_path, threshold)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        for i in range(len(masks)):
            color_mask = color(masks[i])
            img = cv2.addWeighted(img, 1, color_mask, 0.5, 0)
            cv2.rectangle(img, boxes[i][0], boxes[i][1], color=(255,0,0), thickness=rect_th)
        return img
In: from matplotlib import pyplot as plt
    img=result('data/demo.jpg')
    plt.figure(figsize=(10, 10))
    plt.axis('off')
    img_result = plt.imshow(img)

TensorBoard界面

上述代码完成了一个简单的实例分割任务。如上图所示,Mask RCNN能够分割出该图像中的部分实例,读者可考虑对预训练模型进行微调,以适应不同场景下的不同任务。注意:上述代码均在CPU上进行,速度较慢,读者可以考虑将数据与模型转移至GPU上,具体操作可以参考第4节。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1347772.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Centos8之更换DNF源

一、DNF包管理器简介 DNF(Dandified Yum)是一个用于Fedora、CentOS和RHEL等Linux发行版的包管理器。它是Yum(Yellowdog Updater, Modified)的下一代版本,旨在提供更快、更可靠的软件包管理体验。以下是一些DNF包管理器…

【YOLO系列】yolo V1 ,V3,V5,V8 解释

文章目录 yolo V1 模型结构图通道数 的 物理意义是什么?输出 7730 怎么理解?YOLO v1 损失函数LOSS yolo V3yolo V5yolo V8 视频来源:https://www.bilibili.com/video/BV13K411t7Zs/ AI视频小助理 一、YOLO系列的目标检测算法,其中…

文章解读与仿真程序复现思路——中国电机工程学报EI\CSCD\北大核心《兼顾捕碳强度与可再生能源消纳的储能容量配置优化方法》

本专栏栏目提供文章与程序复现思路,具体已有的论文与论文源程序可翻阅本博主的专栏栏目《论文与完整程序》 这个标题涉及到两个主要方面:捕碳强度和可再生能源的消纳,以及与之相关的储能容量配置的优化方法。下面我会逐一解读这两个方面&…

ES6之生成器(Generator)

✨ 专栏介绍 在现代Web开发中,JavaScript已经成为了不可或缺的一部分。它不仅可以为网页增加交互性和动态性,还可以在后端开发中使用Node.js构建高效的服务器端应用程序。作为一种灵活且易学的脚本语言,JavaScript具有广泛的应用场景&#x…

OpenCV-11颜色通道的分离与合并

本次我们使用两个比较重要的API split(mat)将图像的通道进行分割。 merge((ch1,ch2,ch3))将多个通道进行融合。 示例代码如下: import cv2 import numpy as npimg np.zeros((480, 640, 3),…

HTML使用JavaScript的三种方式

要使用 JavaScript&#xff0c;你可以在 HTML 文件中的 <script> 标签中编写代码&#xff0c;或者将代码保存到一个单独的 .js 文件中并在 HTML 文件中引入。以下是一些常用的 JavaScript 使用方式&#xff1a; 内联 JavaScript&#xff1a;在 HTML 文件的 <script&g…

CodeWave赋能创新的全功能技术平台

目录 前言1 应用中心2 资产中心&#xff1a;汇聚创新能量&#xff0c;提供开发加速3 集成中心3.1 API管理3.2 报表管理 4 运维中心4.1 资源监控4.2 用户管理4.3 权限管理4.4 日志与监控 5 配置中心5.1 源码配置5.2 镜像仓库配置5.3 数据库配置5.4 报表配置5.5 资产配置5.6 品牌…

【小沐学NLP】Python实现K-Means聚类算法(nltk、sklearn)

文章目录 1、简介1.1 机器学习1.2 K 均值聚类1.2.1 聚类定义1.2.2 K-Means定义1.2.3 K-Means优缺点1.2.4 K-Means算法步骤 2、测试2.1 K-Means&#xff08;Python&#xff09;2.2 K-Means&#xff08;Sklearn&#xff09;2.2.1 例子1&#xff1a;数组分类2.2.2 例子2&#xff1…

有道翻译web端 爬虫, js

以下内容写于2023-12-28, 原链接为:https://fanyi.youdao.com/index.html#/ 1 在输入框内输入hello world进行翻译,通过检查发出的网络请求可以看到翻译文字的http接口应该是: 2 复制下链接最后的路径,去js文件中搜索下: 可以看到这里是定义了一个函数B来做文字的翻译接口函数…

IDEA JAVA Spring Boot运行Hello World(1.8)

参考资料&#xff1a; Spring Boot运行Hello World - 知乎https://blog.csdn.net/weixin_44005516/article/details/108293228(解决bug)SpringBoot入门第一章&#xff1a;Hello World-java教程-PHP中文网 (仅参考如何运行程序)java 8安装教程 java 8安装教程_java8安装-CSDN博…

开发Chrome插件获取当前页面Cookie

前言 看《重来》的时候有提到&#xff0c;把自己的需求做成产品&#xff0c;给更多人提供价值。 就是本篇的文章的由来。 我的需求场景&#xff0c;因为要用postman测公司开发的接口&#xff0c;公司接口通过cookie做鉴权&#xff0c; 所以我每次都要f12&#xff0c;然后从Ne…

Python筛选出批量下载的多时相遥感影像文件中缺失的日期

本文介绍批量下载大量多时相的遥感影像文件后&#xff0c;基于Python语言与每一景遥感影像文件的文件名&#xff0c;对这些已下载的影像文件加以缺失情况的核对&#xff0c;并自动统计、列出未下载影像所对应的时相的方法。 批量下载大量遥感影像文件对于RS学生与从业人员可谓十…

阿里后端实习一面面经

阿里后端实习一面面经 项目中使用到了es&#xff0c;es的作用&#xff1f; elasticsearch是一款非常强大的开源搜索引擎&#xff0c;具备非常多强大功能&#xff0c;可以帮助我们从海量数据中快速找到需要的内容 es中的重要概念&#xff1f; 群集&#xff1a;一个或多个节点…

JavaScript元素根据父级元素宽高缩放

/*** 等比缩放* param wrap 外部容器* param container 待缩放的容器* returns {{width: number, height: number}}* 返回值&#xff1a;width:宽度, height:高度*/aspectRatio(wrap: any, container: any) {// w h / ratio, h w * ratioconst wrapW wrap.width;const wrapH…

PyTorch中常用的工具(3)TensorBoard

文章目录 前言3 可视化工具3.1 TensorBoard 前言 在训练神经网络的过程中需要用到很多的工具&#xff0c;最重要的是数据处理、可视化和GPU加速。本章主要介绍PyTorch在这些方面常用的工具模块&#xff0c;合理使用这些工具可以极大地提高编程效率。 由于内容较多&#xff0c…

Unity坦克大战开发全流程——游戏场景——游戏界面——设置界面复用

游戏场景——游戏界面——设置界面复用 先将开始场景当中的设置面板复制过来 由于设置面板挂载的脚本都是相同的&#xff0c;在BeginScene中关闭设置面板时不会报空&#xff0c;而在GameScene中关闭设置面板时却会报空&#xff0c;这是因为监听事件中的单例模式调用的实例是Beg…

【时钟】分布式时钟HLC|Logical Time|Vector Clock|True Time

目录 简略 详细 附录 1 分布式系统不能使用NTP的原因 简略 分布式系统中不同于单机系统不能使用NTP(网络时间协议&#xff08;Network Time Protocol&#xff09;)来获取时间&#xff0c;所以我们需要一个特别的方式来获取分布式系统中的时间&#xff0c;mvcc也是使用time保证读…

Debezium发布历史40

原文地址&#xff1a; https://debezium.io/blog/2018/09/20/materializing-aggregate-views-with-hibernate-and-debezium/ 欢迎关注留言&#xff0c;我是收集整理小能手&#xff0c;工具翻译&#xff0c;仅供参考&#xff0c;笔芯笔芯. 使用 Hibernate 和 Debezium 实现聚合…

Linux安装Oracle调用dbca无响应和密码问题

Linux服务器下调用dbca无响应&#xff0c;或弹出如下提示&#xff1a; 则需要在Linux命令行窗口&#xff0c;输入如下命令即可 export DISPLAYip:0.0 注意&#xff1a;该ip应该为可显示图形桌面的机器ip地址。 该桌面需要已经安装了Xmanager-Passive&#xff08;比如 Xmanag…

Langchain-Chatchat开源库使用的随笔记(一)

笔者最近在研究Langchain-Chatchat&#xff0c;所以本篇作为随笔记进行记录。 最近核心探索的是知识库的使用&#xff0c;其中关于文档如何进行分块的详细&#xff0c;可以参考笔者的另几篇文章&#xff1a; 大模型RAG 场景、数据、应用难点与解决&#xff08;四&#xff09;R…