第十二章 迁移学习-实战宝可梦精灵

news2024/11/28 9:35:06

文章目录

  • 一、Pokemon数据集
    • 1.1 数据集收集
    • 1.2 数据集划分
    • 1.3 数据集加载
    • 1.4 数据预处理
    • 1.5 pytorch自定义数据库实现
  • 二、ResNet网络搭建
  • 三、训练与测试
  • 四、迁移学习
    • 4.1 pytorch实现迁移学习

一、Pokemon数据集

1.1 数据集收集

在这里插入图片描述

# git下载
git lfs install
git clone https://www.modelscope.cn/datasets/ModelBulider/pokemon.git

1.2 数据集划分

在这里插入图片描述


1.3 数据集加载

在这里插入图片描述

  • 加载数据
    ① 继承 torch.utils.data.Dataset
    ② 实现 __len__ 函数,其返回数据集的数量(整型数字)
    ③ 实现 __getitem__函数,根据索引值返回一个数据
    在这里插入图片描述

举例:
在这里插入图片描述


1.4 数据预处理

将尺寸大小不一致的数据(图片)预处理为大小一致的1数据
② 数据增强(旋转、裁剪等)
③ 归一化(均值、方差)
④ 转换为 Tensor 数据类型
在这里插入图片描述


1.5 pytorch自定义数据库实现

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: code - pokemon.py
@author: yonghao
@Description: 
@since 2021/03/01 19:41
'''
from visdom import Visdom
import time
import torch
import os, glob
import random, csv
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

root = 'D:\\个人\\学习资料\\学习视频\\深度学习与PyTorch入门实战教程\\12.迁移学习-实战宝可梦精灵\\project_code\\pokemon'


class Pokemon(Dataset):
def __init__(self, root, resize, mode='train'):
'''
初始化数据集
:paramroot: 图片存储的位置
:paramresize: 重新编辑图片的尺寸
:parammode: 初始化图片的类型(可以是数据集中各中分类)
'''
super(Pokemon, self).__init__()
self.root = root
 self.resize = resize
 self.mode = mode
 self.name2label = {}
# 创建 类名-> label 的映射字典
# os.listdir()每次顺序都不一样,故使用sorted()排序,使 类名-> label 的映射字典固定
for name in sorted(os.listdir(os.path.join(root))):
# 只读取文件夹名
if not os.path.isdir(os.path.join(root, name)):
continue
self.name2label[name] = len(self.name2label)
self.images, self.labels = self.load_csv('images.csv')
# 根据mode设定数据集的比例
if mode == 'train': # 60%
self.images = self.images[:int(0.6 * len(self.images))]
self.labels = self.labels[:int(0.6 * len(self.labels))]
elif mode == 'val': # 20%
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else: # 20%
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]

def __len__(self):
return len(self.images)

def __getitem__(self, item) -> tuple:
# item ~ [0,len(images)-1]
# self.images , self.labels
# image , label
img, label = self.images[item], self.labels[item]
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'), # string path => image data
transforms.Resize((int(1.25 * self.resize), int(1.25 * self.resize))), # 调整尺寸
transforms.RandomRotation(15), # 旋转
transforms.CenterCrop(self.resize), # 中心裁剪
transforms.ToTensor(),
# 注意transforms.Normalize() 应该在transforms.ToTensor() 后面
# 数据在通道层上归一化,会使变化图片的像素
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 返回由img,label 组成的Tensor 元组
img = tf(img)
label = torch.tensor(label)

return img, label

 def denormalize(self, x_het):
'''
图像逆正则化显示
:paramx_het: 正则化后的数据
:return:
'''

# x_het = (x - mean) / std
mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
# x = x_het * std + mean
# x:[channel , h , w] , mean:[3] -> [3,1,1] , std:[3] -> [3,1,1]
mean = mean.unsqueeze(dim=-1).unsqueeze(dim=-1)

std = std.unsqueeze(dim=-1).unsqueeze(dim=-1)

x = x_het * std + mean

 return x

 def load_csv(self, filename):
'''
加载图片数据 与 其label数据
:paramfilename: 加载数据的文件名
:return:
'''
# 仅在第一次调用时创建csv文件,保存 图片路径——>label 的映射关系
if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.name2label.keys():
'''
python在模块glob中定义了glob()函数,实现了对目录内容进行匹配的功能,
glob.glob()函数接受通配模式作为输入,并返回所有匹配的文件名和路径名列表
与os.listdir类似
'''
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
# 1167 , 'D:\\个人\\学习资料\\学习视频\\深度学习与PyTorch入门实战教程\\12.迁移学习-实战宝可梦精灵\\project_code\\pokemon\\bulbasaur\\00000000.png'

# 打乱的是图片的存储路径
random.shuffle(images)

# 使用上下文管理,对文件进行操作
'''
with是从Python2.5引入的一个新的语法,它是一种上下文管理协议,目的在于从流程图中把try,except 和finally 关键字和

资源分配释放相关代码统统去掉,简化try….except….finlally的处理流程。

with通过__enter__方法初始化,然后在__exit__中做善后以及处理异常。

所以使用with处理的对象必须有__enter__()和__exit__()这两个方法。

其中__enter__()方法在语句体(with语句包裹起来的代码块)执行之前进入运行,__exit__()方法在语句体执行完毕退出后运行。

with 语句适用于对资源进行访问的场合,确保不管使用过程中是否发生异常都会执行必要的“清理”操作,释放资源,比如文件使用后自动关闭、线程中锁的自动获取和释放等。

紧跟with后面的语句会被求值,返回对象的__enter__()方法被调用,这个方法的返回值将被赋值给as关键字后面的变量,当with后面的代码块全部被执行完之后,将调用前面返回对象的__exit__()方法
'''
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
# os.sep 为系统自动识别的文件路径分隔符
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img, label])

images, labels = [], []
with open(os.path.join(root, filename), mode='r') as f:
reader = csv.reader(f)
for row in reader:
img, label = row
                images.append(img)
labels.append(int(label))

assert len(images) == len(labels)

return images, labels


def main():
vis = Visdom()
# 获取数据集(单个数据做返回)
db = Pokemon(root, 64, mode='train')
img, label = next(iter(db))
print('sample:', img.shape, label.shape)
vis.image(img, win='img_win_het', opts=dict(title='norm_img_show'))
vis.image(db.denormalize(img), win='img_win', opts=dict(title='img_show'))

# 批量导出数据
loader = DataLoader(db, batch_size=32, shuffle=True)
for x, y in loader:
vis.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
vis.text(str(y.numpy()), win='label', opts=dict(title='bacth-y'))
time.sleep(10)


if __name__ == '__main__':
main()

二、ResNet网络搭建

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: 实战代码- resnet.py
@author: yonghao
@Description: 创建残差网络结构
@since 2021/03/01 17:51
'''
import torch
import torch.nn.functional as F
from torch import nn
import utils


class ResBlk(nn.Module):
'''
创建ResBlock
'''

def __init__(self, ch_in, ch_out, stride=1):
'''
创建ResBlock模块
:paramch_in: 输入的通道数
:paramch_out: 输出的通道数
:paramstride: 卷积步长
'''
super(ResBlk, self).__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
if ch_in == ch_out:
self.extra = nn.Sequential()
else:
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
nn.BatchNorm2d(ch_out)
)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))

out = out + self.extra(x)
out = F.relu(out)
return out


class ResNet18(nn.Module):

def __init__(self, num_class):
'''
创建18层的ResNet
:paramnum_class:分类数量
'''
super(ResNet18, self).__init__()

self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=2),
nn.BatchNorm2d(16)
)

# followed 4 blocks
# [b , 16 , h , w] => [b , 32 , h , w]
self.blk1 = ResBlk(16, 32, stride=3)
# [b , 32 , h , w] => [b , 64 , h , w]
self.blk2 = ResBlk(32, 64, stride=3)
# [b , 64 , h , w] => [b , 128 , h , w]
self.blk3 = ResBlk(64, 128, stride=2)
# [b , 128 , h , w] => [b , 256 , h , w]
self.blk4 = ResBlk(128, 256, stride=2)
# [b , 256 , h , 2] => [b , 256*h*w]
self.flat = utils.Flatten()
# [b , 256*h*w] => [b , num_class]
self.out_layer = nn.Linear(256 * 3 * 3, num_class)

def forward(self, x):
x = F.relu(self.conv1(x), inplace=True)
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# print(x.shape)
x = self.flat(x)
out = self.out_layer(x)
return out


def mian():
# 测试ResBlk,当ch_in==ch_out时正确
# 当ch_in==ch_out时报异常
blk = ResBlk(64, 128, stride=2)
tmp = torch.randn(2, 64, 64, 64)
out = blk(tmp)
print('block:', out.shape)

model = ResNet18(5)
tmp = torch.randn(2, 3, 224, 224)
out = model(tmp)
print("resnet:", out.shape)
p = sum([i.numel() for i in model.parameters()])
print('parameters size:', p)


if __name__ == '__main__':
mian()


三、训练与测试

在这里插入图片描述

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: project_code - process.py
@author: yonghao
@Description: 实现训练过程 与 测试过程
@since 2021/03/02 18:54
'''
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from model.resnet import ResNet18
from pokemon import Pokemon

# 批量数量
bacthsz = 32

# 学习率
lr = 1e-3

# 迭代次数
epochs = 10

# device = torch.device('cpu')
# if torch.cuda.is_available():
#     device = torch.device('cuda')

# 设置固定随机初始值
torch.manual_seed(1234)

# 训练集
train_db = Pokemon('pokemon', 224, mode='train')
train_loader = DataLoader(train_db, batch_size=bacthsz, shuffle=True, num_workers=4)

# 验证集
val_db = Pokemon('pokemon', 224, mode='val')
val_loader = DataLoader(val_db, batch_size=bacthsz, num_workers=2)

# 测试集
test_db = Pokemon('pokemon', 224, mode='test')
test_loader = DataLoader(test_db, batch_size=bacthsz, num_workers=2)


def evaluate(model, loader):
correct = 0
total = len(loader.dataset)
for x, y in loader:
# x, y = x.to(device), y.to(device)
# x:[b , c , h , w] , y:[b]
# out:[b,class_num]
with torch.no_grad():
out = model(x)
pred = out.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()

return correct / total


def main():
# model = ResNet18(5).to(device)
model = ResNet18(5)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()

# 用于保存最高精度
best_acc = 0
best_epoch = 0
# 训练过程
for epoch in range(epochs):
for step, (x, y) in enumerate(train_loader):
# [b , c , h , w] , y[b]
# x, y = x.to(device), y.to(device)
logits = model(x)
loss = criteon(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# validation
if epoch % 2 == 0:
val_acc = evaluate(model, val_loader)
if val_acc > best_acc:
best_epoch = epoch
                best_acc = val_acc
                torch.save(model.state_dict(), 'best.mdl')
print('best acc:', best_acc, "best epoch:", best_epoch)

# 测试过程
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')

test_acc = evaluate(model, test_loader)
print('test acc:', test_acc)


if __name__ == '__main__':
'''
best acc: 0.8969957081545065 best epoch: 8
loaded from ckpt!
test acc: 0.8931623931623932
'''
main()

四、迁移学习

将处理相类似信号(特别是数据量较大)的神经网络嫁接过来,应用到本实验中
在这里插入图片描述

  • 具体的嫁接过程
    ① 尽量保留网络前、中部分
    ② 去除最后一层,根据自己的分类任务定制最后一层
    在这里插入图片描述

4.1 pytorch实现迁移学习

from torchvision.models import resnet18

model = resnet18(pretrained=True)
# 17 layer out:[32, 512, 1, 1]
model = nn.Sequential(*list(model.children())[:-1],
utils.Flatten(),# 降维度
nn.Linear(512, 5))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2062537.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【大数据】什么是数据中台?

随着企业规模不断扩大、业务多元化——中台服务架构的应运而生。“中台”早期是由美军的作战体系演化而来的,技术上说的“中台”主要是指学习这种高效、灵活和强大的指挥作战体系。阿里在今年发布“双中台ET”数字化转型方法论,“双中台”指的是数字中台…

学生和青年研究人员该如何撰写高质量论文?教授支招:3个关键点解锁一篇成功发表的稿件

我是娜姐 迪娜学姐 ,一个SCI医学期刊编辑,探索用AI工具提效论文写作和发表。 本文通讯作者为中国科学院大学唐智勇教授,博士生导师,科技部纳米重大研究计划首席科学家。 对于学生和青年研究人员该如何撰写高质量稿件并成功发表…

揭秘Semantic Kernel:用AI自动规划和执行用户请求

在我们日益高效的开发世界中,将任务自动化并智能规划变得越来越必要。今天,我要给大家介绍一个强大的概念——Semantic Kernel中的planner功能。通过这篇文章,我们会学习到planner的工作原理以及如何实现智能任务规划。 什么是planner&#x…

vue3项目中无法实现cpolar内网穿透解决方案

运行vue3,打开cpolar启动内网穿透,结果却发现 在vue.config.js中修改为如下代码: const { defineConfig } require(vue/cli-service);module.exports defineConfig({transpileDependencies: true,devServer: {allowedHosts: all,host: 0.0…

day06-MySQL学习笔记01

2024.08.17 day06-MySQL学习笔记 前言 前面说过,三层架构,其中dao层用于操作数据。在上面的项目中,数据放在了xml文件中。在企业开发中,数据一般存储在数据库中,我们直接对数据库操作。今天就学习如何操作数据库。 首…

赛氪网技术支持第八届集创赛全国总决赛:共绘集成电路创新蓝图

赛氪网技术支持第八届集创赛全国总决赛:共绘集成电路创新蓝图 山东,2024年8月19日至21日 —— 全国瞩目的第八届全国大学生集成电路创新创业大赛(以下简称“集创赛”)全国总决赛在美丽的海滨城市山东省烟台市隆重举行。本次大赛由…

架桥机液压站比例阀放大器

架桥机液压站是专为公路桥梁建设而设计的一种重要设备,它通过先进的液压系统来实现桥梁的快速、安全架设。液压系统包括三套独立的子系统,分别服务于1号柱、2号柱以及0号柱和3号柱。每套系统均由液压泵站、液压缸、比例电磁控制阀等核心部件构成。液压泵…

IaaS,PaaS,aPaaS,SaaS,FaaS,如何区分?

​IaaS, PaaS,SaaS,aPaaS 还有一种 FaaS ,这几个都是云服务中常见的 5 大类型: IaaS:基础架构即服务,Infrastructure as a Service PaaS:平台即服务,Platform as a Service aPaaS&…

Linux_rwx权限,修改权限,修改所有者和所在组

目录 权限的基本介绍 rwx作用到文件 rwx作用到目录 权限说明案例 修改权限 修改文件所有者-chown 修改文件/目录所在组-chgrp 权限的基本介绍 第0位是文件类型,然后是所有者的权限,所属组的权限,其他用户的权限。 -代表它是一个普通…

使用VS Code开发.NET 8 环境搭建

1. sdk环境确认 -- 查看.net 版本 PS C:\Users\a> dotnet --version 8.0.303 -- 查看已安装的.net sdk 列表 PS C:\Users\a> dotnet --list-sdks 3.0.100 [C:\Program Files\dotnet\sdk] 5.0.301 [C:\Program Files\dotnet\sdk] 6.0.417 [C:\Program Files\dotnet\sdk] …

cadence617版本,如何做一个参数可调的反相器

🏆本文收录于《CSDN问答解惑-专业版》专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收…

Linux VSFTP 部署与配置

一、VSFTP 简介与应用 VSFTP(Very Secure FTP Daemon)是一款功能强大、安全可靠的FTP服务器软件,广泛应用于Linux/Unix系统中。它提供了高效的文件传输服务,并具备诸多安全特性,如用户认证、权限控制、SSL/TLS加密等。…

AI大模型太TM牛逼了!

如果你问:2024年,程序员必须掌握哪项技术?AI一定是榜首! 从去年起,AI大模型已是程序员的必备工具——‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍ 编程提效: 编写更快,程序更稳定;代码更优&am…

【全开源】php在线客服系统源码 (搭建教程+全新UI)

PHP在线客服系统是一种基于PHP编程语言开发的在线客服系统,它可以为网站提供实时的在线客服支持,方便用户与客服人员进行即时的沟通和交流。作为一种开源的系统,它的源码可以供开发者进行二次开发和定制,以满足不同网站的需求。 …

老古董Lisp实用主义入门教程(5):好奇先生用Lisp探索Lisp

鲁莽先生什么都不管 鲁莽先生打开电脑,安装一堆东西,噼里啪啦敲了一堆代码,叽里呱啦说了一堆话,然后累了就回家睡觉了。 这可把好奇先生的兴趣勾起来,他怎么也睡不着。好奇先生打开电脑,看了看鲁莽先生留…

Figma 替代品 Penpot 安装和使用教程

在设计领域,Figma 无疑是一个巨人。它彻底改变了设计流程,将协作带到了一个全新的高度。然而,随着 Adobe 收购 Figma 的消息传出,许多设计师和开发者开始担心:Figma 未来会如何演变?那些好用的特性会不会被…

【python】深入探讨python中的抽象类,创建、实现方法以及应用实战

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

【SpringBoot】11 多数据源(MyBatis:dynamic-datasource)

介绍 多数据源:指的是一个单一应用程序中涉及了两个及以上的数据库,这种配置允许应用程序根据业务需求灵活地管理和操作不同的数据库。 需求 一个应用服务中,连接多个数据库,有本地的也有远程的,有MysQL、Oracle、P…

代码随想录算法训练营day51:图论02:99. 岛屿数量;100. 岛屿的最大面积

99. 岛屿数量 卡码网题目链接(ACM模式)(opens new window) 题目描述: 给定一个由 1(陆地)和 0(水)组成的矩阵,你需要计算岛屿的数量。岛屿由水平方向或垂直方向上相邻的陆地连接而…

PHP农场扶农系统智慧认养智慧乡村系统农场系统小程序源码

🌱科技赋能田园梦 —— 探索“农场扶农系统”与“智慧认养智慧乡村”新篇章🚀 🌈【开篇:田园新风尚,科技引领未来】 在快节奏的都市生活中,你是否曾梦想过拥有一片属于自己的绿色天地?现在&am…