基于卷积神经网络的书法字体识别系统,resnet50,mobilenet模型【pytorch框架+python】

news2024/10/9 16:21:06

   更多目标检测和图像分类识别项目可看我主页其他文章

功能演示:

基于卷积神经网络的书法字体识别系统,resnet50,mobilenet【pytorch框架,python,tkinter】_哔哩哔哩_bilibili

(一)简介

基于卷积神经网络的书法字体识别系统是在pytorch框架下实现的,这是一个完整的项目,包括代码,数据集,训练好的模型权重,模型训练记录,ui界面和各种模型指标图表等。

该项目有两个可选模型:resnet50和mobilenet,两个模型都在项目中;GUI界面由tkinter设计和实现。此项目可在windowns、linux(ubuntu, centos)、mac系统下运行。

该项目是在pycharm和anaconda搭建的虚拟环境执行,pycharm和anaconda安装和配置可观看教程:

windows保姆级的pycharm+anaconda搭建python虚拟环境_windows启动python虚拟环境-CSDN博客

在Linux系统(Ubuntn, Centos)用pycharm+anaconda搭建python虚拟环境_linux pycharm-CSDN博客

(二)项目介绍

1. 项目结构

​​​​

该项目可以使用已经训练好的模型权重,也可以自己重新训练,自己训练也比较简单

以训练resnet50模型为例:

第一步:修改model_resnet50.py的数据集路径,模型名称、模型训练的轮数

​ 

第二步:模型训练和验证,即直接运行model_resnet50.py文件

第三步:使用模型,即运行gui_chinese.py文件即可通过GUI界面来展示模型效果

2. 数据结构

​​​​​

部分数据展示: 

​​​​

3.GUI界面(技术栈:tkinter+python) 

​​​​

4.模型训练和验证的一些指标及效果
​​​​​1)模型训练和验证的准确率曲线,损失曲线

​​​​​2)热力图

​​3)准确率、精确率、召回率、F1值

4)模型训练和验证记录

​​

(三)代码

由于篇幅有限,只展示核心代码

    def main(self, epochs):
        # 记录训练过程
        log_file_name = './results/resnet50训练和验证过程.txt'
        # 记录正常的 print 信息
        sys.stdout = Logger(log_file_name)
 
        print("using {} device.".format(self.device))
        # 开始训练,记录开始时间
        begin_time = time()
        # 加载数据
        train_loader, validate_loader, class_names, train_num, val_num = self.data_load()
        print("class_names: ", class_names)
        train_steps = len(train_loader)
        val_steps = len(validate_loader)
        # 加载模型
        model = self.model_load()  # 创建模型
        # 修改全连接层的输出维度
        in_channel = model.fc.in_features
        model.fc = nn.Linear(in_channel, len(class_names))
 
        # 模型结构可视化
        x = torch.randn(16, 3, 224, 224)  # 随机生成一个输入
        # 模型结构保存路径
        model_visual_path = 'results/resnet50_visual.onnx'
        # 将 pytorch 模型以 onnx 格式导出并保存
        torch.onnx.export(model, x, model_visual_path)  
        # netron.start(model_visual_path)  # 浏览器会自动打开网络结构
 
 
        # 将模型放入GPU中
        model.to(self.device)
        # 定义损失函数
        loss_function = nn.CrossEntropyLoss()
        # 定义优化器
        params = [p for p in model.parameters() if p.requires_grad]
        optimizer = optim.Adam(params=params, lr=0.0001)
 
        train_loss_history, train_acc_history = [], []
        test_loss_history, test_acc_history = [], []
        best_acc = 0.0
 
        for epoch in range(0, epochs):
            # 下面是模型训练
            model.train()
            running_loss = 0.0
            train_acc = 0.0
            train_bar = tqdm(train_loader, file=sys.stdout)
            # 进来一个batch的数据,计算一次梯度,更新一次网络
            for step, data in enumerate(train_bar):
                # 获取图像及对应的真实标签
                images, labels = data
                # 清空过往梯度
                optimizer.zero_grad()
                # 得到预测的标签
                outputs = model(images.to(self.device))
                # 计算损失
                train_loss = loss_function(outputs, labels.to(self.device))
                # 反向传播,计算当前梯度
                train_loss.backward()
                # 根据梯度更新网络参数
                optimizer.step()  
 
                # 累加损失
                running_loss += train_loss.item()
                # 每行最大值的索引
                predict_y = torch.max(outputs, dim=1)[1]  
                # torch.eq()进行逐元素的比较,若相同位置的两个元素相同,则返回True;若不同,返回False
                train_acc += torch.eq(predict_y, labels.to(self.device)).sum().item()
                # 更新进度条
                train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                         epochs,
                                                                         train_loss)
            # 下面是模型验证
            # 不启用 BatchNormalization 和 Dropout,保证BN和dropout不发生变化
            model.eval()
            # accumulate accurate number / epoch
            val_acc = 0.0  
            testing_loss = 0.0
            # 张量的计算过程中无需计算梯度
            with torch.no_grad():  
                val_bar = tqdm(validate_loader, file=sys.stdout)
                for val_data in val_bar:
                    # 获取图像及对应的真实标签
                    val_images, val_labels = val_data
                    # 得到预测的标签
                    outputs = model(val_images.to(self.device))
                    # 计算损失
                    val_loss = loss_function(outputs, val_labels.to(self.device))  
                    testing_loss += val_loss.item()
                    # 每行最大值的索引
                    predict_y = torch.max(outputs, dim=1)[1]  
                    # torch.eq()进行逐元素的比较,若相同位置的两个元素相同,则返回True;若不同,返回False
                    val_acc += torch.eq(predict_y, val_labels.to(self.device)).sum().item()
 
            train_loss = running_loss / train_steps
            train_accurate = train_acc / train_num
            test_loss = testing_loss / val_steps
            val_accurate = val_acc / val_num
 
            train_loss_history.append(train_loss)
            train_acc_history.append(train_accurate)
            test_loss_history.append(test_loss)
            test_acc_history.append(val_accurate)
 
            print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
                  (epoch + 1, train_loss, val_accurate))
            # 保存最佳模型
            if val_accurate > best_acc:
                best_acc = val_accurate
                torch.save(model.state_dict(), self.model_name)
 
        # 记录结束时间
        end_time = time()
        run_time = end_time - begin_time
        print('该循环程序运行时间:', run_time, "s")
        # 绘制模型训练过程图
        self.show_loss_acc(train_loss_history, train_acc_history,
                           test_loss_history, test_acc_history)
        # 画热力图
        test_real_labels, test_pre_labels = self.heatmaps(model, validate_loader, class_names)
        # 计算混淆矩阵
        self.calculate_confusion_matrix(test_real_labels, test_pre_labels, class_names)

​​​​​(四)总结

以上即为整个项目的介绍,整个项目主要包括以下内容:完整的程序代码文件、训练好的模型、数据集、UI界面和各种模型指标图表等。

项目运行过程如出现问题,请及时交流!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2199467.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

牛客——xay loves or与 __builtin_popcount的使用

xay loves or 题目描述 登录—专业IT笔试面试备考平台_牛客网 运行思路 题目要求我们计算有多少个正整数 yy 满足条件 x \text{ OR } y sx OR ys。这里的“OR”是指按位或运算。为了理解这个问题,我们需要考虑按位或运算的性质。 对于任意两个位 a_iai​ 和 b_…

如何用AI绘画工具生成中国风插画?Midjourney保持风格一致出图

如何运用AI绘画工具如Midjourney,生成符合我们特定要求的艺术作品是一门精进的技巧,尤其当你想生成具有鲜明特色的国风插画时,纯文本提示词的局限性常常使我们难以达到预期效果。然而,借助Midjourney的高级参数功能——特别是sref…

中航资本:招保万金全跌停!“人气王”创历史,半日成交突破600亿

狂奔的“牛”总算迎来“回头”。 今日是新股民入市第一天。依据我国结算的安排,关于10月1日(周二)至10月8日(周二)提交请求的新开证券账户,于10月9日(周三)起可用于买卖。 不过&am…

GIS、向量、文字检索... 火山引擎 ByteHouse 集成全场景分析能力

企业业务场景增多、规模扩大,对于底层数据架构来说,可能也会愈加复杂。 比如,某企业因自身业务发展,需要引入向量检索能力,但前期选型的技术架构并不能直接支持,只能重新引入向量数据库。这意味着&#xff…

JavaWeb - 9 - MySQL

数据库:DataBase(DB),是存储和管理数据的仓库 数据库管理系统:DataBase Management System(DBMS),操纵和管理数据库的大型软件 SQL:Structured Query Language,操作关系型数据库的编程语言,定…

经纬恒润荣获2024中国汽车供应链大会创新成果奖

2024年9月24日-26日,2024中国汽车供应链大会暨第三届中国新能源智能网联汽车生态大会在武汉隆重举办。本届大会以“新挑战、新对策、新机遇——推动中国汽车供应链可持续发展”为主题,集聚政府主管领导、行业专家、汽车及零部件企业精英和主流媒体&#…

这个 JavaScript API 比你想象中更强大!

大家好,我是 ConardLi。 今天,我们来聊聊一个可能被你忽视,而且非常强大的标准 JavaScript API - AbortController 。 在过去,大家在提到 AbortController 的时候,一般会举请求中断的例子,就连 MDN 给到的…

重学SpringBoot3-集成Redis(三)之注解缓存策略设置

更多SpringBoot3内容请关注我的专栏:《SpringBoot3》 期待您的点赞👍收藏⭐评论✍ 重学SpringBoot3-集成Redis(三)之注解缓存策略设置 1. 引入 Redis 依赖2. 配置 RedisCacheManager 及自定义过期策略2.1 示例代码:自定…

重塑能源持续亏损近22亿:今年前五个月销量下滑,产能利用率骤降

《港湾商业观察》黄懿 9月2日,上海重塑能源集团股份有限公司(下称“重塑能源”)向港交所提交上市申请书,委任中国国际金融香港证券有限公司、招银国际融资有限公司及法国巴黎证券(亚洲)有限公司为整体协调…

Linux(不同版本系统包含Ubuntu)下安装mongodb详细教程

一、下载MongoDB 在MongoDB官网下载对应的MongoDB版本,可以点击以下链接快速跳转到下载页面: mongodb官网下载地址 注意选择和自己操作系统一致的platform,可以先查看自己的操作系统 查看操作系统详情 命令: uname -a 如图:操…

海洋大地测量基准与水下导航系列之二国外海底大地测量基准和海底观测网络发展现状(下)

2004年,英国、德国、法国等国家在欧洲“全球环境与安全监测’(Global Monitoring for Environment and Security,GMES)观测计划倡导下制定了“欧洲海底观测网络”(European Seafoor Observatory Network,ESONET)计划。ESONET是一个多学科的欧洲卓越网络(NoE &#x…

光路科技以技术创新为驱动,打造创新型企业新标杆

近日,深圳市光路在线科技有限公司(光路科技)凭借其出色的创新能力和市场表现,荣获深圳市中小企业服务局颁发的“创新型中小企业”称号。这一荣誉标志着光路科技在推动行业发展和技术进步方面取得了显著成就。 光路科技自2008年成立…

【含文档】基于Springboot+Android的在线招聘平台(含源码+数据库+lw)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 系统定…

◇【code】PPO: Proximal Policy Optimization

整理的代码库:https://github.com/Gaoshu-root/Code-related-courses/tree/main/RL2024/PPO OpenAI 文档 —— PPO-Clip OpenAI 文档 界面链接 PPO: on-policy 算法、适用于 离散 或 连续动作空间。可能局部最优 PPO 的动机与 TRPO 一样:…

Scott Brinker:企业正在更换更多的Martech,专注集成和API,不断扩大技术栈

营销技术替代因素:集成和API排在第二位 MarTech.org组织了2024年Martech替代调查,它能够深入了解营销技术栈是如何演变的。在496名受访者中,有65%的人表示他们在过去一年中更换了他们技术栈中的一个或多个营销技术解决方案。这些是最常被替代…

Tableau|三 数据连接与管理

一 Tableau的数据架构 数据连接层(Connection)、数据模型层(DataModel)和数据可视化层(VizQL)。 1.数据连接层 决定了如何访问源数据和获取哪些数据。 数据连接层的数据连接信息包括数据库、数据表、数据视…

华为大咖说 | 新时代,智能电动车车联网有哪些发展趋势?(下篇)

本文作者:朱行健(华为专家)全文约4252字,阅读约需9分钟 近年来,汽车产业逐步向电动化、自动化、网联化、共享化发展,车联网开始成为新的竞争主体,汽车市场开始出现新的市场发展驱动力、形成新的…

E36.C语言模拟试卷1第一大题选题解析与提示(未完)

点我去下载C语言模拟试卷1的文件 备注:ZIP文件中的参考答案仅仅提供最终结果 目录 第3题 第5题 第7题 第9题 第14题 第16题 第19题 第20题 第22题 第24题 第26题 第27题 第28题 第3题 3.若有说明语句:char c ‘\64’ ; 则变量C包含: …

python19_加减乘除(二)

加减乘除 a hello b world c 2 d 4# 字符串加法 def str_add(A, B):result A Breturn result# 字符串乘法 def str_mul(A, B):result A * Breturn result# 字符串除法 def str_div(A, B):result B / Areturn result# 字符串减法 def str_sub(A, B):result B - Aretur…

A股牛市来袭,资本涌动:加密市场的出路与机遇

近期,随着A股的强劲反弹,不少加密市场的投资者,尤其是一些KOL(关键意见领袖),开始转移资金并公开建议进军A股。这种趋势反映出部分投资者对加密市场的信心动摇,尤其是在全球宏观经济不确定性加剧…