基于pytoch卷积神经网络水质图像分类实战

news2024/10/8 22:12:58

具体怎么学习pytorch,看b站刘二大人的视频。

完整代码:

import numpy as np
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
'''https://zhuanlan.zhihu.com/p/156926543'''
# 定义图片目录
image_dir = 'images'

# 初始化图片路径列表
img_list = []

# 遍历指定目录及其子目录中的所有文件
for parent, _, filenames in os.walk(image_dir):
    for filename in filenames:
        # 拼接文件的完整路径
        filename_path = os.path.join(parent, filename)
        img_list.append(filename_path)

# 初始化图像张量列表和标签列表
image_tensors = []
y_list = []

for image_path in img_list:
    # 提取标签 (假设标签是文件名的第一个字符)
    label = int(os.path.basename(image_path)[0])
    y_list.append(label)

    # 打开图像
    img = Image.open(image_path)

    # 获取图像尺寸
    width, height = img.size

    # 定义裁剪的区域(假设要保留图像中心的 100x100 区域)
    left = (width - 100) / 2
    top = (height - 100) / 2
    right = (width + 100) / 2
    bottom = (height + 100) / 2

    # 裁剪图像
    img = img.crop((left, top, right, bottom))

    # 将图像转换为 NumPy 数组
    img_array = np.asarray(img)

    # 将 NumPy 数组转换为 PyTorch 张量
    img_tensor = torch.from_numpy(img_array).float()

    # 如果图像是 RGB,将其转换为 (C, H, W) 格式
    if img_tensor.ndimension() == 3 and img_tensor.shape[2] == 3:
        img_tensor = img_tensor.permute(2, 0, 1)  # 从 (H, W, C) 变为 (C, H, W)

    # 增加 batch 维度
    img_tensor = img_tensor.unsqueeze(0)  # 从 (C, H, W) 变为 (1, C, H, W)

    # 规范化到0-1之间
    img_tensor = img_tensor / 255.0

    # 添加到图像张量列表
    image_tensors.append(img_tensor)

    # 打印图像张量的形状
    print(f"当前图像形状: {img_tensor.shape}")

# 将图像张量列表转换为四维张量
x_data = torch.cat(image_tensors, dim=0)
# 遍历 y_list 中的每个元素,并将每个数减去 1
for i in range(len(y_list)):
    y_list[i] -= 1

# 将标签列表转换为张量
y_labels = torch.tensor(y_list).long()  # 注意这里使用 .long() 方法将标签转换为长整型

print(x_data.shape,y_labels.shape)
print(y_labels)

# 定义数据集和数据加载器
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, x_data, y_labels):
        self.x_data = x_data
        self.y_labels = y_labels

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        return self.x_data[idx], self.y_labels[idx]


# 使用自定义数据集和数据加载器
custom_dataset = CustomDataset(x_data, y_labels)
train_size = int(0.8 * len(custom_dataset))
val_size = len(custom_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(custom_dataset, [train_size, val_size])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)

# 定义卷积神经网络模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 25 * 25, 128)
        self.fc2 = nn.Linear(128, 5)  # 假设有5个类别

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# 训练模型
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(50):  # 假设训练50个epoch
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}")

    # 在每个epoch结束后,计算并打印验证集的准确率
    model.eval()  # 将模型设置为评估模式
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_accuracy = correct / total
    print(f'Validation Accuracy after Epoch {epoch + 1}: {val_accuracy}')

本数据集中x_data的维度是四维张量(203,3,100,100),y_labels的维度是一维张量 

代码中需要注意的点,卷积模型接受的是四维张量,因此要转变为四维张量。

全连接层中输入的特征数,需要自己计算,通过前面卷积层和池化层后,计算总的维度数。一般是最后的通道数*高度*宽度

定义模型中的forward函数中,在经过全连接层计算前,需要将四维的x转为2维

如果 x 的形状是 (64, 32, 28, 28),表示一个批次大小为64的图像张量,其中每个图像有32个通道,高度和宽度都是28像素。现在,我们希望将这个张量展平为一个二维张量,以便输入到全连接层进行进一步处理。

通过 torch.flatten(x, 1) 操作,我们将在指定维度(这里是第一个维度,也就是通道维度)上对张量进行展平。展平后的张量形状将变为 (64, 32*28*28),其中64是批次大小,而 32*28*28 是展平后的特征数量,即每个图像的特征数量。这与前面定义的全连接层的输入特征数要一致。

Dataloader中batch_size就是设置第一个维度,比如这里的batch_size是32,那么

for inputs, labels in train_loader:

 这里的inputs维度是(32,3,100,100)

新学习pytorch中的分割数据集与测试集方法。

# 使用自定义数据集和数据加载器
custom_dataset = CustomDataset(x_data, y_labels)
train_size = int(0.8 * len(custom_dataset))
val_size = len(custom_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(custom_dataset, [train_size, val_size])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)

结果展现,可以看见准确率有0.82:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1807933.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

springboot高校运动会信息管理系统设计与实现-计算机毕业设计源码92968

摘 要 本论文介绍了一个高校运动会信息管理系统的设计和实现过程。首先是高校运动会的需求分析和可行性分析,通过比较运动会的各个工作流程,确定了系统的数据流程和数据库结构,然后介绍了高校运动会信息管理系统开发所使用的软件开发工具&…

脖子以下是人机交互,脖子以上是人机融合智能

“脖子以下是人机交互,脖子以上是人机融合智能”这句话是当前人与人工智能配合技术发展的一种形象描述,一个是生理物理,一个是人脑电脑: 1、人机交互的重要性 脖子以下的人机交互在当前的人工智能系统中扮演着重要的角色。人机交互…

SpringBoot整合RabbitMQ (持续更新中)

RabbitMQ 官网地址:RabbitMQ: One broker to queue them all | RabbitMQ RabbitMQ 与 Erlang 版本兼容关系​ 3.13.0 26.0 26.2.x The 3.13 release series is compatible with Erlang 26. OpenSSL 3 support in Erlang is considered to be mature and ready for…

大数据湖一体化运营管理建设方案(49页PPT)

方案介绍: 本大数据湖一体化运营管理建设方案通过构建统一存储、高效处理、智能分析和安全管控的大数据湖平台,实现了企业数据的集中管理、快速处理和智能分析。该方案具有可扩展性、高性能、智能化、安全性和易用性等特点,能够为企业数字化…

基于51单片机俄罗斯方块小游戏

基于51单片机俄罗斯方块游戏 (仿真+程序) 功能介绍 具体功能: 1.用LCD12864显示游戏界面; 2.用四个按键控制游戏(左、右移、下移、翻转); 3.游戏规则和平时玩的俄罗斯方块一样&a…

Qt窗口与对话框

目录 Qt窗口 1.菜单栏 2.工具栏 3.状态栏 4.滑动窗口 QT对话框 1.基础对话框QDiaog 创建新的ui文件 模态对话框与非模态对话框 2.消息对话框 QMessageBox 3.QColorDialog 4.QFileDialog文件对话框 5.QFontDialog 6.QInputDialog Qt窗口 前言:之前以上…

天才程序员周弈帆 | Stable Diffusion 解读(二):论文精读

本文来源公众号“天才程序员周弈帆”,仅用于学术分享,侵权删,干货满满。 原文链接:Stable Diffusion 解读(二):论文精读 【小小题外话】端午安康! 在上一篇文章天才程序员周弈帆 …

初阶 《函数》 4. 函数的调用

4. 函数的调用 4.1 传值调用 函数的形参和实参分别占有不同内存块,对形参的修改不会影响实参 4.2 传址调用 传址调用是把函数外部创建变量的内存地址传递给函数参数的一种调用函数的方式 这种传参方式可以让函数和函数外边的变量建立起真正的联系,也就是…

【数据分析基础】实验三 文件操作、数组与矩阵运算

一.实验目的 掌握上下文管理语句with的使用方法。掌握文本文件的操作方法。了解os、os.path模块的使用。掌握扩展库Python-docx、openpyxl的安装与操作word、Excel文件内容的方法。熟练掌握numpy数组相关运算和简单应用。熟练使用numpy创建矩阵,熟悉常用…

matlab演示地月碰撞

代码 function EarthMoonCollisionSimulation()% 初始化参数earth_radius 6371; % 地球半径,单位:公里moon_radius 1737; % 月球半径,单位:公里distance 384400; % 地月距离,单位:公里collision_tim…

【JAVASE】java语法(成员变量与局部变量的区别、赋值运算符中的易错点)

一:成员变量与局部变量的区别 区别 成员变量 局部变量 类中位置不同 …

高通Android开关机动画踩坑简单记录

1、下面报错有可能是selinux的原因 Read-only file system 2、接着push 动画 reboot之后抓取logcat出现 ,以下这个报错。看着像是压缩格式有问题。 3、于是重新压缩一下报错没有再出现 ,压缩格式默认是标准,这里必须要改成存储格式哈 4、修改…

python-数字黑洞

[题目描述] 给定一个三位数,要求各位不能相同。例如,352是符合要求的,112是不符合要求的。将这个三位数的三个数字重新排列,得到的最大的数,减去得到的最小的数,形成一个新的三位数。对这个新的三位数可以重…

解决windows11开机xbox自启动

1、同时按键盘“ctrlaltdelete”键,在弹出页面中选择任务管理器; 2、点击启动应用 3、找到软件Xbox App Services,选择“已启用”点击右键,点击禁用;

新书推荐:2.3 消息机制

Windows程序的消息机制是指在Windows操作系统下,应用程序与操作系统之间的一种通信方式。通过消息机制,应用程序可以接收来自操作系统的各种事件和请求,以便做出相应的响应和处理。 在Windows程序中,消息机制的实现是基于消息队列…

OpenGauss数据库-6.表空间管理

第1关:创建表空间 gsql -d postgres -U gaussdb -W passwd123123 CREATE TABLESPACE fastspace OWNER omm relative location tablespace/tablespace_1; 第2关:修改表空间 gsql -d postgres -U gaussdb -W passwd123123 ALTER TABLESPACE fastspace R…

Gh-ost让MySQL在线表结构变更不再是难题

Gh-ost:无缝迁移,效率与安全并行- 精选真开源,释放新价值。 概览 gh-ost是由GitHub团队精心打造的在线MySQL表结构迁移工具,它以一种无需触发器的方式,实现了对数据库表结构变更的在线操作。gh-ost的设计初衷是解决现…

matlab使用教程(95)—显示地理数据

下面的示例说明了多种表示地球地貌的方法。此示例中的数据取自美国商务部海洋及大气管理局 (NOAA) 国家地理数据中心,数据通告编号为 88-MGG-02。 1.关于地貌数据 数据文件 topo.mat 包含地貌数据。topo 是海拔数据,topomap1 是海拔的颜色图。 load t…

【HarmonyOS】鸿蒙应用模块化实现

【HarmonyOS】鸿蒙应用模块化实现 一、Module的概念 Module是HarmonyOS应用的基本功能单元,包含了源代码、资源文件、第三方库及应用清单文件,每一个Module都可以独立进行编译和运行。一个HarmonyOS应用通常会包含一个或多个Module,因此&am…

【力扣高频题】003.无重复字符的最长子串

前段时间和小米的某面试官聊天。因为我一直在做 算法文章 的更新,就多聊了几句算法方面的知识。 并且在聊天过程中获得了一个“重要情报”:只要他来面试,基本上每次的算法题,都会去考察关于 子串和子序列 的问题。 的确&#xf…