猫狗识别大模型——基于python语言

news2024/9/26 5:20:43

目录

1.猫狗识别

2.数据集介绍

3.猫狗识别核心原理

4.程序思路

4.1数据文件框架

 4.2 训练模型

4.3 模型使用

4.4 识别结果

5.总结


1.猫狗识别

人可以直接分辨出图片里的动物是猫还是狗,但是电脑不可以,要想让电脑也分辨出图片里的动物是猫还是小狗,就要使用到深度学习,电脑学习提取图片特征,进而学习区分图片里的是猫还是狗。

2.数据集介绍

程序用到的训练数据集是猫狗图像数据集,数据格式jpg格式,猫狗数据集:

https://www.kaggle.com/datasets/shaunthesheep/microsoft-catsvsdogs-dataset

 

3.猫狗识别核心原理

猫狗识别大模型是一种深度学习架构,主要用于图像分类任务,用来区分猫和狗这两种常见的宠物动物。

该模型基于卷积神经网络(CNN),它们通过学习大量的猫和狗图像数据集中的特征来进行训练,使其能够识别出输入图片中动物的种类。

训练过程中,模型会对猫的特有纹理、颜色模式、耳朵形状等特征进行学习,并形成区分猫狗的关键特征模板。一旦模型经过充分训练并优化,它可以准确地判断新的未知图片是属于猫还是狗。

应用此类模型的方式通常是将其部署到移动设备或者云端服务器上,用户上传一张照片后,模型会返回一个预测结果,指示图像中动物的类别。

4.程序思路

基于tensorflow模型框架以及卷积神经网络还有其他各种模块,划分训练集,微调集和测试机,对猫狗图片文件进行训练。

4.1数据文件框架

 4.2 训练模型

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import os

# 获取所有的GPU设备
gpus = tf.config.list_physical_devices('GPU')

# 检查是否有两个以上的GPU
if gpus and len(gpus) > 1:
    try:
        # 假设GPU1是独立GPU,设置可见设备为GPU1
        tf.config.set_visible_devices(gpus[1], 'GPU')
        tf.config.experimental.set_memory_growth(gpus[1], True)
    except RuntimeError as e:
        print(e)
else:
    print("没有检测到多个GPU,或者系统只存在一个GPU。")

# 定义数据目录
data_dir = './pythonProject/ai_modle_win/cats vs dogs/dataset'  # 请替换为你的数据集路径
train_dir = os.path.join(data_dir, 'train')
validation_dir = os.path.join(data_dir, 'validation')
test_dir = os.path.join(data_dir, 'test')

# 图像数据生成器
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)

validation_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

# 计算样本数量
def count_files(directory):
    total_files = 0
    for root, dirs, files in os.walk(directory):
        total_files += len(files)
    return total_files

train_samples = count_files(train_dir)
validation_samples = count_files(validation_dir)
test_samples = count_files(test_dir)

# 数据生成器
def create_generator(datagen, directory, target_size, batch_size, class_mode):
    generator = datagen.flow_from_directory(
        directory,
        target_size=target_size,
        batch_size=batch_size,
        class_mode=class_mode
    )
    # 包装生成器以处理损坏的图像文件
    while True:
        try:
            yield next(generator)
        except (OSError, StopIteration) as e:
            print(f"跳过无法读取的图像文件:{e}")
            continue

train_generator = create_generator(train_datagen, train_dir, (150, 150), 32, 'binary')
validation_generator = create_generator(validation_datagen, validation_dir, (150, 150), 32, 'binary')
test_generator = create_generator(test_datagen, test_dir, (150, 150), 32, 'binary')

# 定义模型
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),
    MaxPooling2D(2, 2),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Flatten(),
    Dropout(0.5),
    Dense(512, activation='relu'),
    Dense(1, activation='sigmoid')
])

model.compile(loss='binary_crossentropy',
              optimizer=Adam(learning_rate=0.001),
              metrics=['accuracy'])

# 训练模型
history = model.fit(
    train_generator,
    steps_per_epoch=train_samples // 32,  # 将结果转换为整数
    validation_data=validation_generator,
    validation_steps=validation_samples // 32,  # 将结果转换为整数
    epochs=5
)

# 保存模型
model.save('./pythonProject/ai_modle_win/cats vs dogs/cat_dog.h5')

# 评估模型
test_loss, test_acc = model.evaluate(test_generator, steps=test_samples // 32)
print(f'Test accuracy: {test_acc:.2f}')

# 可视化训练结果
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.figure(figsize=(12, 9))

plt.subplot(1, 2, 1)
plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()

注意更改文件路径!!!

4.3 模型使用

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import os

# 加载已保存的模型
model = load_model('./pythonProject/ai_modle_win/cats vs dogs/cat_dog.h5')

# 预测函数
def predict_image(img_path):
    img = image.load_img(img_path, target_size=(150, 150))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array /= 255.0

    prediction = model.predict(img_array)
    if prediction[0] > 0.5:
        print(f"The image at {img_path} is a Dog")
    else:
        print(f"The image at {img_path} is a Cat")

# 示例用法
test_image_path = './pythonProject/ai_modle_win/cats vs dogs/30.jpg'  # 替换为你的测试图片路径
predict_image(test_image_path)

使用上述训练的模型进行图片识别,注意文件路径。

4.4 识别结果

5.总结

通过构造猫狗图片数据集,然后使用深度学习训练一个猫狗识别大模型,你也快来试一试吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2137835.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【linux-Day2】linux的基本指令<上>

【linux-Day2】linux的基本指令<上> 一键查看操作系统的重要地位linux下的基本指令&#x1f4e2;ls&#xff1a;显示当前目录下所有的子目录和文件&#x1f4e2;pwd&#xff1a;显示用户当前所在的目录&#xff0c;在windows中&#xff0c;相当于显示当前目录的绝对路径。…

AI工具一键制作爆火的“汉语新解“卡片!

最近出现了一种很火的新玩法“汉语新解”。 AI把一个词汇&#xff0c;以一种特殊的视角&#xff0c;用幽默、讽刺等方式重新定义&#xff0c;然后生成一张精美的卡片。 这个玩法和之前我发的的吐槽工具玩法类似&#xff0c;主打的就是一个新颖、情绪释放。 今天教大家怎么快速…

Python 解析 JSON 数据

1、有如下 JSON 数据&#xff0c;存放在 data.json 文件&#xff1a; [{"id":1, "name": "小王", "gender": "male", "score": 96.8}, {"id":2, "name": "小婷", "gender&qu…

大模型探索式轨迹优化:基于试错的自主智能体学习新方法

人工智能咨询培训老师叶梓 转载标明出处 现有的开源LLMs在构建智能体方面的效果远不如GPT-4。标准的构建开源LLM智能体的方法涉及模仿学习&#xff0c;即基于专家轨迹对LLMs进行微调。然而&#xff0c;这些方法完全依赖于专家演示&#xff0c;由于对目标环境探索不足而可能产生…

windows11+ubuntu20.04.6双系统安装

记录win11和ubuntu20.04.6在单个硬盘上安装的主要流程 系统说明 BIOS模式&#xff1a; UEFI 硬盘&#xff1a; 1TB固态 内存&#xff1a; 32GB 步骤 1、 准备两个不小于16GB的U盘&#xff0c;一个用于装Windows&#xff0c;一个用于装ubuntu&#xff0c;注意8G的U盘虽然能够…

操作系统知识点-进程与线程,一文搞懂!

本文图片均来自王道考研 一、进程的概念、组成和特征 进程&#xff08;Process&#xff09;是计算机中的一个核心概念&#xff0c;它是对正在运行的程序的一个抽象表示。在计算机科学中&#xff0c;一个进程是系统进行资源分配和调度的一个独立单元&#xff0c;是操作系统结构…

Python数据分析 Pandas基本操作

Python数据分析 Pandas基本操作 一、Series基础操作 ​ Series是pandas的基础数据结构&#xff0c;它可以用来创建一个带索引的一维数组&#xff0c;下面开始介绍它的基础操作 1、创建Series 1&#xff09;使用数据创建Series&#xff1a; import pandas as pd pd.Series(1…

学习笔记JVM篇(三)

一、垃圾回收机制 垃圾回收&#xff08;Garbage Collection&#xff09;机制&#xff0c;是自动回收无用对象从而释放内存的一种机制。Java之所以相对简单&#xff0c;很大程度是归功于垃圾回收机制。&#xff08;例如C语言申请内存后要手动的释放&#xff09; 优点&#xff…

基于less和scss 循环生成css

效果 一、less代码 复制代码 item-count: 12; // 生成多少个 .item 类.item-loop(n) when (n > 0) {.icon{n} {background: url(../../assets/images/menu/icon{n}.png) no-repeat;background-size: 100% 100%;}.item-loop(n - 1);}.item-loop(item-count);二、scss代码 f…

在线查看 Android 系统源代码 Android Code Search

在线查看 Android 系统源代码 Android Code Search 1. Android Code Search2. Android2.1. platform/superproject2.2. build/envsetup.sh2.3. build/make/envsetup.sh References 1. Android Code Search https://cs.android.com/ Android https://cs.android.com/android An…

PCIe进阶之TL:Address Spaces, Transaction Types, and Usage

1 Transaction Layer Overview 如上图为PCIe设备的一个分层结构,从上层逻辑看,事务层的关键点是: 流水线式的完整的 split-transaction 协议事务层数据包(TLP)的排序和处理基于信用的流控制机制可选支持的数据中毒功能和端到端数据完整性检测功能事务层包含以下内容: TLP…

【C++】标准库IO查漏补缺

【C】标准库 IO 查漏补缺 文章目录 系统I/O1. 概述2. cout 与 cerr3. cerr 和 clog4. 缓冲区5. 与 printf 的比较 系统I/O 1. 概述 标准库提供的 IO 接口&#xff0c;包含在 iostream 文件中 输入流: cin输出流&#xff1a;cout / cerr / clog。 输入流只有一个 cin&#x…

MFC工控项目实例之十六输入信号验证

承接专栏《MFC工控项目实例之十五定时刷新PC6325A模拟量输入》 验证选定的输入信号实时状态 在BoardTest.cpp文件中添加代码 void CBoardTest::OnButton2() {// TODO: Add your control notification handler code hereisThreadBegin true; //运行线程执行pThre…

medium_socnet

0x00前言 靶场要安装在virtualbox &#xff08;最新版&#xff09;。否者会出现一些问题。 攻击机&#xff1a;kali2024 靶机&#xff1a;medium_socnet 0x01信息搜集 因为把靶机和虚拟机啊放在了同一网段。 所以我先使用了 arp-scan,查看有多少同一网段ipUP 。 经过推断…

OSS对象资源管理

1、登录aliyun 1.1、什么是OSS&#xff1f;有什么用&#xff1f; OSS 是“Object Storage Service”的缩写&#xff0c;中文常称为“对象存储服务”。OSS 是一种互联网云存储服务&#xff0c;主要用于海量数据的存储与管理。 相较于nginx&#xff0c;OSS更灵活&#xff0c;不…

点云深度学习系列:Sam2Point——基于提示的点云分割

文章&#xff1a;SAM2POINT:Segment Any 3D as Videos in Zero-shot and Promptable Manners 代码&#xff1a;https://github.com/ZiyuGuo99/SAM2Point Demo&#xff1a;https://huggingface.co/spaces/ZiyuG/SAM2Point 1&#xff09;摘要 文章介绍了SAM2POINT&#xff0c;这是…

跟《经济学人》学英文:2024年09月14日这期 People are splurging like never before on their pets

People are splurging like never before on their pets Would you buy your furry companion a cologne? like never before&#xff1a;从未有过&#xff1b;未曾发生过 splurge&#xff1a;挥霍&#xff1b;浪费&#xff1b;破费&#xff1b;大量花费&#xff1b;过度消…

python 读取excel数据存储到mysql

一、安装依赖 pip install mysql-connector-python 二、mysql添加表students CREATE TABLE students (ID int(11) NOT NULL AUTO_INCREMENT,Name varchar(50) DEFAULT NULL,Sex varchar(50) DEFAULT NULL,PRIMARY KEY (ID) ) ENGINEInnoDB AUTO_INCREMENT13 DEFAULT CHARSETu…

S32K3 工具篇5:如何使用lauterbach下载调试elf文件

S32K3 工具篇5&#xff1a;如何使用lauterbach下载调试elf文件 一&#xff0c;利用trace32现有flash脚本烧录elf二&#xff0c;debug 现有elf文件 之前写过如何在S32DS中使用lauterbach下载&#xff0c;但是对于RTD EB MCAL的代码&#xff0c;通常情况下是使用命令的方式去编译…

Spring Boot母婴商城:安全、便捷、高效

2 相关技术 2.1 SSM框架介绍 本课题程序开发使用到的框架技术&#xff0c;英文名称缩写是SSM&#xff0c;在JavaWeb开发中使用的流行框架有SSH、SSM、SpringMVC等&#xff0c;作为一个课题程序采用SSH框架也可以&#xff0c;SSM框架也可以&#xff0c;SpringMVC也可以。SSH框架…