05.cv PIL torch

news2024/12/22 14:31:31

文章目录

  • opencv
    • rectangle
    • circle
    • 保存图像
  • PIL
    • resize
    • 与Opencv的互操作
  • torch
    • tensor的创建和定义
    • Torch的自动梯度计算
    • Torch的模块
    • torch的训练流程

opencv

  1. plt.imshow 以RGB形式显示
  2. cv2.imread 读取的是BGR
import cv2
image = cv2.imread('image.png') #加载图像
print(image.shape, image.dtype, type(image))
# (528, 604, 3), dtype('uint8'), numpy.ndarray

plt.imshow(image[:, :, [2, 1, 0]]) # 利用了numpy的选择功能实现BGR和RGB对调

hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) #
plt.imshow(hsv)

plt.show()

rectangle

在这里插入图片描述

circle

circle(img, center, radius, color[, thickness[, lineType[, shift ] ] ])

x, y = 230, 270
cv2.circle(image, (x, y), 16, (0, 0, 255), -1, lineType=16)

保存图像

cv2.imwrite('image_change.png', image)

PIL

Python标准图像处理库
from PIL import Image

img = Image.open('image.png')
plt.imshow(img)
print(type(img)) # PIL.PngImagePlugin.PngImageFile

resize

resized = img.resize((600, 300))

与Opencv的互操作

cv_img = cv2.imread('image.png')
pil_img = Image.fromarray(cv_img)
plt.imshow(pil_img)

torch

numpy重点在cpu,而torch重点在gpu
import torch

tensor的创建和定义

tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False)

a = torch.tensor([1, 2, 3], dtype=torch.float32)
print(a, a.shape, a.dtype)
# tensor([1., 2., 3.]) torch.Size([3]) torch.float32

a = torch.zeros((3, 3))[.long(), float(), ...]

a = torch.tensor([[2, 3, 2], [2, 2, 1]]).float()

a = torch.rand((3, 3))

a = torch.ones((3, 3))

a = torch.eye(3, 3) #单位阵

k = a[None, :, :, None]

u = k.squeeze() # 去掉为1的维度
l = u.unsqueeze(2) # 增加维度

Torch的自动梯度计算

a = torch.tensor(10.,requires_grad=True)
b = torch.tensor(5.,requires_grad=True)
c = a * (b * 1.5)
c.backward()
print(a.grad, b.grad)

Torch的模块

import torch.nn as nn
class Model(nn.Module):
	def __init__(self):
		super(Model, self).__init__()
		self.conv1 = nn.Conv2d(3, 64, 3)
		self.relu1 = nn.ReLU(True)
	def forward(self, x):
		self.conv1(x)
		self.relu1(x)
		return x
model = Model()
dir(model)
['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set', '_parameters', '_register_load_state_dict_pre_hook', '_register_state_dict_hook', '_replicate_for_data_parallel', '_save_to_state_dict', '_slow_forward', '_state_dict_hooks', '_state_dict_pre_hooks', '_version', '_wrapped_call_impl', 'add_module', 'apply', 'bfloat16', 'buffers', 'call_super_init', 'children', 'compile', 'conv1', 'cpu', 'cuda', 'double', 'dump_patches', 'eval', 'extra_repr', 'float', 'forward', 'get_buffer', 'get_extra_state', 'get_parameter', 'get_submodule', 'half', 'ipu', 'load_state_dict', 'modules', 'named_buffers', 'named_children', 'named_modules', 'named_parameters', 'parameters', 'register_backward_hook', 'register_buffer', 'register_forward_hook', 'register_forward_pre_hook', 'register_full_backward_hook', 'register_full_backward_pre_hook', 'register_load_state_dict_post_hook', 'register_module', 'register_parameter', 'register_state_dict_pre_hook', 'relu1', 'requires_grad_', 'set_extra_state', 'share_memory', 'state_dict', 'to', 'to_empty', 'train', 'training', 'type', 'xpu', 'zero_grad']

for name, layer in model._modules.items():
	print(name, layer)
model.state_dict() # 模型中的参数状态
conv1_weight = model.state_dict()['conv1.weight']
print(conv1_weight.shape) # torch.Size([64, 3, 3, 3])

torch的训练流程

  1. 定义数据集Dataset
import torchvision.transforms.functional as T
class MyDataset:
	def __init__(self, directory):
		self.directory = directory
		self.files = load_files(directory)
	def __len__(self):
		return len(self.files)
	def __getitem__(self, index):
		return T.to_tensor(image), label
dataset = MyDataset()
dataset_length = len(dataset)
dataset_item = dataset[10]
  1. 定义模型结构
import torch.nn as nn
class ResNet18(nn.Module):
	def __init__(self, num_classes):
		super(ResNet18, self).__init__()

		# 定义每个层的信息
		self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3)
		self.bn1 = norm_layer(self.inplanes)
		# ...省略
		self.fc = nn.Linear(512 * block.expansion, num_classes)
	def forward():
		x = self.conv1(x)
		x = self.bn1(x)
		...
		x = x.reshape(x.size(0), -1)
		x = self.fc(x)
		return x
	def __call__(self, x):
		return forward(x) #model(image)类似这样执行
  1. 定义Dataset实例和DataLoader实例
train_dataset = MyDataset("./train")
train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True,num_workers=24)

# DataLoader实现了多进程数据并行加载
# window下num_workers=0最好给0,支持不好 linux正常
  1. 定义模型实例
model = ResNet18(1000)
model.cuda() # 转到cuda
  1. 定义loss函数
loss_function = nn.CrossEntropyLoss() #交叉熵损失
---------
predict = torch.tensor([0.1, 0.1, 0.5]).requires_grad_(True)
ground_thruth_one_hot = torch.tensor([0, 0, 1])
loss2 = -torch.sum(torch.log(torch.softmax(predict, 1)[ground_thruth_one_hot > 0]))
loss2.backward()
predict.grad
  1. 定义优化器optimizer
op = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-5)
  1. 循环执行优化过程
lr_schedule = {
	10: 1e-4
	20: 1e-5
}
total_epoch = 30
for epoch in range(total_epoch):
	if epoch in lr_schedule:
		new_lr = lr_schedule[epoch]
		for param_group in op.param_groups:
			param_group['lr'] = new_lr
	for index_batch, (images, labels) in enumerate(train_dataloader):
		# images.shape 64, 3, 300, 300
		# labels.shape 64, 1

		# 数据转到cuda
		images = images.cuda()
		labels = labels.cuda()
	
		output = model(images)
		loss = loss_function(output,labels)

		# 优化三部曲
		# 1、梯度清空
		op.zero_grad()

		# 2、loss函数反向传播
		loss.backward()

		# 3、应用梯度并进行更新
		op.step()

		print('epoch: {}, loss:{}'.format(epoch, loss.item()))
  1. 模型的测试
model = ResNet18(1000)
model.cuda()
model.eval() # 进入评估模式

trans = torchvision.transforms.Compose([
	torchvision.transforms.Resize(256),
	torchvision.transforms.ToTensor(),
	torchvision.Normalize([0.4850.4560.406][0.2290.2240.225])
])

# 禁止模型执行过程中为计算梯度而使用更多显存存储中间过程
with torch.no_grad():
	for file in files:
		image = cv2.imread(file)
		image = trans(image) # 归一化
		outputs = model(image)
		predict_label = outputs.argmax(dim=1).cpu()
  1. 模型的保存和加载
torch.save(model.state_dict(), 'my_model.pth')

checkpoint = torch.load('my_model.pth')
model.load_state_dict(checkpoint)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1690389.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux--软硬链接

目录 0.文件系统 1.软硬链接 1.1见一下软硬链接 1.2软硬链接的特征 1.3软硬链接是什么,有什么作用(场景) 0.文件系统 Linux--文件系统-CSDN博客 1.软硬链接 1.1见一下软硬链接 1.这是软链接 这个命令在Unix和Linux系统中用于创建一个符号…

什么是DNS缓存投毒攻击,有什么防护措施

随着企业组织数字化步伐的加快,域名系统(DNS)作为互联网基础设施的关键组成部分,其安全性愈发受到重视。然而,近年来频繁发生的针对DNS的攻击事件,已经成为企业组织数字化发展中的一个严重问题。而在目前各…

vue学习3:开发者调试工具的下载安装

极简插件官网_Chrome插件下载_Chrome浏览器应用商店 (zzzmh.cn) 测试运行程序 网页中右键检查

【Unity3D美术】URP渲染管线学习01

扫盲简介 URP渲染管线是Unity3d提供的一种视觉效果更好的渲染模式,类似的还有Built RP(默认最普通的渲染模式)\ HDRP(超高清,对设备要求高),视觉效果好,而且占用资源少!成为主流渲染管线模式&a…

十进制同步计数器

十进制同步计数器 使用最多的十进制计数器是按照 8421 BCD 码进行计数的电路 十进制同步加法计数器 【例1】设计一个十进制同步加法计数器,要求电路按 8421 BCD 码进行加法计数 Step1:建立原始状态转换图 Step2:选触发器,求方…

粉丝问,有没有UI的统计页面,安排!

移动应用的数据统计页面具有以下几个重要作用: 监控业务指标:数据统计页面可以帮助用户监控关键业务指标和数据,例如用户活跃度、销售额、转化率等。通过实时更新和可视化呈现数据,用户可以及时了解业务的整体状况和趋势。分析用…

LeetCode 128 最长连续序列(hot100) 解题思路分享

题干: 思路: 如果对时间复杂度没有要求的话,可以先排序,再一段一段地找,这样的好处是空间占用小。 如果希望On的话,那就采取设置一个Set的方法,这样空间复杂度是On,但是时间复杂度…

CPP Con 2020:Type Traits I

先谈谈Meta Programming 啥是元编程呢?很简单,就是那些将其他程序当作数据来进行处理和传递的编程(私人感觉有点类似于函数式?)这个其他程序可以是自己也可以是其他程序。元编程可以发生在编译时也可以发生在运行时。…

Python实现将LabelMe生成的JSON格式转换成YOLOv8支持的TXT格式

标注工具 LabelMe 生成的标注文件为JSON格式,而YOLOv8中支持的为TXT文件格式。以下Python代码实现3个功能: 1.将JSON格式转换成TXT格式; 2.将数据集进行随机拆分,生成YOLOv8支持的目录结构; 3.生成YOLOv8支持的YAML文件…

探索亚马逊云科技技术课程:大模型平台与提示工程的应用与优化

上方图片源自亚马逊云科技【生成式 AI 精英速成计划】技术开发技能课程 前言 学习了亚马逊云科技–技术开发技能课程 本课程分为三个部分,了解如何使用大模型平台、如何训练与部署大模型及生成式AI产品应用与开发,了解各类服务的优势、功能、典型使用案…

【QT八股文】系列之篇章2 | QT的信号与槽机制及通讯流程

【QT八股文】系列之篇章2 | QT的信号与槽机制及通讯流程 前言2. 信号与槽信号与槽机制介绍/本质/原理,什么是Qt信号与槽机制?如何在Qt中使用?信号与槽机制原理,解析流程Qt信号槽的调用流程信号与槽机制的优缺点信号与槽机制需要注…

【软考中级 软件设计师】数据结构

数据结构是计算机科学中一个基础且重要的概念,它研究数据的存储结构以及在此结构上执行的各种操作。在准备软考中级-软件设计师考试时,掌握好数据结构部分对于通过考试至关重要。下面是一些核心知识点概览: 基本概念: 数据结构定义…

在NVIDIA Jetson Nano上部署YOLOv5算法,并使用TensorRT和DeepStream进行加速

部署YOLOv5算法在NVIDIA Jetson Nano上并使用TensorRT和DeepStream进行加速涉及几个关键步骤。下面是一个详细的指南: 步骤 1: 准备YOLOv5模型 训练或下载预训练模型:首先,你需要有一个YOLOv5模型。你可以自己训练一个模型,或者…

响应式处理-一篇打尽

纯pc端响应式 pc端平常用到的响应式布局 大致就如下三种,当然也会有其他方法,欢迎评论区补充 将div height、width设置成100% flex布局 flex布局主要是将flex-wrap: wrap, 最后,你可以通过给子元素设置 flex 属性来控制它们的…

构建全面的无障碍学习环境:科技之光,照亮学习之旅

在信息与科技日益发展的当下,为所有人群提供一个包容和平等的学习环境显得尤为重要,特别是对于盲人朋友而言,无障碍学习环境的构建成为了一项亟待关注与深化的课题。一款名为“蝙蝠避障”的辅助软件,以其创新的设计理念与实用功能…

Excel 按顺序去重再编号

Excel的A有重复数据: A1Cow2Chicken3Horse4Butterfly5Cow 现在要去除重复,用自然数按顺序进行编号,结果写在相邻列: AB1Cow12Chicken23Horse34Butterfly45Cow1 使用 SPL XLL,输入公式并向下拖: spl(&q…

云平台的安全能力提升解决方案

提升云平台的安全能力是确保数据和服务安全的关键步骤。针对大型云平台所面临的云上安全建设问题,安全狗提供完整的一站式云安全解决方案,充分匹配云平台安全管理方的需求和云租户的安全需求。协助大型云平台建设全网安全态势感知、统一风险管理、统一资…

Zabbix-agents (windows环境)安装及配置

目录 一. 简介 Zabbix 服务端 1. Zabbix 服务器(Server) 2. Zabbix 数据库 3. Zabbix Web 前端 Zabbix 客户端 1. Zabbix 代理(Agent) 2. 安装和配置 二. 下载 三. 安装 四. 检查是否启动 五. 手动启动方式 六 .创建…

Python面向对象数据库之ZODB使用详解

概要 ZODB(Zope Object Database)是一个纯Python的面向对象数据库。它允许程序员将Python对象以透明的方式存储在数据库中,无需将对象模型转换为关系模型,极大地简化了Python应用的数据持久化工作。 安装 安装ZODB非常简单,可以通过Python的包管理器pip进行安装: pip …

leecode热题100---994:腐烂的橘子

题目: 在给定的 m x n 网格 grid 中,每个单元格可以有以下三个值之一: 值 0 代表空单元格; 值 1 代表新鲜橘子; 值 2 代表腐烂的橘子。 每分钟,腐烂的橘子 周围 4 个方向上相邻 的新鲜橘子都会腐烂。 返回…