【pytorch深度学习】使用张量表征真实数据

news2025/1/15 16:38:54

使用张量表征真实数据

本文为书pytorch深度学习实战的一些学习笔记和扩展知识,涉及到的csv文件等在这里不会给出,但是我会尽量脱离这一些文件将书本想要表达的内容给展示出来。

文章目录

  • 使用张量表征真实数据
    • 1. 加载图像文件
    • 2. 改变布局
    • 3. 加载目录下图像
    • 4. 正规化数据
    • 5. 三维图像:体数据
    • 6. 表示表格数据
    • 7. 独热编码
    • 8. 分类与阈值
    • 9. 处理时间序列

学习目标:我们如何获取一段数据,并以一种适合训练深度学习模型的方式用张量表示?这就是这篇博文的主要内容,我们将介绍不同类型的数据,并展示如何将这些数据表示为张量。

1. 加载图像文件

使用imageio模块加载PNG格式图像:

import imageio
img_arr = imageio.v2.imread('./1.png')
print(img_arr)
print(img_arr.shape)
[[[45 46 48]
  [44 45 47]
  [44 45 49]
  ...
  [30 31 35]
  [30 31 35]
  [30 31 35]]

 [[44 45 47]
  [44 45 47]
  [43 44 48]
  ...
  [30 31 35]
  [30 31 35]
  [30 31 35]]

 [[44 45 47]
  [44 45 47]
  [44 45 49]
  ...
  [30 31 35]
  [30 31 35]
  [30 31 35]]

 ...

 [[44 45 49]
  [44 45 49]
  [44 45 49]
  ...
  [30 31 35]
  [30 31 35]
  [30 31 35]]

 [[44 45 49]
  [44 45 49]
  [44 45 49]
  ...
  [30 31 35]
  [30 31 35]
  [30 31 35]]

 [[44 45 49]
  [44 45 49]
  [44 45 49]
  ...
  [30 31 35]
  [30 31 35]
  [30 31 35]]]
(256, 252, 3)

这段输出显示的是一个RGB图像的像素数据,以及该图像的尺寸和颜色通道信息。

  1. 像素数据: 输出的大部分是一个三维数组,代表了图像中每个像素的颜色信息。

    • 数组的每个元素是一个长度为3的小数组,分别对应RGB颜色模型中的红色(R)、绿色(G)和蓝色(B)通道。
    • 每个通道的值范围通常是0-255,表示该颜色通道的强度。例如,[45, 46, 48] 代表一个像素点,其中红色通道的强度为45,绿色为46,蓝色为48。
  2. 图像尺寸: 输出的最后部分 (256, 252, 3) 描述了图像的尺寸和颜色通道。

    • 256: 图像的高度(像素行数)。
    • 252: 图像的宽度(像素列数)。
    • 3: 颜色通道数,这里是RGB的3个通道。

唯一要注意的是维度布局,处理图像数据的pytorch模块要求张量排列为C H W,分别表示通道高度宽度。

2. 改变布局

我们可以用张量的permute()方法将每个新的维度,利用旧维度得到一个合适的布局。

img = torch.from_numpy(img_arr)
out = img.permute(2,0,1)  #permute()函数对通道进行重排CxHxW—>HxWxC

注意点是,这个操作没有复制张量数据,而是让out使用与img相同的底层储存,所以img中的像素改变也会影响out中的数据。

3. 加载目录下图像

从一个输入目录中加载所有的PNG图像,并将它们储存在张量中:

batch_size = 3
batch = torch.zeros(batch_size, 3, 256, 256, dtype=torch.uint8)
import os
data_dir = '......'
filenames = [name for name in os.listdir(data_dir)
            if os.path.splitext(name)[-1]=='.png']
for i ,filename in enumerate(filenames):
    img_arr = imageio.imread(os.path.join(data_dir,filename))
    img_t = torch.from_numpy(img_arr)
    img_t = img_t.permute(2, 0, 1)
    img_t = img_t[:3]
    batch[i] = img_t
print(batch)

4. 正规化数据

神经网络通常使用浮点数张量作为输入,当输入数据的范围在0~1或-1~1时,神经网络表现出最佳的训练性能。

  • 归一化:
batch = batch.float()
batch /= 255.0  # 将数据归一到[0,1]
print(batch)

在数字图像处理中,每个像素的颜色通道通常用一个8位整数来表示,这意味着每个通道的值都是0到255,所以最大值就是255,这里的除以255就是这样来的。

  • 标准化:
batch_size = 3
batch = torch.zeros(batch_size, 3, 256, 256, dtype=torch.uint8)
import os
data_dir = '...'
filenames = [name for name in os.listdir(data_dir)
            if os.path.splitext(name)[-1]=='.png']
for i ,filename in enumerate(filenames):
    img_arr = imageio.cv2.imread(os.path.join(data_dir,filename))
    img_t = torch.from_numpy(img_arr)
    img_t = img_t.permute(2, 0, 1)
    img_t = img_t[:3]
    batch[i] = img_t
batch = batch.float()

n_channels = batch.shape[1]
for c in range(n_channels):
    mean = torch.mean(batch[:,c])
    std = torch.std(batch[:,c])
    batch[:,c] = (batch[:,c] - mean)/std
print(batch)
print(batch.shape)

在pytorch中,处理单张图片时候用到的是CHW格式,而在处理一个批次时,用到的是NCHW格式,N是批次大小。所以这里用到的是n_channels=batch.shape[1]

5. 三维图像:体数据

以上的内容都是对二维图像进行操作,而在某些情况下,例如CT的医学成像应用程序,这通常需要大量的从头到脚的大量图像序列,而这一些序列代表的就是人体的一份份切片,也就是二维图像。CT图像通常是灰度图像,意味着每个像素点不是彩色的,而是只有一个强度值,这个值表示在那个特定位置X射线通过身体组织后剩余的强度。这些强度值通常被编码为12位或更高,这意味着它们有比8位灰度图像(0到255的范围)更高的强度范围。这个高强度范围使得CT图像能够展现出更细致的组织密度差异,对于检测各种体内结构,如骨骼、器官和异常组织等非常有用。由于CT扫描是三维的,它包含了多个连续的切片,这些切片堆叠在一起构成了整个扫描区域的三维体数据。在处理CT数据时,通常会处理这个三维数组,而不仅仅是单个二维图像。三维体数据使得医生和算法能够更好地理解和分析身体内部的结构。

dir_path = '...'
vol_arr = imageio.volread(dir_path,'DICOM')
print(vol_arr.shape)
vol = torch.from_numpy(vol_arr).float()
vol = torch.unsqueeze(vol,0)
print(vol.shape)

6. 表示表格数据

加载csv文件最常用的三种是:python自带的csv模块、numpy、pandas

import torch
import numpy as np
wine_path = '.\\deeply\\data\\p1ch4\\tabular-wine\\winequality-white.csv'
wine_numpy = np.loadtxt(wine_path, dtype=np.float32, delimiter=";",skiprows=1)
print(wine_numpy)
wine_tensor = torch.from_numpy(wine_numpy)
print(wine_tensor.shape,wine_tensor.dtype)

col_list = next(csv.reader(open(wine_path), delimiter= ';'))
print(col_list)
[[ 7.    0.27  0.36 ...  0.45  8.8   6.  ]
 [ 6.3   0.3   0.34 ...  0.49  9.5   6.  ]
 [ 8.1   0.28  0.4  ...  0.44 10.1   6.  ]
 ...
 [ 6.5   0.24  0.19 ...  0.46  9.4   6.  ]
 [ 5.5   0.29  0.3  ...  0.38 12.8   7.  ]
 [ 6.    0.21  0.38 ...  0.32 11.8   6.  ]]
torch.Size([4898, 12]) torch.float32
['fixed acidity', 'volatile acidity', 'citric acid', 'residual sugar', 'chlorides', 'free sulfur dioxide', 'total sulfur dioxide', 'density', 'pH', 'sulphates', 'alcohol', 'quality']

7. 独热编码

独热编码比较陌生,但它很容易理解:

独热编码(One-Hot Encoding)是一种处理分类数据的方法,常用于机器学习和数据分析领域。在独热编码中,每个类别都被表示为一个二进制向量,这个向量中只有一个元素是1,其余都是0。这种表示法使得不同的类别之间在数值上彼此独立,没有任何数值上的大小关系。

特点:

  1. 二进制向量:每个类别由一个长度等于类别总数的向量表示。
  2. 单一激活位:在每个类别的向量中,只有一个位置被置为1(表示当前类别),其余位置都是0。
  3. 去除数值关系:由于每个类别都由独立的位表示,因此避免了不同类别间可能存在的数值上的比较或排序。

示例:

假设有一个类别变量,包含三个可能的类别:猫、狗和鸟。在独热编码中,这些类别将被编码为:

  • 猫:[1, 0, 0]
  • 狗:[0, 1, 0]
  • 鸟:[0, 0, 1]

应用:

  • 机器学习模型:许多机器学习算法和模型要求输入数据是数值形式。独热编码能够将非数值类别特征转换为数值形式,使其能够被这些模型处理。
  • 去除类别间的偏序关系:在一些情况下,类别数据可能会被错误地解释为有序数据。例如,如果直接使用数字1、2、3来表示猫、狗、鸟,模型可能会错误地假设狗大于猫,鸟大于狗。独热编码消除了这种假设。
###独热编码
###使用scatter_()方法获得独热编码,该方法将沿着参数提供的索引方向将源张量的值填充进输入张量中
target_onehot = torch.zeros(target.shape[0],10)
target_onehot.scatter_(1, target.unsqueeze(1), 1.0)
print(target_onehot,target_onehot.shape)

8. 分类与阈值

在这里插入图片描述

我们用第六点:表示表格数据,来作为例子,为了更好衔接并且脱离csv文件,我先把上面代码放下来:

import torch
import numpy as np

# 葡萄酒评分
wine_path = '.\\deeply\\data\\p1ch4\\tabular-wine\\winequality-white.csv'
wine_numpy = np.loadtxt(wine_path, dtype=np.float32, delimiter=";",skiprows=1)  ###skiprows=1表示不读第一行,因为其中包含列名
wine_tensor = torch.from_numpy(wine_numpy)

# 表示分数
data = wine_tensor[:,:-1]
target = wine_tensor[:,-1].long()
"""
data = wine_tensor[:,:-1]:这行代码获取了除了最后一列之外的所有列,即葡萄酒的特征。
target = wine_tensor[:,-1].long():这行代码获取了最后一列,并将其转换为长整型(long),这一列是葡萄酒的评分(目标标签)。
"""

# 计算均值和方差
data_mean = torch.mean(data, dim=0)
data_var = torch.var(data, dim=0)

# 标准化数据
data_normalized = (data - data_mean)/torch.sqrt(data_var)
bad_index = target <= 3  # 在这一列中筛选
print(target)
print(bad_index)
print(target.shape, bad_index.shape, bad_index.sum())
# 输出:
tensor([6, 6, 6,  ..., 6, 7, 6]) 
tensor([False, False, False,  ..., False, False, False])
torch.Size([4898]) torch.Size([4898]) tensor(20)

知道了我们不需要的索引,所以现在可以剔除不需要的数据了:

bad_data = data[bad_index]
print(bad_data.shape)
# 输出: torch.Size([20, 11])

下面就是一系列常见的操作了,我会在我觉得需要的地方加以补充和注释:

# 分为好中劣三等
bad_data = data[target <= 3]
mid_data = data[(target > 3) & (target < 7)]
good_data = data[target >= 7]
# 对每一列取均值
bad_mean = torch.mean(bad_data,dim=0)
mid_mean = torch.mean(mid_data,dim=0)
good_mean = torch.mean(good_data,dim=0)
# 使用二氧化硫总量的阈值来区分酒的好劣
total_sulfur_threshold = 141.83  # 设定阈值
total_sulfur_data = data[:,6]  # 从 data 张量中提取第7列(索引为6),这列代表总硫化物含量

# 比较筛选操作,这里不要和上面代码块意义一样,也是筛选,只是没有使用符号而已
predicted_indexes = torch.lt(total_sulfur_data,total_sulfur_threshold)
# 得到真正好酒的索引
actual_indexes = target > 5
print(actual_indexes.shape,actual_indexes.dtype,actual_indexes.sum(),actual_indexes)
# 检验预测与实际结果是否相符
# item()方法 	将数组元素复制到标准Python标量并返回它。
n_matches = torch.sum(actual_indexes & predicted_indexes).item()
n_predicted = torch.sum(predicted_indexes).item()
n_actual = torch.sum(actual_indexes).item()
print(f'预测得到的优质酒数量:',n_matches)
print(f'2700瓶中的预测准确率:',n_matches/n_predicted)
print(f'整个数据集中的预测准确率:',n_matches/n_actual)

9. 处理时间序列

在表示平面表中组织的数据时,我们发现表中的每一行都是独立于其他行的,它们的顺序毫无关系,换句话说没有那一列对那些出现较早或者较晚的信息进行编码。那如果有一组数据是要我们观察它是如何一年一年变化的或者有一组数据是共享单车的使用情况的,调查共享单车的使用情况有时间,季节等,这时就需要将一维多通道数据转化为二维多通道数据。如下图所示:
在这里插入图片描述

下面就来给出代码,以及我学习过程中有疑惑地方的解释:

bikes_numpy = np.loadtxt(
              '.\\deeply\\data\\p1ch4\\bike-sharing-dataset\\hour-fixed.csv',
              dtype = np.float32,
              delimiter = ",",
              skiprows =1,
              converters={1: lambda x: float(x[8:10])})
bikes = torch.from_numpy(bikes_numpy)

有疑惑的地方就是converters,这个的意思就是假设在csv文件里面有一列为日期XXXX-XX-XX(年月日)当我们困惑如何将它们存起来时,converters可以解决,上面代码的意思就是他将日期中的日XX提取出来并分开为X和X,两个单独为一列。当然你也可以根据自己需求选择。

# 使用view()按时间段调整数据
daily_bikes = bikes.view(-1, 24, bikes.shape[1])

view(): 这个方法用于重新塑形张量,而不改变其数据。这里的 -1, 24, bikes.shape[1] 意味着将数据重塑为一个三维张量。-1 表示该维度的大小由其他维度确定,24 表示每天有24小时,bikes.shape[1] 保持原始数据的特征数不变。

原始形状:一个二维数组,形状为 (17520, 17)

变换后的形状:一个三维数组,形状为 (730, 24, 17)

# 调整为 NxCxL 次序
daily_bikes = daily_bikes.transpose(1, 2)

transpose(1, 2): 这个方法用于交换张量中的两个维度。这里,它交换了第二和第三维度(维度索引从0开始计数)调整时候变为(730, 17, 24)

全部代码:

bikes_numpy = np.loadtxt(
              '.\\deeply\\data\\p1ch4\\bike-sharing-dataset\\hour-fixed.csv',
              dtype = np.float32,
              delimiter = ",",
              skiprows =1,
              converters={1: lambda x: float(x[8:10])})
bikes = torch.from_numpy(bikes_numpy)
# 使用view()按时间段调整数据
daily_bikes = bikes.view(-1, 24, bikes.shape[1])
# 调整为NxCxL次序
daily_bikes = daily_bikes.transpose(1,2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1197827.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Nacos入门到运行-超详细~windwos

&#x1f4da;目录 ⚙️简介:⚡️Nacos下载⌛解压到文件⚙️配置信息☘️修改 application.properties ⛵运行程序☘️安全问题☄️程序出现问题查看方式 ⛳Nacos开启鉴权⚡️跳过Token获取数据⚓接口请求&#xff1a; ✍️结束&#xff1a; ⚙️简介: Nacos:正如官网说的,一个…

【JAVA学习笔记】 68 - 网络——TCP编程、UDP编程

项目代码 https://github.com/yinhai1114/Java_Learning_Code/tree/main/IDEA_Chapter21/src 网络 一、网络相关概念 1.网络通讯 1.概念:两台设备之间通过网络实现数据传输 2.网络通信:将数据通过网络从一台设备传输到另一台设备 3. java.net包下提供了一系列的类或接口&a…

Ansible命令使用

ansible ansible的命令 ansible命令模块Pingcommand 模块shell 模块copy 模块file 模块fetch 模块cron 模块yum 模块service 模块user 模块group 模块script 模块setup 模块get_url模块stat模块unarchive模块unarchive模块 ansible的命令 /usr/bin/ansible  Ansibe AD-Hoc 临…

Xilinx DDR3 MIG系列——Xiinx DDR3官方手册ds176_7series_MIS

本节目录 一、官方手册ds176_7series_MIS 1、DDR3功能支持 2、MIG官方手册资源 3、Vivado DDR3 MIG IP资源表的导出与查看本节内容 Xilinx官方提供了手册&#xff0c;以便硬件开发者设计DDR3的硬件电路&#xff0c;和FPGA开发者使用MIG官方ip核完成项目的逻辑开发。 针对Xilin…

类和对象(2):构造函数,析构函数

一、构造函数 1.1 概念 构造函数是一种特殊的成员函数&#xff0c;名字与类名相同&#xff0c;创建类类型对象时编译器自动调用——初始化对象&#xff0c;在对象整个生命周期内只调用一次。 PS: 1. 构造函数无返回值&#xff1b;2. 构造函数支持重载。 class Date { public:…

【沐风老师】3DMAX克隆修改器插件教程

3DMAX克隆修改器插件&#xff0c;它通过增量平移、旋转和缩放输入几何体来创建对象的副本。在某些方面&#xff0c;它类似于 3ds Max 的内置阵列工具&#xff0c;但有一个主要优点 -克隆是完全参数化的&#xff0c;因此您可以随时更改重复项的数量及其分布。其他功能包括随机变…

Yum配置、相关命令和常见问题

搭建光盘源 将系统盘读取出来&#xff0c;找到系统盘下存放软件包的目录 2.配置yun仓库 输入命令进入仓库编辑 #必须以.repo结尾 :wq 回车保存退出 3.命令行输入yum repolist 查看yum仓库 配置硬盘源 1.将硬盘源拷贝到目录&#xff0c;或者挂载到目录 2.指定repo文件baseu…

Vue3-组合式API生命周期函数

一进入页面的请求一律放在setup中执行 如果有些代码需要在mounted生命周期中执行&#xff0c;并且写成函数的调用方式可以调用多次&#xff0c;并不会冲突&#xff0c;而是按照顺序依次执行 <script setup>onMounted(()>{console.log("mounted生命周期函数-逻辑…

SQL必知会(二)-SQL查询篇(7)-使用函数处理数据

第8课、使用函数处理数据 表8-1 DBMS 函数的差异 函数语法提取字符串的组成DB2、Oracle、PostgreSQL 和 SQLite 使用 SUBSTR()&#xff1b;MariaDB、Mysql 和 SQL Server 使用 SUBSTRING()数据类型转换Oracle 使用多个函数&#xff0c;每种类型的转换有一个函数&#xff1b;D…

指针传 1

1. 内存 在计算机中内存划分为⼀个个的内存单元&#xff0c;每个内存单元的⼤⼩取1个字节。每个内存单元放了八个bite位&#xff0c;就像我们在高中时住的八人间&#xff0c;那么每个人就代表了一个bite位。 每个内存单元也都有⼀个编号&#xff08;这个编号就相当 于我们所住…

聊天机器人框架Rasa资源整理

Rasa是一个主流的构建对话机器人的开源框架&#xff0c;它的优点是几乎覆盖了对话系统的所有功能&#xff0c;并且每个模块都有很好的可扩展性。参考文献收集了一些Rasa相关的开源项目和优质文章。 一.Rasa介绍 1.Rasa本地安装 直接Rasa本地安装一个不好的地方就是容易把本地…

Django框架FAQ

文章目录 问题1:Django数据库恢复问题2:null和blank的区别问题3:Django创建超级用户报错问题4:Django同源策略 问题1:Django数据库恢复 问题: 从仓库拉下来的Django项目,没有sqlite数据库和migrations记录,如何通过model恢复数据库 解决方法: # 步骤1:导出数据 # 不指定 ap…

如何配置《动手学强化学习》的环境

如何配置《动手学强化学习》的环境 网站&#xff1a;https://hrl.boyuai.com/chapter/intro github仓库&#xff1a;https://github.com/boyu-ai/Hands-on-RL/tree/main 可以看到该教程要求使用gym0.18.3版本的gym库&#xff0c;本教程可以用于解决绝大多数需要使用Pendulum-…

阿里云从公网IP转为弹性公网IP,同时绑定多个IP教程

先将云服务器ECS 转为弹性IP 购买新的弹性辅助网卡 购买弹性公网iP 购买之后选择绑定资源选择第二步购买的网卡 进入ECS 终端 ,输入 ip address可以查看到eth1 的对应mac 地址 终端输入 vi /etc/sysconfig/network-scripts/ifcfg-eth1保存一下信息 DEVICEeth1 #表示新配置…

【MySQL基本功系列】第二篇 InnoDB事务提交过程深度解析

通过上一篇博文&#xff0c;我们简要了解了MySQL的运行逻辑&#xff0c;从用户请求到最终将数据写入磁盘的整个过程。 当数据写入磁盘时&#xff0c;存储引擎扮演着关键的角色&#xff0c;它负责实际的数据存储和检索。 在MySQL中&#xff0c;有多个存储引擎可供选择&#xf…

免费博客搭建笔记

title: 免费博客搭建笔记 tags: 博客搭建 本次是对自己在网上学习github搭建一个 &#x1f447;个人免费静态网站的总结当然不是很完美&#x1f447; Bow to the new king iYANG (yangsongl1n.github.io) 接着我会从我的写笔记的个人习惯来逐步介绍如何搭建这个网站 1.写笔…

【解决】conda-script.py: error: argument COMMAND: invalid choice: ‘activate‘

运行conda activate base报错&#xff1a; 试了网上找到的解决方法都不行&#xff1a; 最后切换了一下terminal&#xff1a; 从powershell改回cmd&#xff08;不知道为什么一开始手贱换成powershell&#xff09; 就可以了

XML解析文档解析

1.首先是我的项目结构以及我所引入的依赖&#xff1a; 2.引入的依赖&#xff1a;jdk用的是17 <properties><maven.compiler.source>17</maven.compiler.source><maven.compiler.target>17</maven.compiler.target> </properties> <dep…

黑客技术(网络安全)-自学

前言 一、什么是网络安全 网络安全可以基于攻击和防御视角来分类&#xff0c;我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术&#xff0c;而“蓝队”、“安全运营”、“安全运维”则研究防御技术。 无论网络、Web、移动、桌面、云等哪个领域&#xff0c;都有攻与防…