基于flask的猫狗图像预测案例

news2024/9/21 14:32:09

📚博客主页:knighthood2001
公众号:认知up吧 (目前正在带领大家一起提升认知,感兴趣可以来围观一下)
🎃知识星球:【认知up吧|成长|副业】介绍
❤️如遇文章付费,可先看看我公众号中是否发布免费文章❤️
🙏笔者水平有限,欢迎各位大佬指点,相互学习进步!

假设,你有模型,有训练好的模型文件,有模型推理代码,就可以把他放到flask上进行展示。

项目架构

在这里插入图片描述

  • index.html是模板文件
  • app.py是项目运行的入口
  • best_model.pth是训练好的模型参数
  • model.py是神经网络模型,这里采用的是GoogleNet网络。
  • model_reasoning.py是模型推理,通过这里面的代码,我们可以在本地进行猫狗图片的预测。

运行图

在这里插入图片描述

点击选择文件
在这里插入图片描述
图片下面就显示预测结果了。
在这里插入图片描述

项目完整代码与讲解

index.html

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>图像分类</title>
    <style>
        body {
            font-family: Arial, sans-serif;
            margin: 20px;
        }
        #result {
            margin-top: 10px;
        }
        #preview-image {
            max-width: 400px;
            margin-top: 20px;
        }
    </style>
</head>
<body>
    <h1>图像分类</h1>
    <form id="upload-form" action="/predict" method="post" enctype="multipart/form-data">
        <input type="file" name="file" accept="image/*" onchange="previewImage(event)">
        <input type="submit" value="预测">
    </form>

    <img id="preview-image" src="" alt="">
    <br>
    <div id="result"></div>

    <script>
        document.getElementById('upload-form').addEventListener('submit', async (e) => {
            e.preventDefault();  // 阻止默认的表单提交行为
            const formData = new FormData(); // 创建一个新的FormData对象,用于封装表单数据
            formData.append('file', document.querySelector('input[type=file]').files[0]);  // 添加表单数据
            // 使用fetch API发送POST请求到'/predict'路径,并将formData作为请求体
            const response = await fetch('/predict', {
                method: 'POST',
                body: formData
            });
            // 获取响应的JSON数据
            const result = await response.json();
            // 将预测结果显示在页面上ID为'result'的元素中
            document.getElementById('result').innerText = `预测结果: ${result.prediction}`;
        });

        function previewImage(event) {
            const file = event.target.files[0];  // 获取上传的文件对象
            const reader = new FileReader();  // 创建一个FileReader对象,用于读取文件内容

            // 清空上一次的预测结果
            document.getElementById('result').innerText = '';
            
            // 当文件读取完成后,将文件内容显示在页面上ID为'preview-image'的元素中
            reader.onload = function(event) {
                document.getElementById('preview-image').setAttribute('src', event.target.result);
            }
            // 如果用户选择了文件,则开始读取文件内容
            if (file) {
                reader.readAsDataURL(file); // 将文件读取为DataURL格式,这样可以直接用作img元素的src属性
            }
        }
    </script>
</body>
</html>

前端我练的不多,很多解释已经在代码中讲了。

model.py

这是GoogleNet的网络架构

import torch
from torch import nn
from torchsummary import summary
# 定义一个Inception模块
class Inception(nn.Module):
    def __init__(self, in_channels, c1, c2, c3, c4):  # 这些参数,所在的位置都会发送变化,所有需要这个参数
        super(Inception, self).__init__()
        self.ReLU = nn.ReLU()

        # 路线1,单1×1卷积层
        self.p1_1 = nn.Conv2d(in_channels=in_channels, out_channels=c1, kernel_size=1)

        # 路线2,1×1卷积层, 3×3的卷积
        self.p2_1 = nn.Conv2d(in_channels=in_channels, out_channels=c2[0], kernel_size=1)
        self.p2_2 = nn.Conv2d(in_channels=c2[0], out_channels=c2[1], kernel_size=3, padding=1)

        # 路线3,1×1卷积层, 5×5的卷积
        self.p3_1 = nn.Conv2d(in_channels=in_channels, out_channels=c3[0], kernel_size=1)
        self.p3_2 = nn.Conv2d(in_channels=c3[0], out_channels=c3[1], kernel_size=5, padding=2)

        # 路线4,3×3的最大池化, 1×1的卷积
        self.p4_1 = nn.MaxPool2d(kernel_size=3, padding=1, stride=1)
        self.p4_2 = nn.Conv2d(in_channels=in_channels, out_channels=c4, kernel_size=1)


    def forward(self, x):
        p1 = self.ReLU(self.p1_1(x))
        p2 = self.ReLU(self.p2_2(self.ReLU(self.p2_1(x))))
        p3 = self.ReLU(self.p3_2(self.ReLU(self.p3_1(x))))
        p4 = self.ReLU(self.p4_2(self.p4_1(x)))
        return torch.cat((p1, p2, p3, p4), dim=1)

class GoogLeNet(nn.Module):
    def __init__(self, Inception, in_channels, out_channels):
        super(GoogLeNet, self).__init__()
        self.b1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.b2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.b3 = nn.Sequential(
            Inception(192, 64, (96, 128), (16, 32), 32),
            Inception(256, 128, (128, 192), (32, 96), 64),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.b4 = nn.Sequential(
            Inception(480, 192, (96, 208), (16, 48), 64),
            Inception(512, 160, (112, 224), (24, 64), 64),
            Inception(512, 128, (128, 256), (24, 64), 64),
            Inception(512, 112, (128, 288), (32, 64), 64),
            Inception(528, 256, (160, 320), (32, 128), 128),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.b5 = nn.Sequential(
            Inception(832, 256, (160, 320), (32, 128), 128),
            Inception(832, 384, (192, 384), (48, 128), 128),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(1024, out_channels))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

                elif isinstance(m, nn.Linear):
                    nn.init.normal_(m.weight, 0, 0.01)
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.b1(x)
        x = self.b2(x)
        x = self.b3(x)
        x = self.b4(x)
        x = self.b5(x)
        return x


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GoogLeNet(Inception, 1, 10).to(device)
    print(summary(model, (1, 224, 224)))

model_reasoning.py

import torch
from torchvision import transforms
from model import GoogLeNet, Inception
from PIL import Image

def test_model(model, test_file):
    # 设定测试所用到的设备,有GPU用GPU没有GPU用CPU
    device = "cuda" if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    classes = ['猫', '狗']
    print(classes)
    image = Image.open(test_file)

    # normalize = transforms.Normalize([0.162, 0.151, 0.138], [0.058, 0.052, 0.048])
    # # 定义数据集处理方法变量
    # test_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), normalize])

    # 定义数据集处理方法变量
    test_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
    image = test_transform(image)

    # 添加批次维度,变成[1,3,224,224]
    image = image.unsqueeze(0)

    with torch.no_grad():
        model.eval()
        image = image.to(device)  # 图片也要放到设备当中
        output = model(image)
        print(output.tolist())
        pre_lab = torch.argmax(output, dim=1)
        result = pre_lab.item()
    print("预测值:", classes[result])
    return classes[result]

def test_special_model(best_model_file, test_file):
    # 加载模型
    model = GoogLeNet(Inception, in_channels=3, out_channels=2)
    model.load_state_dict(torch.load(best_model_file))

    # 模型的推理判断
    return test_model(model, test_file)

if __name__ == "__main__":
    # # 加载模型
    # model = GoogLeNet(Inception, in_channels=3, out_channels=2)
    # model.load_state_dict(torch.load('best_model.pth'))
    # # 模型的推理判断
    # test_model(model, "test_data/images.jfif")

    test_special_model("best_model.pth", "static/1.jpg")

这段代码与之前的模型推理代码不同的是,我添加了test_special_model函数,方便后续app.py中可以直接调用这个函数进行模型推理。

app.py

import os
from flask import Flask, request, jsonify, render_template

from model_reasoning import test_special_model
from model_reasoning import test_model
app = Flask(__name__)

# 定义路由
@app.route('/')
def index():
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # 获取上传的文件
        file = request.files['file']
        if file:
            # 调用模型进行预测
            # # 加载模型
            # model = GoogLeNet(Inception, in_channels=3, out_channels=2)
            # basedir = os.path.abspath(os.path.dirname(__file__))
            #
            # model.load_state_dict(torch.load(basedir + '/best_model.pth'))
            # result = test_model(model, file)

            basedir = os.path.abspath(os.path.dirname(__file__))
            best_model_file = basedir + '/best_model.pth'
            result = test_special_model(best_model_file, file)
            return jsonify({'prediction': result})
        else:
            return jsonify({'error': 'No file found'})


if __name__ == '__main__':
    app.run(debug=True)

如果没有上文中的test_special_model函数,那么这里你就需要

   # 加载模型
   model = GoogLeNet(Inception, in_channels=3, out_channels=2)
   basedir = os.path.abspath(os.path.dirname(__file__))
   
   model.load_state_dict(torch.load(basedir + '/best_model.pth'))
   result = test_model(model, file)

并且还需要导入相应的库。

best_model.pth

最重要的是,你需要训练好的一个模型。

有需要的,可以联系我,我直接把这个项目代码发你。省得你还需要配置项目架构。

小插曲

我为什么会使用绝对路径,因为我在使用相对路径后,代码提示找不到这个路径。

    basedir = os.path.abspath(os.path.dirname(__file__))
    best_model_file = basedir + '/best_model.pth'

然后,我刚刚又试了一下,发现使用相对路径,又可以运行成功了。

真是不可思议(这个小插曲花了我大半个小时)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1908248.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

遥感分类产品精度验证之TIF验证TIF

KKB_2020.tif KKB_2020_JRC.tif kkb.geojson 所用到的包&#xff1a;&#xff08;我嫌geopandas安装太麻烦colab做的。。 import rasterio import geopandas as gpd import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.metrics import c…

【Threejs进阶教程-着色器篇】3. Uniform的基本用法2与基本地球昼夜效果

Uniform的基本用法2 关于本Shader教程前两篇地址&#xff0c;请按顺序学习本篇使用到的资源用uniform传递纹理代码分析texture类型的uniform在shader中接收uniformtexture2D()处理图片压缩修改wrapS和wrapT 切换成夜景效果切换Mix() 昼夜切换升级改动代码效果分析解决球体分界线…

Linux dig命令常见用法

Linux dig命令常见用法 一、dig安装二、dig用法 DIG命令(Domain Information Groper命令)是常用的域名查询工具&#xff0c;通过此命令&#xff0c;你可以实现域名查询和域名问题的定位&#xff0c;对于网络管理员和在域名系统(DNS)领域工作的小伙伴来说&#xff0c;它是一个非…

【大模型LLM面试合集】大语言模型架构_attention

1.attention 1.Attention 1.1 讲讲对Attention的理解&#xff1f; Attention机制是一种在处理时序相关问题的时候常用的技术&#xff0c;主要用于处理序列数据。 核心思想是在处理序列数据时&#xff0c;网络应该更关注输入中的重要部分&#xff0c;而忽略不重要的部分&…

YOLOv8改进 | 注意力机制| 引入多尺度分支来增强特征表征的注意力机制 【CVPR2021】

秋招面试专栏推荐 &#xff1a;深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转 &#x1f4a1;&#x1f4a1;&#x1f4a1;本专栏所有程序均经过测试&#xff0c;可成功执行&#x1f4a1;&#x1f4a1;&#x1f4a1; 专栏目录 &#xff1a;《YOLOv8改进有效…

顶会FAST24最佳论文|阿里云块存储架构演进的得与失-5.其他话题分享

4.1 可用性威胁与解决方案 挑战1&#xff1a;BlockServer故障影响众多VD 问题描述&#xff1a;单个BlockServer的故障可能会影响到多个虚拟磁盘&#xff08;VDs&#xff09;的正常运作&#xff0c;这是由于传统架构中BlockServer承担了过多的职责&#xff0c;其稳定性直接关系…

Eyes Wide Shut Exploring the Visual Shortcomings of Multimodal LLMs

Eyes Wide Shut? Exploring the Visual Shortcomings of Multimodal LLMs 近两年多模态大模型&#xff08;Multimodal LLM&#xff0c;MLLM&#xff09;取得了巨大的进展&#xff0c;能够基于图片与人类对话&#xff0c;展现出强大的识别甚至推理能力。然而&#xff0c;在某些…

字符串操作(CC++)

字符串操作 1. C语言基本使用字符串操作函数 2. C3. 对比 C语言和C在字符串操作上有很大的不同&#xff0c;这主要是因为C标准库提供了更强大、更易于使用的字符串类&#xff08;std::string&#xff09;&#xff0c;而C语言主要依赖字符数组和一系列标准库函数&#xff08;如s…

Halcon Ean13 一维码读取

一 EAN码介绍 1 EAN码定义: EAN码是国际物品编码协会制定的一种商品用条码&#xff0c;通用于全世界。EAN码符号有标准版&#xff08;EAN-13&#xff09;和缩短版&#xff08;EAN-8&#xff09;两种。标准版表示13位数字&#xff0c;又称为EAN13码&#xff0c;缩短版表示8位数…

SSM慢性病患者健康管理系统-计算机毕业设计源码04877

目 录 摘要 1 绪论 1.1 研究意义 1.2研究目的 1.3论文结构与章节安排 2 慢性病患者健康管理系统系统分析 2.1 可行性分析 2.1.1 技术可行性分析 2.1.2 经济可行性分析 2.1.3 法律可行性分析 2.2 系统功能分析 2.2.1 功能性分析 2.2.2 非功能性分析 2.3 系统用例分…

day02_员工管理

文章目录 新增员工需求分析和设计代码开发功能测试代码完善录入的用户名已存在&#xff0c;抛出异常后没有处理新增员工的时候&#xff0c;创建人id和修改人id设置为了固定值ThreadLocal&#xff08;面试题&#xff09; 分页查询问题解决 启用禁用员工账号需求和分析代码设计 编…

分享外贸工作中常用英文标准表达和英文语句

常用英文表达 报拉格斯最低到岸价 quote the lowest price CIF Lagos经营纺织品多年 be in the line of textiles for many years货物受欢迎 the goods are very popular with customers / have met with a warm reception /be well received/accepted/ enjoy a wide populari…

2024年7月2日~2024年7月8日周报

目录 一、前言 二、完成情况 2.1 吴恩达机器学习系列课程 2.1.1 分类问题 2.1.2 假说表示 2.1.3 判定边界 2.2 学习数学表达式 2.3 论文写作情况 2.3.1 题目选取 2.3.2 摘要 2.3.3 关键词 2.3.4 引言部分 2.3.4 文献综述部分 三、下周计划 3.1 存在的问题 3.2 …

Nacos注册中心相关错误记录

文章目录 1&#xff0c;com.alibaba.cloud:spring-cloud-starter-alibaba-nacos-discovery:jar:unknown was not found1.1 定位及解决方案1.2&#xff0c;简要说明dependencyManagement的作用 2&#xff0c;nacos启动失败2.1 解决方案 1&#xff0c;com.alibaba.cloud:spring-c…

七大AI绘画软件大比拼!高效且免费!

在当今数字时代&#xff0c;人工智能技术广泛应用于各个行业&#xff0c;包括艺术创作。人工智能绘画软件可以帮助艺术家更快、更有效地创作。然而&#xff0c;市场上人工智能绘画软件的选择也令人眼花缭乱。那么&#xff0c;哪种人工智能绘画软件更好呢&#xff1f;需要明确的…

《UDS协议从入门到精通》系列——图解0x84:安全数据传输

《UDS协议从入门到精通》系列——图解0x84&#xff1a;安全数据传输 一、简介二、数据包格式2.1 服务请求格式2.2 服务响应格式2.2.1 肯定响应2.2.2 否定响应 Tip&#x1f4cc;&#xff1a;本文描述中但凡涉及到其他UDS服务的&#xff0c;均提供专栏内文章链接跳转方式以便快速…

平安消保在行动 | 守护每一个舒心笑容 不负每一场双向奔赴

“要时刻记得以消费者为中心&#xff0c;把他们当做自己的朋友&#xff0c;站在他们的角度去思考才能更好地解决问题。” 谈及如何成为一名合格的消费者权益维护工作人员&#xff0c;平安养老险深圳分公司负责咨诉工作的庞宏霄认为&#xff0c;除了要具备扎实的专业技能和沟通…

大舍传媒:如何在海外新闻媒体发稿报道摩洛哥?

引言 作为媒体行业的专家&#xff0c;我将分享一些关于在海外新闻媒体发稿报道摩洛哥的干货教程。本教程将带您深入了解三个重要的新闻媒体平台&#xff1a;Mediterranean News、Morocco News和North African News。 地中海Mediterranean News Mediterranean News是一个知名…

景芯SoC训练营DFT debug

景芯训练营VIP学员在实践课上遇到个DFT C1 violation&#xff0c;导致check_design_rule无法通过&#xff0c;具体报错如下&#xff1a; 遇到这个问题第一反映一定是确认时钟&#xff0c;于是小编让学员去排查add_clock是否指定了时钟&#xff0c;指定的时钟位置是否正确。 景芯…

Spring AOP源码篇四之 数据库事务

了解了Spring AOP执行过程&#xff0c;再看Spring事务源码其实非常简单。 首先从简单使用开始, 演示Spring事务使用过程 Xml配置&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema…