数据集003:猫类识别-12种猫分类数据集 (含数据集下载链接)

news2024/9/23 17:14:30

数据集简介:

训练集共有2160张猫的图片, 分为12类. train_list.txt是其标注文件

测试集共有240张猫的图片. 不含标注信息.

训练集图像(部分)

验证集图像(部分)

标签

部分代码:

# 定义训练数据集
class TrainData(Dataset):
    def __init__(self):
        super().__init__()
        self.color_jitter = T.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05)
        self.normalize = T.Normalize(mean=0, std=1)
        self.random_crop = T.RandomCrop(224, pad_if_needed=True)
    
    def __getitem__(self, index):
        # 读取图片
        image_path = train_paths[index]

        image = np.array(Image.open(image_path))    # H, W, C
        try:
            image = image.transpose([2, 0, 1])[:3]  # C, H, W
        except:
            image = np.array([image, image, image]) # C, H, W
        
        # 图像增广
        features = self.color_jitter(image.transpose([1, 2, 0]))
        features = self.random_crop(features)
        features = self.normalize(features.transpose([2, 0, 1])).astype(np.float32)

        # 读取标签
        labels = train_labels[index]

        return features, labels
    
    def __len__(self):
        return len(train_paths)

    
# 定义验证数据集
class ValidData(Dataset):
    def __init__(self):
        super().__init__()
        self.normalize = T.Normalize(mean=0, std=1)
    
    def __getitem__(self, index):
        # 读取图片
        image_path = valid_paths[index]

        image = np.array(Image.open(image_path))    # H, W, C
        try:
            image = image.transpose([2, 0, 1])[:3]  # C, H, W
        except:
            image = np.array([image, image, image]) # C, H, W
        
        # 图像变换
        features = cv2.resize(image.transpose([1, 2, 0]), (256, 256)).transpose([2, 0, 1]).astype(np.float32)
        features = self.normalize(features)

        # 读取标签
        labels = valid_labels[index]

        return features, labels
    
    def __len__(self):
        return len(valid_paths)
# 调用resnet50模型
paddle.vision.set_image_backend('cv2')
model = paddle.vision.models.resnet50(pretrained=True, num_classes=12)

# 定义数据迭代器
train_dataloader = DataLoader(train_data, batch_size=256, shuffle=True, drop_last=False)

# 定义优化器
opt = paddle.optimizer.Adam(learning_rate=1e-4, parameters=model.parameters(), weight_decay=paddle.regularizer.L2Decay(1e-4))

# 定义损失函数
loss_fn = paddle.nn.CrossEntropyLoss()

# 设置gpu环境
paddle.set_device('gpu:0')

# 整体训练流程
for epoch_id in range(15):
    model.train()
    for batch_id, data in enumerate(train_dataloader()):
        # 读取数据
        features, labels = data
        features = paddle.to_tensor(features)
        labels = paddle.to_tensor(labels)

        # 前向传播
        predicts = model(features)

        # 损失计算
        loss = loss_fn(predicts, labels)

        # 反向传播
        avg_loss = paddle.mean(loss)
        avg_loss.backward()

        # 更新
        opt.step()

        # 清零梯度
        opt.clear_grad()

        # 打印损失
        if batch_id % 2 == 0:
            print('epoch_id:{}, batch_id:{}, loss:{}'.format(epoch_id, batch_id, avg_loss.numpy()))
    model.eval()
    print('开始评估')
    i = 0
    acc = 0
    for image, label in valid_data:
        image = paddle.to_tensor([image])

        pre = list(np.array(model(image)[0]))
        max_item = max(pre)
        pre = pre.index(max_item)

        i += 1
        if pre == label:
            acc += 1
        if i % 10 == 0:
            print('精度:', acc / i)
    
    paddle.save(model.state_dict(), 'acc{}.model'.format(acc / i))
# 进行预测和提交
# 首先拿到预测文件的路径列表

def listdir(path, list_name):
    for file in os.listdir(path):
        file_path = os.path.join(path, file)
        if os.path.isdir(file_path):
            listdir(file_path, list_name)
        else:
            list_name.append(file_path)
test_path = []
listdir('cat_12_test', test_path)

# 加载训练好的模型
pre_model = paddle.vision.models.resnet50(pretrained=True, num_classes=12)
pre_model.set_state_dict(paddle.load('acc0.9285714285714286.model'))
pre_model.eval()

pre_classes = []
normalize = T.Normalize(mean=0, std=1)
# 生成预测结果
for path in test_path:
    image_path = path

    image = np.array(Image.open(image_path))    # H, W, C
    try:
        image = image.transpose([2, 0, 1])[:3]  # C, H, W
    except:
        image = np.array([image, image, image]) # C, H, W
    
    # 图像变换
    features = cv2.resize(image.transpose([1, 2, 0]), (256, 256)).transpose([2, 0, 1]).astype(np.float32)
    features = normalize(features)

    features = paddle.to_tensor([features])
    pre = list(np.array(pre_model(features)[0]))
    # print(pre)
    max_item = max(pre)
    pre = pre.index(max_item)
    print("图片:", path, "预测结果:", pre)
    pre_classes.append(pre)

print(pre_classes)

数据集链接:猫类识别-12种猫分类数据集(2400张)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1701720.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Devexpress中GridControl控件中的表格遍历逻辑问题

当我们在执行其他事件时,常常需要对GridControl控件里的表内容进行一个遍历逻辑判断,该文以确认2列中的值是否为空为例;首先在遍历模块当然是使用foreach作为遍历的基础,在这其中在此例中存在具体业务细节,需要对选中行…

设置 border 边框单侧样式 - HarmonyOSNext

设置 border 边框单侧样式,通过 api 中查看 border(value: BorderOptions): T; BorderOptions 又包含了若干个子属性 1.width?: EdgeWidths | Length; 2.color?: EdgeColors | ResourceColor; 3.radius?: BorderRadiuses | Length; 4.style?: EdgeStyles | BorderStyle; 其…

OrangePi Kunpeng Pro开发板初体验——家庭小型服务器

引言 在开源硬件的浪潮中,开发板作为创新的基石,正吸引着全球开发者的目光。它们不仅为技术爱好者提供了实验的平台,更为专业开发者带来了实现复杂项目的可能性。本文将深入剖析OrangePi Kunpeng Pro开发板,从开箱到实际应用&…

2024年【G2电站锅炉司炉】免费试题及G2电站锅炉司炉复审考试

题库来源:安全生产模拟考试一点通公众号小程序 2024年【G2电站锅炉司炉】免费试题及G2电站锅炉司炉复审考试,包含G2电站锅炉司炉免费试题答案和解析及G2电站锅炉司炉复审考试练习。安全生产模拟考试一点通结合国家G2电站锅炉司炉考试最新大纲及G2电站锅…

蓝桥杯第十四届国赛B组刷题笔记

A-0子2023: 题目: 小蓝在黑板上连续写下从 11 到 20232023 之间所有的整数,得到了一个数字序列: 𝑆12345678910111213...20222023S12345678910111213...20222023。 小蓝想知道 𝑆S 中有多少种子序列恰好等…

豆包模型最新数据评测!性能究竟如何?

豆包模型最新数据评测!性能究竟如何? 前言 就在5月27日,字节跳动旗下的豆包大模型在火山引擎原动力大会上正式发布,本次大会中豆包的模型能力也引发行业关注。 介绍豆包 豆包是一个多功能 AI 助手,为你的生活、学习、工…

免费 OSS 资源 Backblaze B2 使用最新指南

免费的对象存储资源日渐枯竭,Backblaze 是为数不多仍提供免费 OSS 的良心厂商。另外一个则是大名鼎鼎的 Cloudflare R2。虽然免费,但 Backblaze 也修改了政策:如果不验证信用卡的话是不能打开 Public 选项的,或者支付一美金。估计…

爬山算法教程(个人总结版)

背景与简介 爬山算法(Hill Climbing Algorithm)是一种用于解决优化问题的启发式搜索方法。它是一种局部搜索算法,通过不断尝试从当前解出发,在其邻域内寻找更优的解,直到无法找到更优解为止。该算法得名于其类似于登山…

青蛙跳台阶问题

本期介绍🍖 主要介绍:青蛙跳台阶问题,青蛙跳台阶与斐波那契数列的关系👀。 文章目录 1. 题目2. 递归解题思路3. 迭代解题思路 1. 题目 从前有一只青蛙他想跳台阶,有n级台阶,青蛙一次可以跳1级台阶&#xff…

MYSQL之安装

一,下载仓库包 wget -i -c https://dev.mysql.com/get/mysql80-community-release-el7-3.noarch.rpm二,安装仓库 yum -y install mysql80-community-release-el7-3.noarch.rpmsed -i s/gpgcheck1/gpgcheck0/g mysql-community.repo三,安装MY…

Python代码:十七、生成列表

1、题目 描述: 一串连续的数据用什么记录最合适,牛牛认为在Python中非列表(list)莫属了。现输入牛牛朋友们的名字,请使用list函数与split函数将它们封装成列表,再整个输出列表。 输入描述: …

lua 计算第几周

需求 计算当前赛季的开始和结束日期,2024年1月1日周一是第1周的开始,每两周是一个赛季。 lua代码 没有处理时区问题 local const 24 * 60 * 60 --一整天的时间戳 local server_time 1716595200--todo:修改服务器时间 local date os.date("*t…

Redis 事件机制 - AE 抽象层

Redis 服务器是一个事件驱动程序,它主要处理如下两种事件: 文件事件:利用 I/O 复用机制,监听 Socket 等文件描述符上发生的事件。这类事件主要由客户端(或其他Redis 服务器)发送网络请求触发。时间事件&am…

苗情灾情监控系统—提高农业生产效率

TH-MQ2苗情灾情监控系统是一种用于监测农作物生长状况和灾情的设备,通过实时监测和数据分析,帮助农民及时了解作物生长情况,采取相应的管理措施,提高农业生产效率和降低生产成本。 该系统通常由多种传感器、摄像头、数据传输模块等…

前端命令行部署

最近接了一个项目,发版本需要把dist包给后端部署服务,再加上产品那边需求不稳定,改了又改,一天要发好几个,不仅跟我配合的后端不胜其烦,本人也是很烦。最近在网上看到一个npm自主部署的包–deploy cli工具&…

QT C++ 模型视图结构 QTableView 简单例子

在Qt中,MVC模式被广泛使用于各种用户界面框架中,包括Qt的模型视图结构。Qt的模型视图结构是基于MVC模式设计的,其中包括了Model、View和Delegate三个部分。 QTableView是Qt模型视图结构中的一种视图,它用于以表格形式显示数据。 …

红队项目PinkysPalace格式字符串缓冲区溢出详解

简介 渗透测试-地基篇 该篇章目的是重新牢固地基,加强每日训练操作的笔记,在记录地基笔记中会有很多跳跃性思维的操作和方式方法,望大家能共同加油学到东西。 请注意: 本文仅用于技术讨论与研究,对于所有笔记中复现的…

如何使用OutputStream类实现文件的读写操作?

哈喽,各位小伙伴们,你们好呀,我是喵手。运营社区:C站/掘金/腾讯云;欢迎大家常来逛逛 今天我要给大家分享一些自己日常学习到的一些知识点,并以文字的形式跟大家一起交流,互相学习,一…

Tensors张量操作

定义Tensor 下面是一个常见的tensor,包含了里面的数值,属性,以及存储位置 tensor([[0.3565,0.1826,0.6719],[0.6695,0.5364,0.7057]],dtypetorch.float32,devicecuda:0)Tensor的属…

Vue2 Element-UI 分页组件el-pagination 修改 自带的total、跳转等默认文字

场景需求: Vue2 Element-UI 分页组件el-pagination 修改 自带的total、跳转等默认文字。如下图:默认提示字变成了英文,如何将其 变成 汉字提示呢? 解决方案: 1.方案1:修改DOM内容 不提倡此方案&#xf…