深度学习实战(1):树叶分类pytorch

news2025/1/22 16:02:07

Classify Leaves | Kaggle

上面是数据集

数据导入与数据处理

%matplotlib inline
import torch
from torch.utils import data as Data
import torchvision
from torch import nn
import torchvision.models as models
from IPython import display
import os
import pandas as pd
import random
import PIL
import numpy as np
将标签转成类别
imgpath = os.getcwd()
trainlist = pd.read_csv(f"{imgpath}/train.csv")
num2name = list(trainlist["label"].value_counts().index)
random.shuffle(num2name)
name2num = {}
for i in range(len(num2name)):
    name2num[num2name[i]] = i
自定义数据集
class Leaf_data(Data.Dataset):
    def __init__(self,path,train,transform=lambda x:x) -> None:
        super().__init__()
        self.path = path
        self.transform = transform
        self.train = train
        if train:
            self.datalist = pd.read_csv(f"{path}/train.csv")
        else:
            self.datalist = pd.read_csv(f"{path}/test.csv")
    def __getitem__(self, index):
        res = ()
        tmplist = self.datalist.iloc[index,:]
        for i in tmplist.index:
            if(i=="image"):
                res += self.transform(PIL.Image.open(f"{self.path}/{tmplist[i]}")),
            else:
                res += name2num[tmplist[i]],
        if(len(res)<2):
            res+= tmplist[i],
        return res
    def __len__(self)->int:
        return len(self.datalist)

准备工作

画图、计算loss、累加器函数等,再之前文章中已经介绍过的,不必一句一句弄明白

def try_gpu():
    if(torch.cuda.device_count()>0):
        return torch.device('cuda')
    return torch.device('cpu')

def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1) #找出输入张量(tensor)中最大值的索引
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())
 
def evaluate_accuracy(net, data_iter):  #@save
    """计算在指定数据集上模型的精度"""
    if isinstance(net, torch.nn.Module):
        net.eval()  # 将模型设置为评估模式
    metric = Accumulator(2)  # 正确预测数、预测总数
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

class Accumulator:  #@save
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n
 
    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]
 
    def reset(self):
        self.data = [0.0] * len(self.data)
 
    def __getitem__(self, idx):
        return self.data[idx]
    
import matplotlib.pyplot as plt
from matplotlib_inline import backend_inline
 
def use_svg_display(): 
    """使⽤svg格式在Jupyter中显⽰绘图"""
    backend_inline.set_matplotlib_formats('svg')
 
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
     """设置matplotlib的轴"""
     axes.set_xlabel(xlabel)
     axes.set_ylabel(ylabel)
     axes.set_xscale(xscale)
     axes.set_yscale(yscale)
     axes.set_xlim(xlim)
     axes.set_ylim(ylim)
     if legend:
         axes.legend(legend)
     axes.grid()
 
class Animator:  #@save
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        # 增量地绘制多条线
        if legend is None:
            legend = []
        use_svg_display()
        self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # 使用lambda函数捕获参数
        self.config_axes = lambda: set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts
 
    def add(self, x, y):
        # 向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        net.eval()  # 设置为评估模式
        if not device:
            device = next(iter(net.parameters())).device
    # 正确预测的数量,总预测的数量
    metric = Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

import time
class Timer:  #@save
    """记录多次运行时间"""
    def __init__(self):
        self.times = []
        self.start()
 
    def start(self):
        """启动计时器"""
        self.tik = time.time()
 
    def stop(self):
        """停止计时器并将时间记录在列表中"""
        self.times.append(time.time() - self.tik)
        return self.times[-1]
 
    def avg(self):
        """返回平均时间"""
        return sum(self.times) / len(self.times)
 
    def sum(self):
        """返回时间总和"""
        return sum(self.times)
 
    def cumsum(self):
        """返回累计时间"""
        return np.array(self.times).cumsum().tolist()
模型构建或载入

如果是第一次训练则可以下载再ImageNet上预训练好的resnet18或者更大的模型,如果之前已经训练有保存好的模型则可以接着训练

model_path = 'pre_res_model.ckpt'
def save_model(net):
    torch.save(net.state_dict(),model_path)

def init_weight(m):
    if type(m) in [nn.Linear,nn.Conv2d]:
        nn.init.xavier_normal_(m.weight)

model_path = 'pre_res_model.ckpt'
first_train = False
if first_train:
    net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
    net.fc = nn.Linear(in_features=512, out_features=len(name2num), bias=True)
    net.fc.apply(init_weight)
else:
    net = models.resnet18()
    net.fc = nn.Linear(in_features=512, out_features=len(name2num), bias=True)
    net.fc.apply(init_weight)
    model_weights = torch.load(model_path)
    net.load_state_dict(model_weights)


net.to(try_gpu())
lr = 1e-4
parames = [parame for name,parame in net.named_parameters() if name not in ["fc.weight","fc.bias"]]
trainer = torch.optim.Adam([{"params":parames},
                            {"params":net.fc.parameters(),"lr":lr*10}],lr=lr)
LR_con = torch.optim.lr_scheduler.CosineAnnealingLR(trainer,1,0)
loss = nn.CrossEntropyLoss(reduction='none')

模型训练

控制一批的训练
def train_batch(X,y,net,loss,trainer,devices):
    if isinstance(X,list):
        X = [x.to(devices) for x in X]
    else:
        X = X.to(devices)
    y = y.to(devices)
    net.train()
    trainer.zero_grad()
    pred = net(X)
    l = loss(pred,y)
    l.sum().backward()
    trainer.step()
    LR_con.step()
    return l.sum(),accuracy(pred,y)
多个epoch
def train(train_data,test_data,net,loss,trainer,num_epochs,device = try_gpu()):
    best_acc = 0
    timer = Timer()
    plot = Animator(xlabel="epoch",xlim=[1,num_epochs],legend=['train loss','train acc','test loss'],ylim=[0,1])
    for epoch in range(num_epochs):
        # Sum of training loss, sum of training accuracy, no. of examples,
        # no. of predictions
        metric = Accumulator(4)
        for i, (features, labels) in enumerate(train_data):
            timer.start()
            l, acc = train_batch(
                features, labels, net, loss, trainer, device)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
        test_acc = evaluate_accuracy_gpu(net, test_data,device=device)
        if(test_acc>best_acc):
            save_model(net)
            best_acc = test_acc
        plot.add(epoch + 1, (metric[0] / metric[2], metric[1] / metric[3], test_acc))
        print(f'loss {metric[0] / metric[2]:.3f}, train acc '
          f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print()
    print(f'loss {metric[0] / metric[2]:.3f}, train acc '
          f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
          f'{str(device)}')
    print(f"best acc {best_acc}")
    return metric[0] / metric[2],metric[1] / metric[3],test_acc
transfroms和dataloader
batch = 128
num_epochs = 4
norm = torchvision.transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.ToTensor(),norm
])
train_data,valid_data = Data.random_split(
    dataset=Leaf_data(imgpath,True,augs),
    lengths=[0.8,0.2]
)
train_dataloder = Data.DataLoader(train_data,batch,True)
valid_dataloder = Data.DataLoader(valid_data,batch,True)
训练
train(train_dataloder,valid_dataloder,net,loss,trainer,num_epochs)

1-4轮: 

这个就是接着训练的,每次训练四轮 :

 继续接着训练,这里到了第12轮:

 接着训练,现在到了20轮,基本上再训练个10轮应该还是能把精度再更进一步提一提的。

这张图片是早上训练10个epoch后的四个epoch,可以看到结果相当不错。

 提交

net.load_state_dict(torch.load(model_path))
augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.ToTensor(),norm
])
test_data = Leaf_data(imgpath,False,augs)
test_dataloader = Data.DataLoader(test_data,batch_size=64,shuffle=False)
res = pd.DataFrame(columns = ["image","label"],index=range(len(test_data)))
net = net.cpu()
count = 0
for X,y in test_dataloader:
    preds = net(X).detach().argmax(dim=-1).numpy()
    preds = pd.DataFrame(y,index=map(lambda x:num2name[x],preds))
    preds.loc[:,1] = preds.index
    preds.index = range(count,count+len(y))
    res.iloc[preds.index] = preds
    count+=len(y)
    print(f"loaded {count}/{len(test_data)} datas")
res.to_csv('mysubmission.csv', index=False)

小结

  • resnet18作为一个40M的模型,训练这个200M的数据集是没有问题的,基本没有过拟合或者欠拟合
  • 在ImageNet上预训练好的resnet18,将最后一层改为176个类别输出,这样的迁移学习效果是非常好的
  • 要学会模型载入与保存,这样可以不断训练出更好的模型
  • 数据预处理对于不熟悉python的人来说可能是最耗时的一部分

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1991011.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode69. x 的平方根,二分法

leetcode69. x 的平方根 给你一个非负整数 x &#xff0c;计算并返回 x 的 算术平方根 。 由于返回类型是整数&#xff0c;结果只保留 整数部分 &#xff0c;小数部分将被 舍去 。 注意&#xff1a;不允许使用任何内置指数函数和算符&#xff0c;例如 pow(x, 0.5) 或者 x ** 0…

培训第二十三天(mysql主从脚本与mysql详细语句介绍)

上午 在同步时&#xff0c;对删除和修改都比较慎重&#xff08;监控同步时&#xff09; mysql主从搭建 前提软件libaio&#xff0c;rsync 1、主2、从3、同步4、测试 注意&#xff1a;先执行从服务器的脚本&#xff0c;再执行主服务器脚本 master-mysql配置脚本 先要在主服务…

企元数智小程序合规分销系统赠送:迎接数字化时代商机

当今时代&#xff0c;随着科技的高速发展和数字化的普及&#xff0c;企业如何抓住数字化时代带来的商机&#xff0c;成为了业界关注的焦点。在这样一个竞争激烈的市场环境下&#xff0c;企业需要不断提高自身的竞争力和应变能力&#xff0c;以应对激烈的市场竞争&#xff0c;开…

Phpstorm实现本地SSH开发远程机器(或虚拟机)项目

适用场景&#xff1a; 1、windows系统想要运行仅支持linux、mac系统的项目&#xff0c;可将项目运行在本地虚拟机&#xff0c;但是在虚拟机里使用vim编辑很麻烦&#xff0c;如何实现在本地用Phpstorm来编辑虚拟机中的代码&#xff1f; 下面的说明都是以本地虚拟机为例&#xff…

java之拼图小游戏(开源)

public class LoginJFrame extends JFrame {//表示登录界面&#xff0c;以后所有跟登录相关的都写在这里public LoginJFrame() {//设置界面的长和宽this.setSize(603,680);//设置界面的标题this.setTitle("拼图登陆界面");//设置界面置顶this.setAlwaysOnTop(true);/…

科研绘图系列:R语言圆形条形图(circular barplot)

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 介绍 圆形条形图(circular barplot)是一种条形图,其中的条形沿着圆形而不是线性排列展示。这种图表的输入数据集与普通条形图相同:每个组(一个组即一个条形)需要一个数值。(更多解释请参…

linux文件查找--locate和find命令详解

在文件系统上查找符合条件的文件 文件查找:1.非实时查找(数据库查找):locate2.实时查找: find应用&#xff1a;生产环境中查找到系统中占用磁盘空间较大且时间比较久的大日志文件&#xff0c;对这个较大的日志文件做处理&#xff08;删除移走等)&#xff0c;防止它占用更多的磁…

gps 轨迹点如何绘制路径

作为用户&#xff0c;我们进行户外运动后&#xff0c;有的人喜欢分享自己的运动记录。这个时候就比较关注自己的运动轨迹路线了。 一.将经纬度转化为轨迹方法1 1.将gps 打点文件导出。 2.将经纬度点转换成如下格式。 3.将转换后的经纬度填入如下地址&#xff1a; https://ww…

必了解的 20 个 AI 术语解析(下)

AI 领域的基础概念和相关技术有很多&#xff0c;这篇文章里&#xff0c;作者就深入浅出地介绍了相应的内容&#xff0c;感兴趣的同学们&#xff0c;不妨来看一下。 必了解的 20 个 AI 术语解析&#xff08;下&#xff09;© 由 ZAKER科技 提供 本文专为非技术背景的 AI 爱…

如何修改360免费wifi热点的频带为2.4G或者5G

有的时候使用电脑广播出热点给嵌入式设备用进而进行抓包&#xff0c;但是他默认广播的是5G Hz的&#xff0c;嵌入式设备扫不到热点。那么如何让他广播2.4G H在呢&#xff1f; CMD控制台使用命令netsh wlan show drivers查看设备驱动&#xff1a; 802.11g 和 802.11n 意味着你的…

Python酷库之旅-第三方库Pandas(071)

目录 一、用法精讲 286、pandas.Series.dt.to_pydatetime方法 286-1、语法 286-2、参数 286-3、功能 286-4、返回值 286-5、说明 286-6、用法 286-6-1、数据准备 286-6-2、代码示例 286-6-3、结果输出 287、pandas.Series.dt.tz_localize方法 287-1、语法 287-2、…

Selenium 自动化测试最佳实践

1 编码前的准备工作与基本指导思想 测试一个网站就是针对该网站测试场景的一次项目开发&#xff0c;所以项目开发中的理念与思想可以借鉴过来。接到测试需求后&#xff0c;不要一开始就陷入按钮、字段、下拉框等页面元素怎么操作的技术细节当中&#xff0c;而要站在最终用户的…

《MySQL数据库》 可视化工具的使用—/—<3>

一、如何使用可视化工具navicat 1、点击左上角的连接中的MySQL 输入主机地址连接虚拟机&#xff0c;找到自己虚拟机中的ip地址输入即可&#xff0c;连接名随意修改 然后点击测试连接&#xff0c;连接成功即可点击确定 2、新建库 直接鼠标右击连接名称ahao001&#xff0c;点击…

react学习笔记:7

预览&#xff1a;&#xff08;fetch发送请求、SPA、连续解构赋值、消息订阅、react router路由第三方库&#xff09; 1、连续解构赋值 总结&#xff1a; 1、连续解构赋值的写法&#xff1a;对象包对象&#xff0c;第二个解构的value一定也是在{}内部的写法 2、消息订阅发布 …

SwiftUI 中 TabView 视图导航栏上按钮丢失问题的解决

问题现象 在某些情况下,SwiftUI 中 TabView 子视图中导航栏上的 ToolbarItem 会消失不见。 如上图所示:在子视图的 Kick Off 导航栏按钮被按下并回退到 TabView 中的主视图之后,其右上角的按钮竟然“神奇”的消失了!该如何解决它呢? 在本篇博文中,您将学到以下内容 问题…

【二分查找】3143. 正方形中的最多点数

本文涉及的基础知识点 C二分查找 LeetCode3143. 正方形中的最多点数 给你一个二维数组 points 和一个字符串 s &#xff0c;其中 points[i] 表示第 i 个点的坐标&#xff0c;s[i] 表示第 i 个点的 标签 。 如果一个正方形的中心在 (0, 0) &#xff0c;所有边都平行于坐标轴&…

大数据-Big Data

GPT-4o (OpenAI) 大数据&#xff08;Big Data&#xff09;指的是无法使用传统方法和工具在合理的时间内处理和分析的大规模数据集。大数据通常具有以下几种特征&#xff0c;也称为5V特征&#xff1a; 1. Volume&#xff08;数据量&#xff09;&#xff1a;大数据涉及到大量的信…

深度学习常用语句for param in params问题:为什么修改param之后,params对应元素也随之改变?

def sgd(params, lr, batch_size): #save"""小批量随机梯度下降"""with torch.no_grad():for param in params:param - lr * param.grad / batch_sizeparam.grad.zero_()sgd([w, b], lr, batch_size) 上述代码中&#xff0c;param遍历params的…

深度学习--------------Kaggle房价预测

目录 下载和缓存数据集访问和读取数据集总代码 数据预处理训练K折交叉验证模型选择总代码提交你的Kaggle预测提交Kaggle 下载和缓存数据集 import hashlib import os import tarfile import zipfile import requests# download传递的参数分别是数据集的名称、缓存文件夹的路径…

LabVIEW液压传动系统

开发了一种高效的液压传动系统&#xff0c;其特点在于采用LabVIEW软件与先进的硬件配合&#xff0c;实现能量的有效回收。此系统主要应用于工业机械中&#xff0c;如工程机械和船机械等&#xff0c;通过优化液压泵和马达的测试台设计&#xff0c;显著提高系统的能效和操作性能。…