PyTorch+AlexNet代码实训

news2024/9/24 15:22:09

参考文章:https://blog.csdn.net/red_stone1/article/details/122974771

数据集:
在这里插入图片描述

打标签:

import os
 
# os.path.join: 每个参数都是一个路径段,将它们连接起来形成有效的路径名。
train_txt_path = os.path.join("data", "catVSdog", "train.txt")
train_dir = os.path.join("data", "catVSdog", "train_data")
valid_txt_path = os.path.join("data", "catVSdog", "test.txt")
valid_dir = os.path.join("data", "catVSdog", "test_data")
 
def gen_txt(txt_path, img_dir): # 标签,图像
    f = open(txt_path, 'w')  # 打开一个文件,创建一个file对象

    # os.walk: 遍历一个目录树,返回目录中的每个目录和文件
    # os.walk每次迭代都会返回一个元组:
    #(当前目录的路径字符串,当前目录中所有子目录名称,当前目录所有文件名称)
    for root, s_dirs, _ in os.walk(img_dir, topdown=True):  # 获取 train文件下各文件夹名称
        # topdown用于决定遍历目录树的顺序
        # 以猫狗大战数据集为例,这里的s_dirs是cat和dog文件夹
        for sub_dir in s_dirs: # 对于猫或狗文件夹里的每个文件(每张图片)遍历
            i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径 ?
            img_list = os.listdir(i_dir)                    # 获取类别文件夹下所有png图片的路径 ? 应该是jpg
            # os.listdir: 用于返回指定目录中的所有文件和目录的名称列表
            for i in range(len(img_list)): # 遍历一个类别中的所有图片
                if not img_list[i].endswith('jpg'):         # 若不是png文件,跳过 ? 应该是jpg
                    continue
                #label = (img_list[i].split('.')[0] == 'cat')? 0 : 1 
                label = img_list[i].split('.')[0] # 按.分割,并取点后的**第一个部分**
                # 将字符类别转为整型类型表示
                if label == 'cat':
                    label = '0'
                else: # label == 'dog'
                    label = '1'
                img_path = os.path.join(i_dir, img_list[i])
                line = img_path + ' ' + label + '\n'
                f.write(line) # 把打好的标签写在.txt里
    f.close()
 
if __name__ == '__main__':
    # 共生成两个图片索引文件:train.txt和test.txt
    gen_txt(train_txt_path, train_dir)
    gen_txt(valid_txt_path, valid_dir)

构建数据集:

from PIL import Image
from torch.utils.data import Dataset
 
class MyDataset(Dataset):
    def __init__(self, txt_path, transform = None, target_transform = None):
        fh = open(txt_path, 'r') # 打开图片索引文件
        imgs = [] # 存储元组:(图片路径,类别(0或1))
        for line in fh: # 迭代读取文件的行
            line = line.rstrip() # 使用rstrip方法去除行末的空白符(包括\n)
            words = line.split() # 将字符串按空白符(空格、制表符等)进行分割
            imgs.append((words[0], int(words[1]))) # 类别转为整型int
            self.imgs = imgs 
            # self.transform和self.target_transform:根据读入的参数赋值
            self.transform = transform
            self.target_transform = target_transform
    # __getitem__方法和__len__方法均继承自父类Dataset
    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB') 
        #img = Image.open(fn)
        if self.transform is not None:
            img = self.transform(img) # self.transform对图片进行处理,推测传入的是一个函数名
        return img, label
    def __len__(self):
        return len(self.imgs)

加载数据集&数据预处理:

from torchvision import transforms
# transforms.Compose接受一个列表或元组作为参数,列表中的每个元素都是一个数据转换操作
# transforms.Compose返回一个串行操作序列
pipline_train = transforms.Compose([
    #随机旋转图片
    transforms.RandomHorizontalFlip(),
    #将图片尺寸resize到227x227(这是AlexNet的要求)
    transforms.Resize((227,227)),
    #将图片转化为Tensor格式
    transforms.ToTensor(),
    #正则化(当模型出现过拟合的情况时,用来降低模型的复杂度,加快模型收敛速度)
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 均值为0.5,标准差为0.5
    #transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
])
pipline_test = transforms.Compose([
    #将图片尺寸resize到227x227
    transforms.Resize((227,227)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
])
train_data = MyDataset('./data/catVSdog/train.txt', transform=pipline_train)
test_data = MyDataset('./data/catVSdog/test.txt', transform=pipline_test)
 
# train_data 和test_data包含多有的训练与测试数据,调用DataLoader批量加载
# batch_size: 每个小批量(batch)包含的样本数量。
# 在训练过程中,模型不会一次性处理整个数据集,而是分成多个小批量逐一输入模型进行训练
# shuffle: 数据洗牌-随机打乱数据集的顺序-使模型在训练时不会对数据顺序敏感
trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(dataset=test_data, batch_size=32, shuffle=False)
# 类别信息也是需要我们给定的
classes = ('cat', 'dog') # 对应label=0,label=1

查看最终制作的数据集(图片&标签):

import numpy as np
examples = enumerate(trainloader) # 方便迭代trainloader中的每个批量数据并同时获取它们的索引
batch_idx, (example_data, example_label) = next(examples) # next: 获取枚举对象的下一个元素
# 批量展示图片
for i in range(4):
    plt.subplot(1, 4, i + 1) #
    plt.tight_layout()  #自动调整子图参数,使之填充整个图像区域
    img = example_data[i]
    img = img.numpy() # FloatTensor转为ndarray
    img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
    img = img * [0.5, 0.5, 0.5] + [0.5, 0.5, 0.5]
    plt.imshow(img)
    plt.title("label:{}".format(example_label[i]))
    plt.xticks([])
    plt.yticks([])
plt.show()

搭建AlexNet神经网络结构:

class AlexNet(nn.Module):
    """
    Neural network model consisting of layers propsed by AlexNet paper.
    """
    def __init__(self, num_classes=2):
        """
        Define and allocate layers for this neural net.
        Args:
            num_classes (int): number of classes to predict with this model
        """
        super().__init__() # 继承父类的__init__方法
        # input size should be : (b x 3 x 227 x 227)
        # The image in the original paper states that width and height are 224 pixels, but
        # the dimensions after first convolution layer do not lead to 55 x 55.
        
        # nn.Sequential是一个用于构建神经网络的容器,它按顺序将各个模块(层)组合在一起,形成一个神经网络模型
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4),  # (b x 96 x 55 x 55) 
            # nn.Conv2d还可以继续添加参数:padding 表示边缘填充空白像素的宽度
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),  # 局部响应归一化
            # 目前更多采用的是Batch Normalization
            nn.MaxPool2d(kernel_size=3, stride=2),  # (b x 96 x 27 x 27)
            nn.Conv2d(96, 256, 5, padding=2),  # (b x 256 x 27 x 27)
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),  # (b x 256 x 13 x 13)
            nn.Conv2d(256, 384, 3, padding=1),  # (b x 384 x 13 x 13)
            nn.ReLU(),
            nn.Conv2d(384, 384, 3, padding=1),  # (b x 384 x 13 x 13)
            nn.ReLU(),
            nn.Conv2d(384, 256, 3, padding=1),  # (b x 256 x 13 x 13)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),  # (b x 256 x 6 x 6)
        )
        # classifier is just a name for linear layers
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True), # 
            nn.Linear(in_features=(256 * 6 * 6), out_features=500),
            nn.ReLU(),
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features=500, out_features=20),
            nn.ReLU(),
            nn.Linear(in_features=20, out_features=num_classes),
        )
 
    def forward(self, x):
        """
        Pass the input through the net.
        Args:
            x (Tensor): input tensor
        Returns:
            output (Tensor): output tensor
        """
        x = self.net(x)
        x = x.view(-1, 256 * 6 * 6)  # reduce the dimensions for linear layer input
        # x.view: 改变张量形状。
        # -1表示自动计算,后面的256*6*6表示将第二个维度变成这个尺寸,以便作为全连接层的输入。
        return self.classifier(x)

将模型部署到GPU/CPU:

#创建模型,部署gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AlexNet().to(device) # 这里的AlexNet是类名,通过.to(device)方法将模型移动到指定的设备
#定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001) # PyTorch提供的Adam优化器
# model.parameters()返回模型中所有需要训练的参数迭代器
# lr是Adam优化器的学习率,控制每次参数更新的步长大小

定义训练过程:

def train_runner(model, device, trainloader, optimizer, epoch):
    #训练模型, 启用 BatchNormalization 和 Dropout, 将BatchNormalization和Dropout置为True
    model.train()
    total = 0
    correct =0.0
 
    #enumerate迭代已加载的数据集,同时获取数据和数据下标
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        #把模型部署到device上
        inputs, labels = inputs.to(device), labels.to(device)
        #初始化梯度
        optimizer.zero_grad()
        #保存训练结果
        outputs = model(inputs)
        #计算损失和
        #多分类情况通常使用cross_entropy(交叉熵损失函数), 而对于二分类问题, 通常使用sigmod
        loss = F.cross_entropy(outputs, labels)
        #获取最大概率的预测结果
        #dim=1表示返回每一行的最大值对应的列下标
        predict = outputs.argmax(dim=1)
        total += labels.size(0)
        correct += (predict == labels).sum().item()
        #反向传播
        loss.backward()
        #更新参数
        optimizer.step()
        if i % 100 == 0:
            #loss.item()表示当前loss的数值
            print("Train Epoch{} \t Loss: {:.6f}, accuracy: {:.6f}%".format(epoch, loss.item(), 100*(correct/total)))
            Loss.append(loss.item())
            Accuracy.append(correct/total)
    return loss.item(), correct/total

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1950090.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Bazaar v1.4.3 任意文件读取漏洞复现(CVE-2024-40348)

0x01 产品简介 Bazarr是Sonarr和Radarr的配套应用程序,可根据您的要求管理和下载字幕。 0x02 漏洞概述 Bazarr存在任意文件读取漏洞,该漏洞是由于Bazaar v1.4.3的组件/api/swaggerui/static中存在一个问题,允许未经身份验证的攻击者可利用…

硅纪元AI应用推荐 | 豆包整容成了浏览器,让你的电脑秒变AI PC

“硅纪元AI应用推荐”栏目,为您精选最新、最实用的人工智能应用,无论您是AI发烧友还是新手,都能在这里找到提升生活和工作的利器。与我们一起探索AI的无限可能,开启智慧新时代! 亲爱的技术宅们、办公高手们&#xff0c…

Tomcat项目本地部署

今天来分享一下如何于本机上在不适用idea等辅助工具的前提下,部署多个tomcat的web项目 我这里以我最近写的SSM项目哈米音乐为例,简单介绍一下项目的大致组成: 首先,项目分为4个模块,如下图所示: 其中&…

SQL 语句中的字符串有单引号导致报错的解决

1.问题 SQL 语句执行对象中,本内容的字符串内含有单引号导致查询或插入数据库报错, 例如 str 关键字 AND 附近有语法错误 2.解决 字符串中的 ’ → 替换 ”,则查询语句成功,故程式中要备注替换 单引号。

无法解析插件 org.apache.maven.plugins:maven-war-plugin:3.2.3(已解决)

文章目录 1、问题出现的背景2、解决方法 1、问题出现的背景 最开始我想把springboot项目转为javaweb项目,然后我点击下面这个插件 就转为javaweb项目了,但是我后悔了,想要还原成springboot项目,点开项目结构关于web的都移除了&am…

运放-增益带宽积-datasheet参数

在运放开环增益频率曲线中,在一定频率范围内,运放的开环增益与对应的频率乘积为常数:增益带宽积(Gain Bandwidth Product, GBP 或者 GBW),即开环增益*频率增益带宽积。 这里有一个误区&#xf…

CompletableFuture异步线程不执行,卡死问题

1、生产上突然发现大量业务数据没执行,通过日志分析有段代码没执行。 2、分析原因可能是异步线程没执行导致,直接上代码场景 3、异步线程调用远程外部接口 超时或多次异常,导致服务无法再开启异步线程,同时代码中其他用到异步线程…

人人可学的AI与高科技普及视频课,零基础,通俗易懂,深入浅出

课程内容: 1 第0课:开课词,欢迎词 ev.mp4 2 第1课:我们为什么要学习Al ev.mp4 3 第2课:AI算法模型的基本概念MOVev,mp4 4 第3课:什么是生成性Al ev,mp4 5 第4课:人工智能的三驾马车 ev.mp4 6 加餐附加课1-谷歌双子座Gemini ev,mp4 7 第5课:关于Al…

为什么idea建议使用“+”拼接字符串

今天在敲代码的时候,无意间看到这样一个提示: 英文不太好,先问问ChatGPT,这个啥意思? IDEA 提示你,可以将代码中的 StringBuilder 替换为简单的字符串连接方式。 提示信息中说明了使用 StringBuilder 进行…

分布式相关理论详解

目录 1.绪论 2.什么是分布式系统,和集群的区别 3.CAP理论 3.1 什么是CAP理论 3.2 一致性 3.2.1 计算机的一致性说明 1.事务中的一致性 2.并发场景下的一致性 3.分布式场景下的一致性 3.2.2 一致性分类 3.2.3 强一致性 1.线性一致性 a) 定义 a) Raft算法…

数据危机!4大硬盘数据恢复工具,教你如何正确挽回珍贵记忆!

在这个数字化的时代,硬盘里的数据对我们来说简直太重要了。但糟糕的是,数据丢失这种事时不时就会发生,可能是因为不小心删了,硬盘坏了,或者中了病毒。遇到这种情况,很多人可能就慌了,不知道怎么…

鸿蒙(HarmonyOS)下拉选择控件

一、操作环境 操作系统: Windows 11 专业版、IDE:DevEco Studio 3.1.1 Release、SDK:HarmonyOS 3.1.0(API 9) 二、效果图 三、代码 SelectPVComponent.ets Component export default struct SelectPVComponent {Link selection: SelectOption[]priva…

模拟信号介绍

定义: 模拟信号是指用连续变化的物理量表示的信息,其信号的幅度、频率或相位随时间作连续变化,或在一段连续的时间间隔内,其代表信息的特征量可以在任意瞬间呈现为任意数值的信号。我们通常又把模拟信号称为连续信号,它…

挑战房市预测领头羊:KNN vs. 决策树 vs. 线性回归

挑战房市预测领头羊(KNN,决策树,线性回归) 1. 介绍1.1 K最近邻(KNN):与邻居的友谊1.1.1 KNN的基础1.1.2 KNN的运作机制1.1.3 KNN的优缺点 1.2 决策树:解码房价的逻辑树1.2.1 决策树的…

AttributeError: ‘list‘ object has no attribute ‘text‘

AttributeError: ‘list‘ object has no attribute ‘text‘ 目录 AttributeError: ‘list‘ object has no attribute ‘text‘ 【常见模块错误】 【解决方案】 示例代码 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页,我是博主英…

前端三大主流框架Vue React Angular有何不同?

前端主流框架,Vue React Angular,大家可能都经常在使用,Vue React,国内用的较多,Angualr相对用的少一点。但是大家有思考过这三大框架的不同吗? 一、项目的选型上 中小型项目:Vue2、React居多…

人工智能AI合集:Ollama部署对话语言大模型-网页访问

目录 🍅点击这里查看所有博文 随着人工智能技术的飞速发展,AI已经不再是遥不可及的高科技概念,而是逐渐融入到我们的日常生活中。从智能手机的语音助手到家庭中的智能音箱,再到工业自动化和医疗诊断,AI的应用无处不在…

gitee设置ssh公钥密码避免频繁密码验证

gitee中可以创建私有项目,但是在clone或者push都需要输入密码, 比较繁琐。 公钥则可以解决该问题,将私钥放在本地,公钥放在gitee上,当对项目进行操作时带有的私钥会在gitee和公钥进行验证,避免了手动输入密…

港科夜闻 | 香港科大与阿里巴巴合作,计划成立大数据与人工智能联合实验室

关注并星标 每周阅读港科夜闻 建立新视野 开启新思维 1、香港科大与阿里巴巴合作,计划成立大数据与人工智能联合实验室。香港科大7月19日与阿里巴巴集团签署合作备忘录,计划成立「香港科技大学–阿里巴巴大数据与人工智能联合实验室」,就生成…

STM32-寄存器DMA配置指南

配置步骤 在STM32F0xx中文参考手册中的DMA部分在开头给出了配置步骤 每个通道都可以在外设寄存器固定地址和存储器地址之间执行 DMA 传输。DMA 传输的数据 量是可编程的,最大达到 65535。每次传输之后相应的计数寄存器都做一次递减操作,直到 计数为&am…