pytorch中的transform用法

news2024/12/24 11:30:44

在 PyTorch 中,transform 主要用于数据预处理和数据增强,尤其在计算机视觉任务中,通过 torchvision.transforms 模块进行图像的变换。transforms 可以对图像进行一系列操作,如裁剪、旋转、缩放、归一化等,以增强数据集的多样性,并提高模型的泛化能力。

1. torchvision.transforms 模块概述

torchvision.transforms 是 PyTorch 提供的一个图像转换工具,它包含一系列的变换操作。常见的转换操作包括:

  • 图像大小调整(Resize)
  • 裁剪(Crop)
  • 图像翻转(Flip)
  • 颜色调整(Color Jitter)
  • 图像归一化(Normalization)
  • 转换为张量(ToTensor)

2. 常用的 transforms 操作

from torchvision import transforms
1) transforms.ToTensor()

将图像转换为 PyTorch 张量(Tensor),并且自动将图像的像素值缩放到 [0, 1] 的范围内。

transform = transforms.ToTensor()
image_tensor = transform(image)
2) transforms.Resize()

调整图像的大小,可以指定一个单一的大小或宽度/高度。

transform = transforms.Resize((224, 224))  # 调整为 224x224 的尺寸
image_resized = transform(image)
3) transforms.CenterCrop()transforms.RandomCrop()

CenterCrop 会从图像的中心裁剪出指定大小的区域;RandomCrop 会随机裁剪出一个指定大小的区域。

transform = transforms.CenterCrop(224)  # 从中心裁剪出 224x224 的区域
image_cropped = transform(image)

# 或者使用随机裁剪
transform = transforms.RandomCrop(224)
image_random_cropped = transform(image)
4) transforms.RandomHorizontalFlip()transforms.RandomVerticalFlip()

进行水平或垂直的随机翻转。

transform = transforms.RandomHorizontalFlip(p=0.5)  # 50% 的概率进行水平翻转
image_flipped = transform(image)
5) transforms.Normalize()

对图像的每个通道进行归一化。通常用来调整图像的颜色通道,使其符合模型训练时的要求。

transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image_normalized = transform(image_tensor)  # 对每个通道进行归一化
6) transforms.ColorJitter()

随机调整图像的亮度、对比度、饱和度和色相。适用于增强数据集的多样性。

transform = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
image_jittered = transform(image)
7) transforms.RandomRotation()

对图像进行随机旋转。

transform = transforms.RandomRotation(30)  # 随机旋转 -30 到 30 度之间
image_rotated = transform(image)

3. 多种 transforms 组合使用

通常,我们会将多个变换操作组合成一个 Compose,使得一个图像依次经过多个变换步骤。

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

image_transformed = transform(image)

上面的代码会将图像:

  1. 调整为 256x256
  2. 随机裁剪为 224x224
  3. 进行水平翻转
  4. 转换为张量
  5. 归一化图像

4. 结合 Dataset 使用 transforms

通常,我们会将 transformstorch.utils.data.Datasettorch.utils.data.DataLoader 结合使用,用于训练过程中的数据预处理。

from torchvision import datasets
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root='path_to_train_data', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

在上面的代码中,ImageFolder 是一个 PyTorch 提供的通用图像数据集类,用于加载目录结构为类标签的图像数据。transform 用于对数据集中的每个图像进行预处理。

5. 自定义 transform

如果 torchvision.transforms 中的预定义操作不能满足需求,我们还可以自定义一个转换类。例如,如果你想为每张图片添加噪声:

from PIL import Image
import numpy as np

class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.mean = mean
        self.std = std
    
    def __call__(self, image):
        image = np.array(image)
        noise = np.random.normal(self.mean, self.std, image.shape)
        noisy_image = image + noise
        noisy_image = np.clip(noisy_image, 0, 255)  # 保证像素值在 [0, 255] 范围内
        return Image.fromarray(noisy_image.astype(np.uint8))

# 使用自定义转换
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    AddGaussianNoise(mean=0, std=0.1),  # 添加高斯噪声
    transforms.ToTensor(),
])

image = Image.open('path_to_image.jpg')
transformed_image = transform(image)

总结

  • transforms 是 PyTorch 中处理图像数据的一组强大工具,适用于图像预处理和数据增强。
  • 通过 transforms.Compose() 可以组合多个转换操作。
  • ToTensor()Resize()RandomCrop()Normalize() 等是常用的转换。
  • 通过 DataLoader 可以高效地加载批量数据,并在训练过程中对每个样本应用转换。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2241161.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Taro React-Native IOS 打包发布

http网络请求不到 配置 fix react-native facebook::flipper::SocketCertificateProvider‘ (aka ‘int‘) is not a function or func_rn运行debug提示flipper-CSDN博客 Xcode 15(iOS17)编译适配报错_no template named function in namespace std-CS…

Chrome使用IE内核

Chrome使用IE内核 1.下载扩展程序IE Tab 2.将下载好的IE Tab扩展程序拖拽到扩展程序界面,之后重启chrome浏览器即可

C++基础:Pimpl设计模式的实现

2024/11/14: 在实现C17的Any类时偶然接触到了嵌套类的实现方法以及Pimpl设计模式,遂记录。 PIMPL ( Private Implementation 或 Pointer to Implementation )是通过一个私有的成员指针,将指针所指向的类的内部实现数据进行隐藏。 …

深入理解AIGC背后的核心算法:GAN、Transformer与Diffusion Models

深入理解AIGC背后的核心算法:GAN、Transformer与Diffusion Models 前言 随着人工智能技术的发展,AIGC(AI Generated Content,人工智能生成内容)已经不再是科幻电影中的幻想,而成为了现实生活中的一种新兴力…

LeetCode面试经典150题C++实现,更新中

用C实现下面网址的题目 https://leetcode.cn/problems/merge-sorted-array/?envTypestudy-plan-v2&envIdtop-interview-150 1、数组\字符串 88合并两个有序数组 以下是使用 C 实现合并两个有序数组的代码及测试用例 C代码实现 #include <iostream> #include &l…

python怎么安装numpy

1、在python官网https://pypi.python.org/pypi/numpy中找到安装的python版本对应的numpy版本。 例如&#xff1a; python版本是&#xff1a; 下载的对应numpy版本是&#xff1a; 2、将numpy下载到python的安装目录下的scripts文件夹中&#xff1b; 3、然后在cmd中执行以下命…

js中typeOf无法区分数组对象

[TOC]&#xff08;js中typeOf无法区分数组对象) 前提&#xff1a;很多时候我们在JS中用typeOf来判断值类型&#xff0c;如&#xff1a;typeOf ‘abc’//string ,typeOf 123 //number; 但当判断对象为数组时返回的仍是’object’ 这时候我们可以使用Object.prototype.toString.c…

JavaScript方法修改 input type=file 样式

html中的<input type "file">的样式很难修改&#xff0c;又跟页面风格很不匹配。我就尝试了几种方法&#xff0c;但是不管是用label还是用opacity:0都很麻烦&#xff0c;还老是出问题&#xff0c;所以最后还是用JavaScript来解决。 下面附上代码&#xff1a;…

JS爬虫实战之TikTok_Shop验证码

TikTok_Shop验证码逆向 逆向前准备思路1- 确认接口2- 参数确认3- 获取轨迹参数4- 构建请求5- 结果展示 结语 逆向前准备 首先我们得有TK Shop账号&#xff0c;否则是无法抓取到数据的。拥有账号后&#xff0c;我们直接进入登录。 TikTok Shop 登录页面 思路 逆向步骤一般分为…

MDBook 使用指南

MDBook 是一个灵感来自 Gitbook 的强大工具&#xff0c;专门用于创建电子书和文档。它能够将 Markdown 编写的内容编译成静态网站&#xff0c;非常适合项目文档、教程和书籍的发布。 个人实践过许多文档方案&#xff0c;如 hexo、hugo、WordPress、docsify 和 mdbook 等&#…

力扣 LeetCode 28. 找出字符串中第一个匹配项的下标(Day4:字符串)

解题思路&#xff1a; KMP算法 需要先求得最长相等前后缀&#xff0c;并记录在next数组中&#xff0c;也就是前缀表&#xff0c;前缀表是用来回退的&#xff0c;它记录了模式串与主串(文本串)不匹配的时候&#xff0c;模式串应该从哪里开始重新匹配。 next[ j - 1 ] 记录了 …

海思3403对RTSP进行目标检测

1.概述 主要功能是调过live555 testRTSPClient 简单封装的rtsp客户端库&#xff0c;拉取RTSP流&#xff0c;然后调过3403的VDEC模块进行解码&#xff0c;送个NPU进行目标检测&#xff0c;输出到hdmi&#xff0c;这样保证了开发没有sensor的时候可以识别其它摄像头的视频流&…

Python学习26天

集合 # 定义集合 num {1, 2, 3, 4, 5} print(f"num&#xff1a;{num}\nnum数据类型为&#xff1a;{type(num)}") # 求集合中元素个数 print(f"num中元素个数为&#xff1a;{len(num)}") # 增加集合中的元素 num.add(6) print(num) # {1,2,3,4,5,6} # 删除…

图论-代码随想录刷题记录[JAVA]

文章目录 前言深度优先搜索理论基础所有可达路径岛屿数量岛屿最大面积孤岛的总面积沉默孤岛Floyd 算法dijkstra&#xff08;朴素版&#xff09;最小生成树之primkruskal算法 前言 新手小白记录第一次刷代码随想录 1.自用 抽取精简的解题思路 方便复盘 2.代码尽量多加注释 3.记录…

测试自动化如何和业务流程结合?

测试自动化框架固然重要&#xff0c;但是最终自动化的目的都是为了业务服务的。 那测试自动化如何对业务流程产生积极影响&#xff1f; 业务流程的重要性 测试自动化项目并非孤立存在&#xff0c;其生命周期与被测试的应用程序紧密相关。项目的价值在于被整个开发团队所使用&a…

大模型基础BERT——Transformers的双向编码器表示

大模型基础BERT——Transformers的双向编码器表示 整体概况 BERT&#xff1a;用于语言理解的深度双向Transform的预训练 论文题目&#xff1a;BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding Bidirectional Encoder Representations from…

Leetcode 56-合并区间

以数组 intervals 表示若干个区间的集合&#xff0c;其中单个区间为 intervals[i] [starti, endi] 。请你合并所有重叠的区间&#xff0c;并返回 一个不重叠的区间数组&#xff0c;该数组需恰好覆盖输入中的所有区间 。 //按左边界排序 //startintervals[i][0],endintervals…

【golang-技巧】-线上死锁问题排查-by pprof

1.背景 由于目前项目使用 cgo golang 本地不能debug, 发生死锁问题&#xff0c;程序运行和期待不一致&#xff0c;通过日志排查可以大概率找到 阻塞范围&#xff0c;但是不能找到具体问题在哪里&#xff0c;同时服务器 通过k8s daemonset 部署没有更好的方式暴露端口 获取ppr…

7天用Go从零实现分布式缓存GeeCache(总结)

1. Lru包 1.1 lru算法简要概述 &#xff08;作者&#xff1a;豆豉辣椒炒腊肉/链接&#xff1a;https://juejin.cn/post/6844904049263771662&#xff09; LRU算法全称是最近最少使用算法&#xff08;Least Recently Use&#xff09;&#xff0c;广泛的应用于缓存机制中。当缓…

oracle查询字段类型长度等字段信息

1.查询oracle数据库的字符集 SELECT * FROM NLS_DATABASE_PARAMETERS WHERE PARAMETER NLS_CHARACTERSET; 2.查询字段长度类型 SELECT * FROM user_tab_columns WHERE table_name user AND COLUMN_NAME SNAME 请确保将user替换为您想要查询的表名。sname为字段名 这里的字…