Pytorch常用的函数(五)np.meshgrid()和torch.meshgrid()函数解析

news2025/3/5 7:04:50

Pytorch常用的函数(五)np.meshgrid()和torch.meshgrid()函数解析

我们知道torch.meshgrid()函数的功能是生成网格,可以用于生成坐标;

在numpy中也有一样的函数np.meshgrid(),但是用法不太一样,我们直接上代码进行解释。

1、两者在用法上的区别

比如:我要生成下图的xy坐标点,看下两者的实现方式:

在这里插入图片描述

np.meshgrid()

>>> import numpy as np
>>> w, h = 4, 2
# 注意,此时输入的是由w和h生成的一维数组
#      此时输出的是网格x的坐标grid_x以及网格y的坐标grid_y
>>> grid_x, grid_y  = np.meshgrid(np.arange(w), np.arange(h)) 

>>> grid_x
array([[0, 1, 2, 3],  
       [0, 1, 2, 3]])
>>> grid_y
array([[0, 0, 0, 0],
       [1, 1, 1, 1]])

torch.meshgrid()

>>> import torch
# 注意,此时输入的是由h和w生成的一维数组(和numpy中的输入顺序相反)
#      此时输出的是网格y的坐标grid_y以及网格x的坐标grid_x(和numpy中的输出顺序相反)
>>> grid_y, grid_x =  torch.meshgrid(
...         torch.arange(h),
...         torch.arange(w)
...     )
>>> grid_x
tensor([[0, 1, 2, 3],
        [0, 1, 2, 3]])
>>> grid_y
tensor([[0, 0, 0, 0],
        [1, 1, 1, 1]])

2、应用案例

2.1 利用np.meshgrid()来画决策边界

我们可以利用np.meshgrid()来画等高线图

# 等高线图
import numpy as np
import matplotlib.pyplot as plt

# 模拟海拔高度
def fz(x, y):
  z = (1 -x / 2 + x**5 + y**3) * np.exp(-x**2-y**2)
  return z

w = np.linspace(-4, 4, 100)
h = np.linspace(-2, 2, 100)

grid_x, grid_y = np.meshgrid(w, h)
z = fz(grid_x, grid_y)

plt.figure('Contour Chart',facecolor='lightgray')
plt.title('contour',fontsize=16)
plt.grid(linestyle=':')

cntr = plt.contour(
    grid_x, # 网格坐标矩阵的x坐标(2维数组)
    grid_y, # 网格坐标矩阵的y坐标(2维数组)
    z,      # 网格坐标矩阵的z坐标(2维数组)
    8,      # 等高线绘制8部分
    colors = 'black', # 等高线图颜色
    linewidths = 0.5 # 等高线图线宽
)
# 设置标签
plt.clabel(cntr, inline_spacing = 1, fmt='%.2f', fontsize=10)
# 填充颜色  大的是红色  小的是蓝色
plt.contourf(grid_x, grid_y, z, 8, cmap='jet')

plt.legend()
plt.show()

在这里插入图片描述

我们可以利用np.meshgrid()来画决策边界。

from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
import numpy as np

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC


# 使用sklearn自带的moon数据
X, y = make_moons(n_samples=100,noise=0.15,random_state=42)

# 绘制生成的数据
def plot_dataset(X,y,axis):
    plt.plot(X[:,0][y == 0],X[:,1][y == 0],'bs')
    plt.plot(X[:,0][y == 1],X[:,1][y == 1],'go')
    plt.axis(axis)
    plt.grid(True,which='both')


# 画出决策边界
def plot_pred(clf,axes):
    w = np.linspace(axes[0],axes[1], 100)
    h = np.linspace(axes[2],axes[3], 100)
    grid_x, grid_y = np.meshgrid(w, h)
    # grid_x 和 grid_y 被拉成一列,然后拼接成10000行2列的矩阵,表示所有点
    grid_xy = np.c_[grid_x.ravel(), grid_y.ravel()]
    # 二维点集才可以用来预测
    y_pred = clf.predict(grid_xy).reshape(grid_x.shape)
    # 等高线
    plt.contourf(grid_x, grid_y,y_pred,alpha=0.2)


ploy_kernel_svm_clf = Pipeline(
    steps=[
        ("scaler",StandardScaler()),
        ("svm_clf",SVC(kernel='poly', degree=3, coef0=1, C=5))
    ]
)


ploy_kernel_svm_clf.fit(X,y)

plot_pred(ploy_kernel_svm_clf,[-1.5, 2.5, -1, 1.5])
plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])
plt.show()

在这里插入图片描述

2.2 利用torch.meshgrid()生成网格所有坐标的矩阵

在目标检测YOLO中将图像划分为单元网格的部分就用到了torch.meshgrid()函数。

import torch
import numpy as np


def create_grid(input_size, stride=32):
    # 1、获取原始图像的w和h
    w, h = input_size, input_size
    # 2、获取经过32倍下采样后的feature map
    ws, hs = w // stride, h // stride
    # 3、生成网格的y坐标和x坐标
    grid_y , grid_x = torch.meshgrid([
        torch.arange(hs),
        torch.arange(ws)
    ])
    # 4、将grid_x和grid_y进行拼接,拼接后的维度为【H, W, 2】
    grid_xy = torch.stack([grid_x, grid_y], dim=-1).float()
    # 【H, W, 2】 -> 【HW, 2】
    grid_xy = grid_xy.view(-1, 2)
    return grid_xy



if __name__ == '__main__':
    print(create_grid(input_size=32*4))
# 生成网格所有坐标的矩阵
tensor([[0., 0.],
        [1., 0.],
        [2., 0.],
        [3., 0.],
        
        [0., 1.],
        [1., 1.],
        [2., 1.],
        [3., 1.],
        
        [0., 2.],
        [1., 2.],
        [2., 2.],
        [3., 2.],
        
        [0., 3.],
        [1., 3.],
        [2., 3.],
        [3., 3.]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1329232.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何进行USB丢弃攻击?

USB丢弃攻击,类似于一场表演艺术,您需要构建一个引人入胜的故事,激发目标的好奇心,让他们忽略基本的安全意识,插入您精心准备的USB设备! 本文章仅限娱乐,请勿模仿或进行违法活动! 一、选择放置…

鸿蒙声势浩大,程序员能从中看出什么机遇?

鸿蒙声势浩大,在最近的大厂商合作消息中我们能看到什么未来机遇?? 12 月 22 日消息,据华为开发者联盟服务消息,来自政务、金融、教育等行业的 8 家企业与华为签约并官宣启动鸿蒙原生应用开发。此前,美团、…

python调用GPT API

每次让gpt给我生成一个调用api的程序时,他经常会调用以前的一些api的方法,导致我的程序运行错误,所以这期记录一下使用新的方法区调用api 参考网址 Migration Guide,这里简要地概括了一下新版本做了哪些更改 OpenAI Python API l…

c++11特新:弱引用智能指针

弱引用智能指针std::weak_ptr可以看做是shared_ptr的助手,它不管理shared_ptr内部的指针。std::weak_ptr没有重载操作符*和->,因为它不共享指针,不能操作资源,所以它的构造不会增加引用计数,析构也不会减少引用计数…

Springboot是什么?Springboot详解!入门介绍

📫作者简介:小明java问道之路,2022年度博客之星全国TOP3,专注于后端、中间件、计算机底层、架构设计演进与稳定性建设优化,文章内容兼具广度、深度、大厂技术方案,对待技术喜欢推理加验证,就职于…

SpringBoot 多环境开发配置文件

在开发过程中,往往开发环境和生产环境需要不同的配置。为了兼容两种运行环境,提高开发效率,可以使用多环境开发配置文件。 配置文件结构大概是这样: application.yml -主启动配置文件(用于控制使用哪种环境配…

【Proteus仿真】【Arduino单片机】蓝牙遥控小车

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真Arduino单片机控制器,使LCD1602液晶,L298电机,直流电机,HC05/06蓝牙模块等。 主要功能: 系统运行后,LCD1602…

web前端html笔记2

新增状态标签<meter><progress> <meter> 属性 值 描述 high 数值 规定高值 low 数值 规定低值 max 数值 规定最大值 min 数值 规定最小值 optimum 数值 规定最优值 value 数值 规定当前值 <body> <meter high"50" …

hive命令启动出现classnotfound

环境&#xff1a;ambari集群三个节点node104、node105和node106&#xff0c;其中node105上有hiveserver2&#xff0c;并且三个节点均有HIVE CLIENT 注意&#xff1a;“./”指hive安装目录 其中装有hiveserver2的node105节点&#xff0c;由于某种需要向lib目录下上传了某些jar包…

图数据库NebulaGraph学习

1.图空间(Space)操作 1.1创建图空间&#xff0c;指定vid_type为整形 CREATE SPACE play_space (partition_num 10, replica_factor 1, vid_type INT64) COMMENT "运动员库表空间"; 1.2创建图空间&#xff0c;指定vid_type为字符串 CREATE SPACE play_space (…

docker安装ES:7.8和Kibana:7.8

本文适用于centos7,快速入手练习es语法 前置&#xff1a;安装docker教程docker、docker-component安装-CSDN博客 1.安装es 9200为启动端口&#xff0c;9300为集群端口 docker pull elasticsearch:7.8.0mkdir -p /mydata/elasticsearch/pluginsmkdir -p /mydata/elasticsear…

LangChain入门指南:定义、功能和工作原理

LangChain入门指南&#xff1a;定义、功能和工作原理 引言LangChain是什么&#xff1f;LangChain的核心功能LangChain的工作原理LangChain实际应用案例如何开始使用LangChain 引言 在人工智能的浪潮中&#xff0c;语言模型已成为推动技术革新的重要力量。从简单的文本生成到复…

【附带大模型训练数据】大模型系统优化:怎么计算模型所需的算力、内存带宽、内存容量和通信数据量?

大模型系统优化&#xff1a;怎么计算模型所需的算力、内存带宽、内存容量和通信数据量&#xff1f; 大模型需要多少算力&#xff1f;大模型需要几块GPU&#xff1f;预训练、领域微调区别?预训练需要多少数据?怎么做数据去重&#xff1f;怎么做数据增强&#xff1f;训练多少ep…

基于ssm+jsp理发店管理系统源码和论文

随着信息化时代的到来&#xff0c;管理系统都趋向于智能化、系统化&#xff0c;理发店管理系统也不例外&#xff0c;但目前国内的市场仍都使用人工管理&#xff0c;市场规模越来越大&#xff0c;同时信息量也越来越庞大&#xff0c;人工管理显然已无法应对时代的变化&#xff0…

Appium安装及配置

一、前置说明 Appium 是一个用于自动化移动应用程序的开源测试框架&#xff0c;它支持 Android 和 iOS&#xff0c;同时支持使用多种编程语言&#xff08;如 Java、Python、JavaScript 等&#xff09;进行测试脚本的编写。 二、操作步骤 1. 安装Node.js Appium Server 由 n…

创建Github Pages 仓库

Github Pages 仓库创建 1. 在 GitHub 上创建一个新仓库2. 在仓库中创建一个分支&#xff08;可选&#xff0c;可跳过&#xff09;3. 创建您的静态网站4. 启用 GitHub Pages5. 等待构建完成6. 访问您的网站 在 GitHub 上创建一个 GitHub Pages 仓库是相对简单的。GitHub Pages 允…

Element UI导航菜单之秘:无痕迹浏览与历史记录栈的管理

前言 需求 在使用 Element UI 的 el-menu 导航栏菜单时&#xff0c;发现 history 栈&#xff08;历史记录栈&#xff09;会不断缓存之前的记录&#xff0c;而在某些场景下我们可能不希望 history 栈&#xff08;历史记录栈&#xff09;中有之前的记录&#xff0c;即实现无痕迹流…

【扩散模型】7、GLIDE | 文本指引的图像生成和编辑

论文&#xff1a;GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models 代码&#xff1a;https://link.zhihu.com/?targethttps%3A//github.com/openai/glide-text2im 出处&#xff1a;OpenAI 一、背景 在扩散模型经过了一系列…

python 用OpenCV 将图片转视频

import os import cv2 import numpy as npcv2.VideoWriter&#xff08;&#xff09;参数 cv2.VideoWriter() 是 OpenCV 中用于创建视频文件的类。它的参数如下&#xff1a; filename&#xff1a;保存视频的文件名。 fourcc&#xff1a;指定视频编解码器的 FourCC 代码&#xf…

智能优化算法应用:基于卷尾猴算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于卷尾猴算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于卷尾猴算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.卷尾猴算法4.实验参数设定5.算法结果6.参考文…