Kaggle:树叶分类(使用Jupyter)

news2024/11/26 5:26:50

竞赛网址:https://www.kaggle.com/c/classify-leaves

# 首先导入包
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import matplotlib.pyplot as plt
import torchvision.models as models
# This is for the progress bar.
from tqdm import tqdm
import seaborn as sns
import albumentations
from albumentations.pytorch.transforms import ToTensorV2
from torchvision.models import resnet50, ResNet50_Weights
# 看看label文件长啥样
labels_dataframe = pd.read_csv('../input/classify-leaves/train.csv')
labels_dataframe.head(5)

在这里插入图片描述

labels_dataframe.describe()

在这里插入图片描述

#function to show bar length

def barw(ax): 
    
    for p in ax.patches:
        val = p.get_width() #height of the bar
        x = p.get_x()+ p.get_width() # x- position 
        y = p.get_y() + p.get_height()/2 #y-position
        ax.annotate(round(val,2),(x,y))
        
#finding top leaves

plt.figure(figsize = (15,30))
ax0 =sns.countplot(y=labels_dataframe['label'],order=labels_dataframe['label'].value_counts().index)
barw(ax0)
plt.show()

在这里插入图片描述

# 把label文件排个序
leaves_labels = sorted(list(set(labels_dataframe['label'])))
n_classes = len(leaves_labels)
print(n_classes)
leaves_labels[:10]

在这里插入图片描述

# 把label转成对应的数字
class_to_num = dict(zip(leaves_labels, range(n_classes)))
class_to_num

在这里插入图片描述

# 再转换回来,方便最后预测的时候使用
num_to_class = {v : k for k, v in class_to_num.items()}
# 继承pytorch的dataset,创建自己的
class LeavesData(Dataset):
    def __init__(self, csv_path, file_path, mode='train', valid_ratio=0.2, resize_height=256, resize_width=256):
        """
        Args:
            csv_path (string): csv 文件路径
            img_path (string): 图像文件所在路径
            mode (string): 训练模式还是测试模式
            valid_ratio (float): 验证集比例
        """
        
        # 需要调整后的照片尺寸,我这里每张图片的大小尺寸不一致#
        self.resize_height = resize_height
        self.resize_width = resize_width

        self.file_path = file_path
        self.mode = mode

        # 读取 csv 文件
        # 利用pandas读取csv文件
        self.data_info = pd.read_csv(csv_path, header=None)  #header=None是去掉表头部分
        # 计算 length
        self.data_len = len(self.data_info.index) - 1
        self.train_len = int(self.data_len * (1 - valid_ratio))
        
        if mode == 'train':
            # 第一列包含图像文件的名称
            self.train_image = np.asarray(self.data_info.iloc[1:self.train_len, 0])  #self.data_info.iloc[1:,0]表示读取第一列,从第二行开始到train_len
            # 第二列是图像的 label
            self.train_label = np.asarray(self.data_info.iloc[1:self.train_len, 1])
            self.image_arr = self.train_image 
            self.label_arr = self.train_label
        elif mode == 'valid':
            self.valid_image = np.asarray(self.data_info.iloc[self.train_len:, 0])  
            self.valid_label = np.asarray(self.data_info.iloc[self.train_len:, 1])
            self.image_arr = self.valid_image
            self.label_arr = self.valid_label
        elif mode == 'test':
            self.test_image = np.asarray(self.data_info.iloc[1:, 0])
            self.image_arr = self.test_image
            
        self.real_len = len(self.image_arr)

        print('Finished reading the {} set of Leaves Dataset ({} samples found)'
              .format(mode, self.real_len))

    def __getitem__(self, index):
        # 从 image_arr中得到索引对应的文件名
        single_image_name = self.image_arr[index]

        # 读取图像文件
        img_as_img = Image.open(self.file_path + single_image_name)

        #如果需要将RGB三通道的图片转换成灰度图片可参考下面两行
#         if img_as_img.mode != 'L':
#             img_as_img = img_as_img.convert('L')

        #设置好需要转换的变量,还可以包括一系列的nomarlize等等操作
        if self.mode == 'train':
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.RandomHorizontalFlip(p=0.5),   #随机水平翻转 选择一个概率
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            # valid和test不做数据增强
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        img_as_img = transform(img_as_img)
        
        if self.mode == 'test':
            return img_as_img
        else:
            # 得到图像的 string label
            label = self.label_arr[index]
            # number label
            number_label = class_to_num[label]

            return img_as_img, number_label  #返回每一个index对应的图片数据和对应的label

    def __len__(self):
        return self.real_len
train_path = '../input/classify-leaves/train.csv'
test_path = '../input/classify-leaves/test.csv'
# csv文件中已经images的路径了,因此这里只到上一级目录
img_path = '../input/classify-leaves/'

train_dataset = LeavesData(train_path, img_path, mode='train')
val_dataset = LeavesData(train_path, img_path, mode='valid')
test_dataset = LeavesData(test_path, img_path, mode='test')
print(train_dataset)
print(val_dataset)
print(test_dataset)

在这里插入图片描述

# 定义data loader
train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=64, 
        shuffle=True,
        num_workers=4
    )

val_loader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_size=64, 
        shuffle=True,
        num_workers=4
    )
test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=64, 
        shuffle=False,
        num_workers=4
    )
# 看一下是在cpu还是GPU上
def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'

device = get_device()
print(device)
# 是否要冻住模型的前面一些层
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        model = model
        for param in model.parameters():
            param.requires_grad = False
# resnet34模型
def res_model(num_classes, feature_extract = False, use_pretrained=True):

    model_ft = models.resnet34(pretrained=True, progress=True)
    set_parameter_requires_grad(model_ft, feature_extract)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, num_classes))

    return model_ft
# 超参数
learning_rate = 1e-4
weight_decay = 1e-3
num_epoch = 20
model_path = './pre_res_model.ckpt'
# Initialize a model, and put it on the device specified.
model = res_model(176)
model = model.to(device)
model.device = device
# For the classification task, we use cross-entropy as the measurement of performance.
criterion = nn.CrossEntropyLoss()

# Initialize optimizer, you may fine-tune some hyperparameters such as learning rate on your own.
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate,weight_decay=weight_decay)

# The number of training epochs.
n_epochs = num_epoch

best_acc = 0.0
for epoch in range(n_epochs):
    # ---------- Training ----------
    # Make sure the model is in train mode before training.
    model.train() 
    # These are used to record information in training.
    train_loss = []
    train_accs = []
    # Iterate the training set by batches.
    for batch in tqdm(train_loader):
        # A batch consists of image data and corresponding labels.
        imgs, labels = batch
        imgs = imgs.to(device)
        labels = labels.to(device)
        # Forward the data. (Make sure data and model are on the same device.)
        logits = model(imgs)
        # Calculate the cross-entropy loss.
        # We don't need to apply softmax before computing cross-entropy as it is done automatically.
        loss = criterion(logits, labels)
        
        # Gradients stored in the parameters in the previous step should be cleared out first.
        optimizer.zero_grad()
        # Compute the gradients for parameters.
        loss.backward()
        # Update the parameters with computed gradients.
        optimizer.step()
        
        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels).float().mean()

        # Record the loss and accuracy.
        train_loss.append(loss.item())
        train_accs.append(acc)
        
    # The average loss and accuracy of the training set is the average of the recorded values.
    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)

    # Print the information.
    print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")
    
    
    # ---------- Validation ----------
    # Make sure the model is in eval mode so that some modules like dropout are disabled and work normally.
    model.eval()
    # These are used to record information in validation.
    valid_loss = []
    valid_accs = []
    
    # Iterate the validation set by batches.
    for batch in tqdm(val_loader):
        imgs, labels = batch
        # We don't need gradient in validation.
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits = model(imgs.to(device))
            
        # We can still compute the loss (but not the gradient).
        loss = criterion(logits, labels.to(device))

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        valid_loss.append(loss.item())
        valid_accs.append(acc)
        
    # The average loss and accuracy for entire validation set is the average of the recorded values.
    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)

    # Print the information.
    print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")
    
    # if the model improves, save a checkpoint at this epoch
    if valid_acc > best_acc:
        best_acc = valid_acc
        torch.save(model.state_dict(), model_path)
        print('saving model with acc {:.3f}'.format(best_acc))

在这里插入图片描述

## predict
saveFileName = './submission.csv'

model = res_model(176)

# create model and load weights from checkpoint
model = model.to(device)
model.load_state_dict(torch.load(model_path))

# Make sure the model is in eval mode.
# Some modules like Dropout or BatchNorm affect if the model is in training mode.
model.eval()

# Initialize a list to store the predictions.
predictions = []
# Iterate the testing set by batches.
for batch in tqdm(test_loader):
    
    imgs = batch
    with torch.no_grad():
        logits = model(imgs.to(device))
    
    # Take the class with greatest logit as prediction and record it.
    predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())

preds = []
for i in predictions:
    preds.append(num_to_class[i])

test_data = pd.read_csv(test_path)
test_data['label'] = pd.Series(preds)
submission = pd.concat([test_data['image'], test_data['label']], axis=1)
submission.to_csv(saveFileName, index=False)
print("Done!!!!!!!!!!!!!!!!!!!!!!!!!!!")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/771431.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

uniapp 小程序 实时拍照(仅拍照)限制上传5张 可预览 可删除

效果图: common.js /*** 预览图片*/ const previewImage (current,list)>{// 预览图片uni.previewImage({current: current,urls: list}); } /*** 删除图片*/ const removeImage (current,list)>{var photoFilesList list;photoFilesList.splice(curren…

AJAX:宏任务与微任务

异步任务划分为了 宏任务:由浏览器环境执行的异步代码 微任务:由 JS 引擎环境执行的异步代码 宏任务和微任务具体划分: 左边表格是宏任务,右边是微任务 事件循环模型 /*** 目标:阅读并回答打印的执行顺序 */ console…

国内软件外包公司开发流程

当企业发展到一定阶段后,现有市场上通用型的软件往往无法满足自身的业务需求,这就需要企业定制化开发软件系统来满足自身独特的需求。而传统企业往往没有自己的软件研发队伍,在开发软件系统时快速新建团队风险比较高,可以采用外包…

Docker 网络模型:多角度分析容器网络的原理与应用

🌷🍁 博主 libin9iOak带您 Go to New World.✨🍁 🦄 个人主页——libin9iOak的博客🎐 🐳 《面试题大全》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~&#x1f33…

FPGA adrv9002 4收4发板卡,支持NVME SATA EMMC 光口 FMC

板卡采用ADI 射频直采芯片ADRV9002 ,支持4收4发支持外部本振 跳频 同时支持4X 10G光口对外传输,FMC扩展 。同时支持4X NVME接口,可以实时流盘,备份一路SAT A接口,板卡同时预留了EMMC,可以PS PL选通访问&…

【Java基础教程】Java学习路线攻略导图 · 上篇 ~

Java学习路线攻略导图 上篇 前言1、入门介绍篇2、程序基础概念篇3、包及访问权限篇4、异常处理篇5、特别篇6、面向对象篇 前言 🍺🍺 各位读者朋友大家好!得益于各位朋友的支持和关注,我的专栏《Java基础教程》 至今已经更新近半&…

浅谈无人机遥感图像拼接与处理方法

遥感(RS-Remote Sensing)——不接触物体本身,用传感器收集目标物的电磁波信息,经处理、分析后,识别目标物,揭示其几何、物理性质和相互关系及其变化规律的现代科学技术。 换言之,即是“遥远的感…

mysal数据库的日志恢复

目录 一 物理冷备份 二 mysqldump 备份与恢复(温备份) 三 mgsql中的增量备份需要借助mysql日志的二进制来恢复 小结 一 物理冷备份 systemctl stop mysqld yum -y install xz 压缩备份 tar Jcvf /opt/mysql_all_$(date %F).tar.xz /usr/local/mysql/…

文件上传前前端通过魔数(magic number)去限制上传文件类型

问题 最近项目需求文件上传前判断文件类型,有的同学会说用文件后缀判断不就好啦。其一,文件后缀可以修改,正确性待考究;其二,有些文件并没有文件后缀。这就需要我们动动脑筋啦,其实我们可以根据文件的头信…

4个顶级WooCommerce商城多站点库存同步WordPress插件

经营几家网上商店是令人兴奋的。但是,这也是一项艰巨的工作,尤其是当您意识到需要同步这些商店的库存时。好消息是,有 WooCommerce 多站点库存同步插件和选项可以加快速度。 WooCommerce 多站点网络可让您将所有在线商店无缝地安置在一个屋檐…

回归预测 | MATLAB实现基于KELM-Adaboost核极限学习机结合AdaBoost多输入单输出回归预测

回归预测 | MATLAB实现基于KELM-Adaboost核极限学习机结合AdaBoost多输入单输出回归预测 目录 回归预测 | MATLAB实现基于KELM-Adaboost核极限学习机结合AdaBoost多输入单输出回归预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 1.MATLAB实现基于KELM-Adaboo…

三菱q以太网简单cpu通讯

产品概述 捷米特JM-ETH-QnA是一款经济型的以太网通讯处理器,是为满足日益增多的工厂设备信息化需求(设备网络监控和生产管理)而设计,用于三菱Q2A/Q2AS1/Q3A/Q4A等多个QnA系列PLC的以太网数据采集,非常方便构建生产管理…

Mysql+ETLCloud CDC+StarRocks实时数仓同步实战

一、业务需求及其痛点 大型企业需要对各种业务系统中的销售及营销数据进行实时同步分析,例如库存信息、对帐信号、会员信息、广告投放信息,生产进度信息等等,这些统计分析信息可以实时同步到StarRocks中进行分析和统计,StarRocks…

性能测试学习阶段性总结

目录 1.前言 2.概念部分 2.1不同角度看软件性能 2.2关键词 2.3测试的方法 2.4应用领域 3.性能测试过程模型(PTGM) 2.1测试前期准备 2.2测试工具引入 2.3测试计划 2.4测试设计与开发 2.5测试执行和管理 2.6测试分析 总结: 1.前言…

干货!3个技巧让你轻松增强客户实时聊天的体验感

在当今竞争激烈的商业环境中,提供出色的客户服务成为企业成功的关键要素之一。尤其是在实时聊天平台上,为客户提供优质的体验感,对于建立良好的客户关系和提高销售转化率至关重要。如果你还在苦恼如何增强用户体验感,苦恼如何增加…

zabbix-server监控mysql数据库及httpd服务、监控apache、监控ftp

目录 一、监控mysql数据库及httpd服务 1、为server.Zabbix.com添加服务模板 2、server.zabbix.com服务端 操作 3、编辑chk_mysql.sh脚本 4、server.zabbix.com测试 二、监控apache 1、获取键值 2、服务器操作 3、zabbix监控web端导入监控模板 4、server.zabbix.com添加…

B072-项目实战-用户模块--前台登录 三方登录

目录 前台登录-账号登录前端完成左上角显示用户信息配置前置拦截器、后置拦截器和不受限资源拦截器 三方登录-微信登录概述流程图用法代码实现步骤分析:实现准备代码前端login.htmlcallback.html 后端LoginController-微信登录LoginServiceImpl-微信登录解决回调域名不能跨域绑…

【教程】VSCode配置C++环境踩坑记录

时隔一年终于在VSCode配置好了C环境[MinGW] 基础部分踩坑坑0坑1坑2坑3 基础部分 就是安装VSCode,然后再安装C插件之类的,咱这就不罗嗦了,如果不清楚可以参考这篇文章:VSCode配置C/C环境 毕竟解决后面一些棘手的问题更重要。 踩坑…

算法竞赛字符串常用操作大全

算法竞赛字符串常用操作总结来啦~ 👊 大家好 我是寸铁💪 考前需要刷大量真题,大家一起相互监督,每日做N题,一起上岸吧✌️ ~ 冲刺蓝桥杯省一模板大全来啦 💥 ~ 蓝桥杯4月8号就要开始了 🙏 ~ 还没背熟模…

一文带你了解Spring中存入Bean和获取Bean的方式

0. Spring中的五大注解 上图中就是五大类注解对应的层,通过源码可以看到其他四个注解都基于Conponent 1. 存入 Bean Spring既然是一个包含众多工具方法的IoC容器,它是一个控制反转的容器,所以就需要将Bean对象存入到容器中,需要…