flask搭建微服务器并训练CNN水果识别模型应用于网页

news2024/9/20 3:29:30

一. 搭建flask环境

概念

  • flask:一个轻量级 Web 应用框架,被设计为简单、灵活,能够快速启动一个 Web 项目。
  • CNN:深度学习模型,用于处理具有网格状拓扑结构的数据,如图像(2D网格)和视频(3D网格)。
  • PyTorch:开源的机器学习库,应用于如计算机视觉和自然语言处理等领域的深度学习。

flask环境搭建操作步骤: 

  1. pycharm终端创建新的虚拟环境:python -m venv virtualName 。
  2. 激活虚拟环境。
  3. 在虚拟环境中安装flask。
  4. 运行第一个前端网页。
流程图例

1.

2.

3.

4.

步骤4代码:
from flask import Flask
app = Flask(__name__)

@app.route('/')
def hello_world():
    return "<h1>hello world!</h1>"

if __name__ == '__main__':
    app.run(debug=True)



二. 训练水果模型

水果识别CNN训练操作步骤: 

  1. 准备数据集(kaggle官网可下载)。
  2. 安装pyrorch。
  3. 使用pytorch的nn模型定义参数。
  4. 训练模型。
  5. 得到训练好的pth模型。
流程图例

1.

2.

5.

步骤3代码:
import torch
from torch import nn

# 水果分类模型参数配置

class NumberNet(nn.Module):
    def __init__(self, device, classes=10):
        super().__init__()
        if device is None:
            device = torch.device("cpu")
            if torch.cuda.is_available():
                device = torch.device("cuda:0")
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 16, 3),  # 100x100 -> 98x98
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 98x98 -> 49x49
            nn.Conv2d(16, 32, 3, padding=1),  # 49x49 -> 49x49
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 49x49 -> 24x24
            nn.Conv2d(32, 64, 3, padding=1),  # 24x24 -> 24x24
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 24x24 -> 12x12
            nn.Flatten(),
            nn.Dropout(),
            nn.Linear(64 * 12 * 12, 1024),  # 调整线性层的输入特征数量
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(1024, classes),
            nn.LogSoftmax(dim=-1)
        )

    def forward(self, X):
        return self.cnn(X)
步骤4代码:
import torch
from torch import nn
from NumberNet import NumberNet
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split



# 水果分类训练
# 数据集配置
# 假设 NumberNet 模型期望的输入是 3 通道彩色图像
transform = transforms.Compose([
    transforms.ToTensor(),  # 这将把 PIL 图像或 NumPy 数组转换为张量,并且范围从 [0, 255] 标准化到 [0.0, 1.0]
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 可选:标准化
])

# 加载项目目录下的水果文件夹
img_dataset = ImageFolder("../fruits", transform=transform)
len_dataset = len(img_dataset)
train_size = int(len_dataset * 0.8)
valid_size = len_dataset - train_size
train_dataset, valid_dataset = random_split(img_dataset, [train_size, valid_size])

# 数据加载器
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1000, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=1000)
# batch_total 应该是 dataloader 的总批次数量,这里计算方式不正确
batch_total = len(train_dataloader)  # 应该直接使用 len(dataloader)

# 使用conda或者cpu开始训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 10
model = NumberNet(device)
criterion = nn.CrossEntropyLoss()
adam = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(epochs):
    losses = []
    for batch_num, (images, labels) in enumerate(train_dataloader, start=1):  # 使用 enumerate 来获取批次编号
        adam.zero_grad()
        predict = model(images.to(device))
        loss = criterion(predict, labels.to(device))
        print(f"batch size: {batch_num} / {batch_total} -- loss: {loss.item():.4f} ")
        losses.append(loss.item())
        loss.backward()
        adam.step()
    acc_list = []
    with torch.no_grad():
        for images, labels in valid_dataloader:
            predict = model(images.to(device))
            result = torch.argmax(predict, dim=-1)
            acc = (result == labels.to(device)).float().mean()  # 使用 torch 的函数来计算准确率
            acc_list.append(acc.item())

    total_acc = sum(acc_list) / len(acc_list)
    total_loss = sum(losses) / batch_total
    print(f"epoch: {epoch + 1} / {epochs} -- loss: {total_loss:.4f} -- acc: {total_acc:.4f} ")

# 保存模型参数,而不是整个模型
torch.save(model, "../readyModel/model.pth")

 三. 将训练好的模型嵌入flask后端

实现水果识别web操作步骤: 

  1. 在虚拟化环境下创建.py后端启动文件,并且创建模型实例,同时将训练好的.pth文件放入代码对应的文件路径。
  2. 创建index.html文件,作为后续前端文件。
  3. 在前端代码和后端代码使用Jason进行路由。
  4. 启动项目,实现功能。
 步骤1代码:
from flask import Flask, render_template, request, jsonify
import time
import torch
import cv2
import numpy as np
from FruitNet import FruitNet  # 确保FruitNet定义是正确的

app = Flask(__name__)

# 定义设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 创建模型实例
model = FruitNet(device=device, classes=5)  # 确保类别数与训练时一致
model.to(device)

# 加载训练好的权重
model.load_state_dict(torch.load("static/fruit_model.pth"))  # 确保权重文件名为fruit_model.pth
model.eval()  # 设置模型为评估模式


def predict_image(image_data):
    # 通过cv2加载图片数据
    img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)

    # 将图像从BGR转换为RGB格式(因为OpenCV默认加载的是BGR格式)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # 调整图片大小到100x100(与训练时的输入大小一致)
    img = cv2.resize(img, (100, 100))

    # 在第一个位置增加一个维度,形成batch大小为1
    img = np.expand_dims(img, 0)

    # 将numpy对象转化为pytorch的tensor对象
    img = torch.from_numpy(img)

    # 调整图像通道顺序
    img = torch.permute(img, [0, 3, 1, 2])  # 转换为 (batch_size, channels, height, width)

    # 测试最终的结果
    with torch.no_grad():  # 关闭梯度计算
        img = img.to(device).float()  # 确保输入是float类型,并发送到指定设备
        predict = model(img)
        predicted_class = torch.argmax(predict, dim=-1).item()

    # 定义水果类别标签
    fruit_classes = ["Apple Golden 1", "Banana", "Pear Red", "Tomato Heart", "Watermelon"]  # 根据你的数据集定义类别标签

    # 输出预测的水果种类
    predicted_fruit = fruit_classes[predicted_class]
    return predicted_fruit
 步骤2代码:
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>水果识别</title>
    <link rel="stylesheet" href="./static/css/index.css">
    <script src="./static/js/jquery-3.7.1.min.js"></script>
</head>
<body>
<div class="main">
    <div>
        <!-- 显示上传的图片 -->
        <div class="upload-img">
            <img id="upload-img" src="" alt="请上传图片"/>
        </div>

        <!-- 表单用于上传图片 -->
        <form   id="upload-btn" action="/upload" method="post" enctype="multipart/form-data">
            <input style="margin-left: 120px" type="file" name="the_file" id="selectImg"> <br/>
            <input type="submit" value="识别该水果">
        </form>
    </div>

    <!-- 显示识别结果 -->
    <div class="result">
        <h2 id="result-show"></h2>
    </div>
</div>

<script>
    // 将文件转为 Base64 用于图片预览
    function convertToBase64(file, callback) {
        const reader = new FileReader();
        reader.onload = function(e) {
            callback(e.target.result);
        };
        reader.readAsDataURL(file);
    }

    $(function(){
        // 处理图片选择后的显示
        $("#selectImg").change(function(ev){
            const file = $(this)[0].files[0];
            if (file) {
                convertToBase64(file, function(base64Img){
                    $("#upload-img").attr("src", base64Img);  // 更新图片预览
                });
            }
        });

        // 处理表单提交
        $('#upload-btn').submit(function(ev){
            ev.preventDefault();  // 阻止默认表单提交

            var formData = new FormData(this);  // 获取表单数据
            $.ajax({
                url: '/upload',  // 请求的后端地址
                type: 'POST',
                data: formData,
                contentType: false,
                processData: false,
                success: function(response){
                    console.log('文件上传成功');
                    console.log(response);

                    // 更新识别结果
                    $('#result-show').text('识别结果:' + response.result);  // 显示识别结果
                },
                error: function(error){
                    console.error('文件上传失败');
                    console.error(error);
                }
            });
        });
    });
</script>
</body>
</html>
 步骤3代码:
<script>
    // 将文件转为 Base64 用于图片预览
    function convertToBase64(file, callback) {
        const reader = new FileReader();
        reader.onload = function(e) {
            callback(e.target.result);
        };
        reader.readAsDataURL(file);
    }

    $(function(){
        // 处理图片选择后的显示
        $("#selectImg").change(function(ev){
            const file = $(this)[0].files[0];
            if (file) {
                convertToBase64(file, function(base64Img){
                    $("#upload-img").attr("src", base64Img);  // 更新图片预览
                });
            }
        });

        // 处理表单提交
        $('#upload-btn').submit(function(ev){
            ev.preventDefault();  // 阻止默认表单提交

            var formData = new FormData(this);  // 获取表单数据
            $.ajax({
                url: '/upload',  // 请求的后端地址
                type: 'POST',
                data: formData,
                contentType: false,
                processData: false,
                success: function(response){
                    console.log('文件上传成功');
                    console.log(response);

                    // 更新识别结果
                    $('#result-show').text('识别结果:' + response.result);  // 显示识别结果
                },
                error: function(error){
                    console.error('文件上传失败');
                    console.error(error);
                }
            });
        });
    });
</script>
@app.route("/")
def home():
    return render_template("index.html")


@app.route('/upload', methods=['POST'])
def upload_file():
    if request.method == 'POST':
        f = request.files['the_file']
        # 保存图片到静态目录
        timestamp = time.strftime("%Y%m%d%H%M%S")
        file_path = f'./static/uploads/{timestamp}.png'
        f.save(file_path)

        # 读取保存后的图片数据并预测
        with open(file_path, 'rb') as image_file:
            image_data = image_file.read()

        predicted_fruit = predict_image(image_data)

        # 返回JSON数据
        return jsonify({
            'file_id': timestamp,
            'result': predicted_fruit,
            'img_path': f'/static/uploads/{timestamp}.png'
        })
  步骤4实现效果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2143832.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

集成学习详细介绍

以下内容整理于&#xff1a; 斯图尔特.罗素, 人工智能.现代方法 第四版(张博雅等译)机器学习_温州大学_中国大学MOOC(慕课)XGBoost原理介绍------个人理解版_xgboost原理介绍 个人理解-CSDN博客 集成学习(ensemble)&#xff1a;选择一个由一系列假设h1, h2, …, hn构成的集合…

LLM大模型基础知识学习总结,零基础入门到精通 非常详细收藏我这一篇就够了

在这个已经被大模型包围的时代&#xff0c;不了解一点大模型的基础知识和相关概念&#xff0c;可能出去聊天都接不上话。刚好近期我也一直在用ChatGPT和GitHub Copilot&#xff0c;也刚好对这些基础知识很感兴趣&#xff0c;于是看了一些科普类视频和报告&#xff0c;做了如下的…

从数据到决策,无限住宅代理还可以这么用

在企业发展中&#xff0c;一个良好的决策可以起到推波助澜的作用&#xff0c;让企业飞速发展。在决策的背后离不开数据的支撑&#xff0c;数据驱动决策已成为企业成功的关键因素。然而&#xff0c;随着数据量的激增和竞争的加剧&#xff0c;企业如何有效地收集、分析和应用数据…

Python 课程14-TensorFlow

前言 TensorFlow 是由 Google 开发的一个开源深度学习框架&#xff0c;广泛应用于机器学习和人工智能领域。它具有强大的计算能力&#xff0c;能够运行在 CPU、GPU 甚至 TPU 上&#xff0c;适用于从小型模型到大规模生产系统的各种应用场景。通过 TensorFlow&#xff0c;你可以…

【云原生监控】Prometheus之Alertmanager报警

Prometheus之Alertmanager报警 文章目录 Prometheus之Alertmanager报警概述资源列表基础环境一、部署Prometheus服务1.1、解压1.2、配置systemctl启动1.3、监控端口 二、部署Node-Exporter2.1、解压2.2、配置systemctl启动2.3、监听端口 三、配置Prometheus收集Exporter采集的数…

旧衣回收小程序:开启旧衣回收新体验

随着社会的大众对环保的关注度越来越高&#xff0c;旧衣物回收市场迎来了快速发展时期。在数字化发展当下&#xff0c;旧衣回收行业也迎来了新的模式----互联网旧衣回收小程序&#xff0c;旨在为大众提供更加便捷、简单、透明的旧衣物回收方式&#xff0c;通过手机直接下单&…

关于1688跨境官方接口的接入||跨境卖家必知的1688跨境要点

1688跨境是什么&#xff1f; 1688是国内领先的货源平台&#xff0c;每年服务超过6500万B类买家&#xff0c;其中很大一部分是跨境商家。这些跨境商家采购中国高性价比的商品到海外销售。 为什么要入驻跨境专供&#xff1f; 据统计&#xff0c;2028年跨境市场规模将实现翻三番&…

RabbitMQ(高阶使用)延时任务

文章内容是学习过程中的知识总结&#xff0c;如有纰漏&#xff0c;欢迎指正 文章目录 1. 什么是延时任务&#xff1f; 1.1 和定时任务区别 2. 延时队列使用场景 3. 常见方案 3.1 数据库轮询 优点 缺点 3.2 JDK的延迟队列 优点 缺点 3.3 netty时间轮算法 优点 缺点 3.4 使用消息…

HTML5好看的水果蔬菜在线商城网站源码系列模板2

文章目录 1.设计来源1.1 主界面1.2 商品列表界面1.3 商品详情界面1.4 其他界面效果 2.效果和源码2.1 动态效果2.2 源代码 源码下载 作者&#xff1a;xcLeigh 文章地址&#xff1a;https://blog.csdn.net/weixin_43151418/article/details/142059220 HTML5好看的水果蔬菜在线商城…

MATLAB系列06:复数数据、字符数据和附加画图类

MATLAB系列06&#xff1a;复数数据、字符数据和附加画图类 6. 复数数据、字符数据和附加画图类6.1 复数数据6.1.1 复变量&#xff08; complex variables&#xff09;6.1.2 带有关系运算符的复数的应用6.1.3 复函数&#xff08; complex function&#xff09;6.1.4 复数数据的作…

通信工程学习:什么是ONU光网络单元

ONU&#xff1a;光网络单元 ONU&#xff08;Optical Network Unit&#xff0c;光网络单元&#xff09;是光纤接入网中的用户侧设备&#xff0c;它位于光分配网络&#xff08;ODN&#xff09;与用户设备之间&#xff0c;是光纤通信系统的关键组成部分。以下是关于ONU光网络单元的…

Web后端开发技术:RESTful 架构详解

RESTful 是一种基于 REST&#xff08;表述性状态转移&#xff0c;Representational State Transfer&#xff09;架构风格的 API 设计方式&#xff0c;通常用于构建分布式系统&#xff0c;特别是在 Web 应用开发中广泛应用。REST 是一种轻量级的架构模式&#xff0c;利用标准的 …

构建响应式API:FastAPI Webhooks如何改变你的应用程序

FastAPI&#xff0c;作为一个现代、快速&#xff08;高性能&#xff09;的Web框架&#xff0c;为Python开发者提供了构建API的卓越工具。特别是&#xff0c;它的app.webhooks.post装饰器为处理实时Webhooks提供了一种简洁而强大的方法。在本文中&#xff0c;我们将探讨如何使用…

Git使用教程-将idea本地文件配置到gitte上的保姆级别步骤

&#x1f939;‍♀️潜意识起点&#xff1a;个人主页 &#x1f399;座右铭&#xff1a;得之坦然&#xff0c;失之淡然。 &#x1f48e;擅长领域&#xff1a;前端 是的&#xff0c;我需要您的&#xff1a; &#x1f9e1;点赞❤️关注&#x1f499;收藏&#x1f49b; 是我持…

剖析Spark Shuffle原理(图文详解)

Spark Shuffle 1.逻辑层面 从逻辑层面来看&#xff0c;Shuffle 是指数据从一个节点重新分布到其他节点的过程&#xff0c;主要发生在需要重新组织数据以完成某些操作时。 RDD血统 Shuffle 触发条件&#xff1a; reduceByKey、groupByKey、join 等操作需要对数据进行分组…

制作OpenLinkSaas发行版

发行版配置 作为软件研发效能一站式解决方案&#xff0c;OpenLinkSaas提供了众多的功能。再不同的场景中&#xff0c;所需要的软件功能是有差异的。OpenLinkSaas提供了发行版配置功能&#xff0c;以便在不同场景下组合所有的功能。 修改代码下面的src-tauri/src/vendor_cfg.rs…

软考高级:嵌入式-嵌入式实时操作系统调度算法 AI 解读

讲解 嵌入式实时操作系统中的调度算法主要用于管理任务的执行顺序&#xff0c;以确保任务能够在规定时间内完成。针对你提到的几种调度算法&#xff0c;我会逐一进行通俗解释。 生活化例子 假设你在家里举办一个家庭聚会&#xff0c;家里人轮流使用一个游戏机玩游戏。你作为…

springboot+redis+缓存

整合 添加依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId> </dependency> 连接redis&#xff0c;配置yml文件 主机 端口号 数据库是哪一个 密码 配置类 p…

嵌入式最常用的接口之一:SDIO 介绍

SDIO简介 SDIO(Secure Digital Input Output)是一种基于SD卡技术的扩展接口标准,允许外部设备通过标准的SD卡槽连接并通信。与传统的SD卡仅限于存储数据不同,SDIO设备通过该接口进行多种功能扩展,如网络连接、GPS、蓝牙、摄像头等。这使得SDIO成为一种广泛应用于移动设备…

html实现好看的多种风格手风琴折叠菜单效果合集(附源码)

文章目录 1.设计来源1.1 风格1 -图文结合手风琴1.2 风格2 - 纯图片手风琴1.3 风格3 - 导航手风琴1.4 风格4 - 双图手风琴1.5 风格5 - 综合手风琴1.6 风格6 - 简描手风琴1.7 风格7 - 功能手风琴1.8 风格8 - 全屏手风琴1.9 风格9 - 全屏灵活手风琴 2.效果和源码2.1 动态效果2.2 源…