【深度学习】快速制作图像标签数据集以及训练

news2025/1/18 12:01:30

快速制作图像标签数据集以及训练

制作DataSet

  • 先从网络收集十张图片 每种十张
    在这里插入图片描述

  • 定义dataSet和dataloader

import glob
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt


# 通过创建data.Dataset子类Mydataset来创建输入
class Mydataset(data.Dataset):
    # init() 初始化方法,传入数据文件夹路径
    def __init__(self, root):
        self.imgs_path = root

    # getitem() 切片方法,根据索引下标,获得相应的图片
    def __getitem__(self, index):
        img_path = self.imgs_path[index]

    # len() 计算长度方法,返回整个数据文件夹下所有文件的个数
    def __len__(self):
        return len(self.imgs_path)


# 使用glob方法来获取数据图片的所有路径
all_imgs_path = glob.glob(r"./Data/*/*.jpg")  # 数据文件夹路径

# 利用自定义类Mydataset创建对象brake_dataset
# 将所有的路径塞进dataset  使用每张图片的路径进行索引图片
brake_dataset = Mydataset(all_imgs_path)
# print("图片总数:{}".format(len(brake_dataset)))  # 返回文件夹中图片总个数



# 制作dataloader
brake_dataloader = torch.utils.data.DataLoader(brake_dataset, batch_size=2)  # 每次迭代时返回4个数据
# print(next(iter(break_dataloader)))

制作标签


# 为每张图片制作对应标签
species = ['sun', 'rain', 'cloud']
species_to_id = dict((c, i) for i, c in enumerate(species))
# print(species_to_id)

id_to_species = dict((v, k) for k, v in species_to_id.items())
# print(id_to_species)

# 对所有图片路径进行迭代
all_labels = []
for img in all_imgs_path:
	# 区分出每个img,应该属于什么类别
	for i, c in enumerate(species):
		if c in img:
			all_labels.append(i)
# print(all_labels)

制作数据和标签一起的dataset和dataloader

  • 上面的dataset不够完善
# 将数据转换为张量数据
# 对数据进行转换处理
transform = transforms.Compose([
	transforms.Resize((256, 256)),  # 做的第一步转换
	transforms.ToTensor()  # 第二步转换,作用:第一转换成Tensor,第二将图片取值范围转换成0-1之间,第三会将channel置前
])


class Mydatasetpro(data.Dataset):
	def __init__(self, img_paths, labels, transform):
		self.imgs = img_paths
		self.labels = labels
		self.transforms = transform

	# 进行切片
	def __getitem__(self, index):
		img = self.imgs[index]
		label = self.labels[index]
		pil_img = Image.open(img)  # pip install pillow
		pil_img = pil_img.convert('RGB')
		data = self.transforms(pil_img)
		return data, label

	# 返回长度
	def __len__(self):
		return len(self.imgs)


BATCH_SIZE = 4
brake_dataset = Mydatasetpro(all_imgs_path, all_labels, transform)
brake_dataloader = data.DataLoader(
	brake_dataset,
	batch_size=BATCH_SIZE,
	shuffle=True
)

imgs_batch, labels_batch = next(iter(brake_dataloader))

# 4 X 3 X 256 X 256
print(imgs_batch.shape)

plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs_batch[:10], labels_batch[:10])):
	img = img.permute(1, 2, 0).numpy()
	plt.subplot(2, 3, i + 1)
	plt.title(id_to_species.get(label.item()))
	plt.imshow(img)
plt.show()  # 展示图片


制作训练集和测试集

# 划分数据集和测试集
index = np.random.permutation(len(all_imgs_path))

#  打乱所有图片的索引
print(index)

# 根据索引获取所有图片的路径
all_imgs_path = np.array(all_imgs_path)[index]
all_labels = np.array(all_labels)[index]

print("打乱顺序之后的所有图片路径{}".format(all_imgs_path))
print("打乱顺序之后的所有图片索引{}".format(all_labels))

# 80%做训练集
s = int(len(all_imgs_path) * 0.8)
# print(s)

train_imgs = all_imgs_path[:s]
# print(train_imgs)
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_labels[s:]


# 将训练集和标签 制作dataset 需要转换为张量
train_ds = Mydatasetpro(train_imgs, train_labels, transform)  # TrainSet TensorData
test_ds = Mydatasetpro(test_imgs, test_labels, transform)  # TestSet TensorData
# print(train_ds)
# print(test_ds)
print("**********")
# 制作trainLoader
train_dl = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)  # TrainSet Labels
test_dl = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)  # TestSet Labels




训练代码

import torch
import torchvision.models as models
from torch import nn
from torch import optim
from DataSetMake import brake_dataloader
from DataSetMake import train_dl, test_dl


# 判断是否使用GPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#  使用resnet 训练
model_ft = models.resnet50(pretrained=True)  # 使用迁移学习,加载预训练权


in_features = model_ft.fc.in_features
model_ft.fc = nn.Sequential(nn.Linear(in_features, 256),
							nn.ReLU(),
							# nn.Dropout(0, 4),
							nn.Linear(256, 4),
							nn.LogSoftmax(dim=1))

model_ft = model_ft.to(DEVICE)  # 将模型迁移到gpu

# 优化器
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(DEVICE)  # 将loss_fn迁移到GPU

# Adam损失函数
optimizer = optim.Adam(model_ft.fc.parameters(), lr=0.003)


epochs = 50  # 迭代次数
steps = 0
running_loss = 0
print_every = 10
train_losses, test_losses = [], []

for epoch in range(epochs):
	model_ft.train()
	# 遍历训练集数据
	for imgs, labels in brake_dataloader:
		steps += 1

		# 标签转换为 tensor
		labels = torch.tensor(labels, dtype=torch.long)

		# 将图片和标签 放到设备上
		imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

		optimizer.zero_grad()  # 梯度归零

		#  前向推理
		outputs = model_ft(imgs)

		# 计算loss
		loss = loss_fn(outputs, labels)
		loss.backward()  # 反向传播计算梯度
		optimizer.step()  # 梯度优化

		# 累加loss
		running_loss += loss.item()

		if steps % print_every == 0:
			test_loss = 0
			accuracy = 0

			# 验证模式
			model_ft.eval()

			# 测试集 不需要计算梯度
			with torch.no_grad():
				# 遍历测试集数据
				for imgs, labels in test_dl:
					#  转换为tensor
					labels = torch.tensor(labels, dtype=torch.long)

					#  数据标签 部署到gpu
					imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

					#  前向推理
					outputs = model_ft(imgs)

					#  计算损失
					loss = loss_fn(outputs, labels)

					# 累加测试机的损失
					test_loss += loss.item()

					ps = torch.exp(outputs)
					top_p, top_class = ps.topk(1, dim=1)

					equals = top_class == labels.view(*top_class.shape)
					accuracy += torch.mean(equals.type(torch.FloatTensor)).item()

			train_losses.append(running_loss / len(train_dl))
			test_losses.append(test_loss / len(test_dl))

			print(f"Epoch {epoch + 1}/{epochs}.. "
				  f"Train loss: {running_loss / print_every:.3f}.. "
				  f"Test loss: {test_loss / len(test_dl):.3f}.. "
				  f"Test accuracy: {accuracy / len(test_dl):.3f}")

			#  回到训练模式 训练误差清0
			running_loss = 0
			model_ft.train()
torch.save(model_ft, "aerialmodel.pth")




在这里插入图片描述

预测代码

import os
import torch
from PIL import Image
from torch import nn
from torchvision import transforms, models

i = 0  # 识别图片计数
# 这里最好新建一个test_data文件随机放一些上面整理好的图片进去
root_path = r"D:\CODE\ImageClassify\Test"  # 待测试文件夹
names = os.listdir(root_path)

for name in names:
	print(name)
	i = i + 1
	data_class = ['sun', 'rain', 'cloud']  # 按文件索引顺序排列


	#  找出文件夹中的所有图片
	image_path = os.path.join(root_path, name)
	image = Image.open(image_path)
	print(image)

	#  张量定义格式
	transform = transforms.Compose([transforms.Resize((256, 256)),
									transforms.ToTensor()])
	# 图片转换为张量
	image = transform(image)
	print(image.shape)

	#  定义resnet模型
	model_ft = models.resnet50()

	# 模型结构
	in_features = model_ft.fc.in_features
	model_ft.fc = nn.Sequential(nn.Linear(in_features, 256),
								nn.ReLU(),
								# nn.Dropout(0, 4),
								nn.Linear(256, 4),
								nn.LogSoftmax(dim=1))


	# 加载已经训练好的模型参数
	model = torch.load("aerialmodel.pth", map_location=torch.device("cpu"))

	# 将每张图片 调整维度
	image = torch.reshape(image, (1, 3, 256, 256))  # 修改待预测图片尺寸,需要与训练时一致
	model.eval()

	#  速出预测结果
	with torch.no_grad():
		output = model(image)
	print(output)  # 输出预测结果
	# print(int(output.argmax(1)))
	# 对结果进行处理,使直接显示出预测的种类  根据索引判别是哪一类
	print("第{}张图片预测为:{}".format(i, data_class[int(output.argmax(1))]))


工程目录结构

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1166058.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

这才是当今生成式人工智能的根本性问题!

原创 | 文 BFT机器人 01 引言 近年来,生成式人工智能产品层出不穷,ChatGPT火爆出圈后,百度、谷歌等科技大佬争相研究生成式人工智能产品,将该技术的普及程度提升到了一个新的水平。然而,生成式人工智能的运营需要高昂…

谷歌浏览器解决跨域问题配置记录

在访问时出现has been blocked by CORS policy: Responspreflight request doesn’t pass access control checlAccess-Control-A1low-Origin" header is present onrequested resource. 出现跨域问题 1.先关闭浏览器 2.创建一个目录,文件夹记住路径 3.点击谷…

高德地图撒点组件

一、引入amap地图库 - public/index.html <script type"text/javascript">window._AMapSecurityConfig {securityJsCode: 地图密钥 }</script><scripttype"text/javascript"src"https://webapi.amap.com/maps?v1.4.8&key111111…

rpm 软件包管理工具

RPM&#xff08;RedHat Package Manager&#xff09;&#xff0c;RedHat软件包管理工具。 rpm 查询 rpm -qa #查询所有包(query all)rpm -qa |grep firefox #firefox-102.15.0-1.el7.centos.x86_64rpm -qi | grep firefox #(query information) #Name : firefox #…

Flink日志采集-ELK可视化实现

一、各组件版本 组件版本Flink1.16.1kafka2.0.0Logstash6.5.4Elasticseach6.3.1Kibana6.3.1 针对按照⽇志⽂件⼤⼩滚动⽣成⽂件的⽅式&#xff0c;可能因为某个错误的问题&#xff0c;需要看好多个⽇志⽂件&#xff0c;还有Flink on Yarn模式提交Flink任务&#xff0c;在任务执…

嵌入式学习的两大误区

误区一、全身投入学习桌面或服务器版本Linux系统很多想学嵌入式Linux 的同学经常问我&#xff0c;我不会Linux系统&#xff0c;怎么学习嵌入式Linux开发&#xff0c;于是他们就花费了大量的精力和时间去研究学习桌面版本Linux系统的使用&#xff0c;什么redhat 、federo&#x…

IDEA启动报端口占用

方法一 netstat -ano | findstr :1099 这将列出正在使用1099端口的进程的相关信息&#xff0c;包括进程ID&#xff08;PID&#xff09;。查找使用1099端口的进程ID&#xff0c;并记下该进程的ID号。输入以下命令并按Enter键执行&#xff0c;其中PID是你在上一步中找到的进程ID…

Openssl生成证书-nginx使用ssl

Openssl生成证书并用nginx使用 安装openssl yum install openssl -y创库目录存放证书 mkdir /etc/nginx/cert cd /etc/nginx/cert配置本地解析 cat >>/etc/hosts << EOF 10.10.10.21 kubernetes-master.com EOF10.10.10.21 主机ip、 kubernetes-master.com 本…

【Unity实战】最全面的库存系统(三)

文章目录 先来看看最终效果前言新增脚本获取唯一ID保存和加载保存地面物品将玩家快捷栏和背包合并快捷栏物品显示完结先来看看最终效果 前言 本期紧跟着上期,继续来完善我们的库存系统,实现物品背包仓库数据的存储和加载功能 新增脚本获取唯一ID 新增脚本,自定义控制只读…

超详细Linux搭建Hadoop集群

一、给计算机集群起别名——互通 总纲&#xff1a; 1、准备3台客户机&#xff08;关闭防火墙、静态IP、主机名称都设置好&#xff09; 2、安装JDK&#xff08;可点击&#xff09; 3、配置环境变量 4、安装Hadoop 5、配置hadoop的环境变量 6、配置集群 7、群起测试 1.1、环境准备…

素材搜罗利器!产品设计必须知道的13款最佳网站!

灵感素材类 1.即时设计 在网页中搜索“即时设计”&#xff0c;进入官网后登录账号&#xff0c;之后进入「资源广场」版块便能看到即时设计提供的上万条设计素材。在搜索框内根据需要进行搜索&#xff0c;比如输入“网页设计”&#xff0c;便会看到即时设计提供的网页设计素材…

代码训练营第59天:动态规划part17|leetcode647回文子串|leetcode516最长回文子序列

leetcode647&#xff1a;回文子串 文章讲解&#xff1a;leetcode647 leetcode516&#xff1a;最长回文子序列 文章讲解&#xff1a;leetcode516 DP总结&#xff1a;动态规划总结 目录 1&#xff0c;leeetcode647 回文子串。 2&#xff0c;leetcode516 最长回文子串&#xff1…

实验室装修公司的线上推广成功案例_上海添力网络科技

2018年7月&#xff0c;也是我的书《快速见效的企业网络营销方法 B2B 大宗B2C》出版后两个月&#xff0c;某装修公司的市场部总监在阅读完这本书后&#xff0c;找到了我&#xff0c;希望能帮到他们公司提升线上获客能力。 当时他们已经成立了线上推广团队&#xff0c;配置了SEM岗…

echarts中 对seriesLayoutBy的理解

https://echarts.apache.org/handbook/zh/concepts/dataset/ ‘row’: 系列被安放到 dataset 的行上面。 这里x轴是目录轴&#xff0c;那么一列就是一个系列 ‘column’: 默认值。系列被安放到 dataset 的列上面。 用自己的话总结就是&#xff1a; 当 seriesLayoutBy 为行时&…

xxx cannot be resolved to a variable之解决方法

错误原因&#xff1a; 大致意识是&#xff1a;无法解析为变量 可能是没有声明、变量名识别不了、要么拼写错误&#xff0c;如果是int类型要考虑是否赋初值 解决方法&#xff1a; 声明为Stirng类型即可

JUL 日志

JUL日志级别 日志分为7个级别&#xff0c;详细信息我们可以在Level类中查看&#xff1a; SEVERE&#xff08;最高值&#xff09;- 一般用于代表严重错误WARNING - 一般用于表示某些警告&#xff0c;但是不足以判断为错误INFO &#xff08;默认级别&#xff09; - 常规消息CON…

STM32F4X SDIO(六) 例程讲解-SD_PowerON

STM32F4X SDIO&#xff08;六&#xff09; 例程讲解-SD_PowerON 例程讲解-SD_PowerONSDIO引脚初始化和时钟初始化SDIO初始化(单线模式)CMD0:GO_IDLE_STATE命令发送程序命令响应程序 CMD8:SEND_IF_CONDCMD8参数命令发送程序命令响应程序 CMD55:APP_CMDCMD55命令参数命令发送命令…

XCTF-RSA-2:baigeiRSA2、 cr4-poor-rsa

baigeiRSA2 题目描述 import libnum from Crypto.Util import number from functools import reduce from secret import flagn 5 size 64 while True:ps [number.getPrime(size) for _ in range(n)]if len(set(ps)) n:breake 65537 n reduce(lambda x, y: x*y, ps) m …

docker网络管理-网络模式

2.4 网络管理 需要安装sudo apt install bridge-utils Docker 网络很重要&#xff0c;重要的&#xff0c;我们在上面学到的所有东西都依赖于网络才能工作。我们从两个方面来学习网络&#xff1a; 端口映射和网络模式 为什么先学端口映射呢&#xff1f; 在一台主机上学习网络&…

strongswan:configure: error: OpenSSL Crypto library not found

引子 在配置strongswan时&#xff0c;有时会遇到以下错误&#xff08;其实所有需要openssl的软件configure时都有可能遇到该问题&#xff09;&#xff1a; configure: error: OpenSSL Crypto library not found 解决方法 crypto是什么呢? 是OpenSSL 加密库(lib), 这个库需要op…