使用Hog特征进行字母和数字的分类

news2024/9/28 9:24:37

目的:使用字母数字的二值图像,进行识别:

整体思路:

1)对图像进行预处理;

对收集的单个字符进行二值化,进行数据均衡,并且将所有的字符图片直接resize为20*20(有过进行等比例缩放后padding为20*20,最终算法效果较差)

2)提取hog特征

3)进行svm(lightGBM)分类

查阅的资料:

使用python代码,提取相关的hog特征:

 参数设置:pythonscikit-image库HOG提取特征(参数解释) - 百度文库

本人使用的hog特征提取代码:

def get_hog():
    """ Get hog descriptor """

    # cv2.HOGDescriptor(winSize, blockSize, blockStride, cellSize, nbins, derivAperture, winSigma, histogramNormType,
    # L2HysThreshold, gammaCorrection, nlevels, signedGradient)
    hog = cv2.HOGDescriptor((SZ, SZ), (8, 8), (4, 4), (8, 8), 9, 1, -1, 0, 0.2, 1, 64, True)
    print("get descriptor size: {}".format(hog.getDescriptorSize()))
    return hog

def get_data(train_data_path, train_dir, result_num):
    '''
    加载训练样本
    :param train_data_path:
    :param result_num:
    :return:
    '''
    hog = get_hog()
    # 识别中文
    # ------加载训练样本
    chars_train = []
    chars_train_label = []
    files = get_file_list(train_data_path)  # 获取所有图片的绝对路径
    files2 = get_file_list(train_dir)  # 获取所有图片的绝对路径
    files += files2
    for filepath in files:
        digit_img = cv2.imread(filepath)
        digit_img = cv2.cvtColor(digit_img, cv2.COLOR_BGR2GRAY)
        chars_train.append(hog.compute(deskew(digit_img)))
        # chars_train.append(preprocess_hog(deskew(digit_img)))
        classTag = result_num[filepath.split("/")[-2]]  # 得到 类标签(数字)
        chars_train_label.append(classTag)

    # chars_train = np.squeeze(chars_train)
    chars_train = np.squeeze(np.float32(chars_train))
    chars_label = np.array(chars_train_label)
    return chars_train, chars_label

 使用SVM进行分类的代码:

class SVM(StatModel):
    def __init__(self, C = 1, gamma = 0.5):
        self.model = cv2.ml.SVM_create()
        self.model.setGamma(gamma)
        self.model.setC(C)
        self.model.setKernel(cv2.ml.SVM_RBF)
        # self.model.setKernel(cv2.ml.SVM_LINEAR)
        self.model.setType(cv2.ml.SVM_C_SVC)
        # 定义算法终止条件
        # self.model.setTermCriteria((cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 500, 1e-6)) #迭代次数超过阈值max_iter时停止,#cv2.TERM_CRITERIA_COUNT |

    # train svm
    def train(self, samples, responses):
        self.model.train(samples, cv2.ml.ROW_SAMPLE, responses)

    def predict(self, chars_test):
        r = self.model.predict(chars_test)
        # result = r[1].ravel()
        return r

    # inference
    def predict_svm(self, test_imgList, tcName, model_path, c, g):
        self.model = SVM(C=c, gamma=g)
        if os.path.exists(model_path):
            self.model.load(model_path)
            # cv2.ml.SVM_load("svmtest.mat")
            # self.model = joblib.load(model_path)
        else:
            print('model_path is missing!')

        allErrCount = 0
        class_num = 37
        ErrCount = np.zeros(class_num, int)
        TrueCount = np.zeros(class_num, int)

        for chars_test, chars_test_label, index in test_imgList:
            # print(chars_test.shape, index)
            first = time.time()
            r = self.predict(chars_test)
            end = time.time()
            print("Testing one pic spent {:.6f}s.".format((end - first)/len(chars_test)))
            result = r[1].ravel()
            # for x in result:
            #     if x != index:
            #         print('pred error:', result_index[index], result_index[x])

            errCount = len([x for x in result if x != index])
            ErrCount[index] = errCount
            TrueCount[index] = len(chars_test_label) - errCount
            print("errorCount: {}.".format(errCount), "trueCount: {}.".format(len(chars_test_label) - errCount))
            allErrCount += errCount

        # tet = time.time()
        # print("Testing All class total spent {:.6f}s.".format(tet - tst))
        print("All error Count is: {}.".format(allErrCount))

        print("number", " TrueCount", " ErrCount")
        mean_acc = 0
        num_div = 0
        for tcn in tcName:
            # tcn = int(tcn)

            # if all_num > 0:
            num_div += 1
            index = result_num[tcn]
            all_num = (TrueCount[index] + ErrCount[index])
            acc = TrueCount[index] /all_num
            mean_acc += acc
            print(tcn, "     ", TrueCount[index], "      ", ErrCount[index], 'acc:', acc)
        mean_acc /= class_num
        return mean_acc
        # plt.figure(figsize=(12, 6))
        # x = list(range(37))
        # plt.plot(x, TrueCount, color='blue', label="TrueCount")  # 将正确的数量设置为蓝色
        # plt.plot(x, ErrCount, color='red', label="ErrCount")  # 将错误的数量为红色
        # plt.legend(loc='best')  # 显示图例的位置,这里为右下方
        # plt.title('Projects')
        # plt.xlabel('number')  # x轴标签
        # plt.ylabel('count')  # y轴标签
        # plt.xticks(np.arange(37), list(tcName))
        # plt.show()

        # inference




    def train_svm(self, chars_train, chars_label, c, g):
        #识别英文字母和数字
        self.model = SVM(C=c, gamma=g)
        self.model.train(chars_train, chars_label)
        # joblib.dump(self.model.model, path)
        return self.model


    def save_trainmodel(self, path):
        if not os.path.exists(path):
            self.model.save(path)

            # joblib.dump(self.model, path)
        # if not os.path.exists("./train_dat/svmchinese.dat"):
        #     self.modelchinese.save("./train_dat/svmchinese.dat")

调用svm分类,并且进行最优参数查找:

    best_score = 0
    best_parameters = {'gamma': 0.001, 'C': 0.001}
    for g in [0.001, 0.01, 0.1, 1, 10, 100]:
        for c in [0.001, 0.01, 0.1, 1, 10,   100]:
            # if g in [100]:
            #     continue
            if g == 0.01 and c == 0.001:
                continue
            model_path = os.path.join(model_path_save, 'svm_hog20221123_'+'_g_' +str(g)+'_c_' +str(c)+'.dat')
            svm_model = SVM(C=c, gamma=g) #12.5

            svm_model.train_svm(chars_train, chars_label, c, g)
            svm_model.save_trainmodel(model_path)

            acc = svm_model.predict_svm(test_imgList, tcName, model_path, c, g)
            print('c:', c, 'gamma:', g, 'MEANacc:', acc)
            print('============================================')
            if acc > best_score:  # 找到表现最好的参数
                best_score = acc
                best_parameters = {'gamma': g, 'C': c}
    print('best_score:', best_score, 'best_parameters:', best_parameters)

使用lightGBM分类

import lightgbm as lgb
    import datetime
    import sklearn
    # import warnings
    # warnings.filterwarnings('ignore')
    # folds = KFold(n_splits=5, shuffle=True, random_state=1996)

    # 模型参数设定
    model = lgb.LGBMClassifier(boosting_type='gbdt' #'dart' #'goss' #学习器类型,通常选取gbdt
                               , class_weight=None
                               , colsample_bytree=1.0
                               , importance_type='split'
                               , learning_rate=0.1
                               , max_depth=5 # *   指定了每棵树的最大深度或者它能够生长的层数上限,数据量小,4-10都无所谓。
                               , min_child_samples=20
                               , min_child_weight=0.001
                               , min_split_gain=0.0
                               , n_estimators=200#迭代次数
                               , n_jobs=2
                               , num_leaves=31 # * 用来设置组成每棵树的叶子的数量,由于lightGBM是leaves_wise生长,官方说法是要小于2^max_depth
                               , objective='multi:softmax'
                               , random_state=None
                               , reg_alpha=0.0
                               , reg_lambda=0.0
                               , silent=True
                               , subsample=1.0
                               , subsample_for_bin=255)
    '''
    n_estimators:拟合的树的棵树,相当于训练轮数 
    n_jobs:并行运行多线程核心数

    '''
    # model.fit(pca_feats, chars_label, eval_set=[(pca_feats, chars_label), (test_imgs, test_labels)],
    #           eval_metric=['logloss'], verbose=True) #
    model.fit(chars_train, chars_label, eval_set=[(chars_train, chars_label), (test_imgs, test_labels)],
              eval_metric=['logloss'], early_stopping_rounds=20, verbose=True)  # early_stopping_rounds=20,

    # lgb.LGBMClassifier(n_estimators=200, n_jobs=1, objective='multi:softmax')
    model_path = "/home/fuxueping/4tdisk/data/certificate_reader/传统算法处理mrz/GBM_model/model.txt"
    # model.booster_.savemodel(model_path)
    joblib.dump(model, model_path)

    start = time.time()
    # model = joblib.load(model_path)
    pred_y_test = model.predict(test_imgs)

数据样例:

 

最终未解决的问题:O和0无法分类正确,错误率很高;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/139600.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Docker镜像如何上传阿里云

目录 1.前期准备 2.push(推)镜像 3.pull(拉)镜像 1.前期准备 1.注册阿里云账户 阿里云官方网站链接:https://dev.aliyun.com 2.登陆账户 3.配置Docker加速器 注:搜索“容器镜像服务” 4.创建镜像仓库的命名空间 例如:xnx 5.创建镜像仓库(创…

77、【字符串】leetcode ——151. 反转字符串中的单词(C++版本)

题目描述 原题链接:151. 反转字符串中的单词 解题思路 先预处理头部空格和中间多余空格;再将整体进行逆转。例如:the sky is blue —> eulb si yks eht;最后,分别对每个单词进行逆转,即可完成反转字符…

生产制造业订单管理软件如何做好订单变更管理?

生产制造企业,由于客户需求具有多样性和不确定性,客户订单的内容便会存在出现各种变更的可能,如数量、交期、更改具体参数等,提出变更订单,是很常见的现象。生产制造企业常见的订单变更需求1、PMC已经下完制令后&#…

户外运动耳机怎么选、五款最适合户外运动的耳机分享

对于运动爱好者来说,很多人都比较喜欢边听音乐边运动,音乐能够让运动起来更有激情,提升运动锻炼效果。那么到底什么耳机更适合户外运动呢?目前运动耳机在市面上有很多,但不是每一款都适合户外运动,自己找的…

7.Express模块基础用法

Express是做web服务器的,是一个第三方的包,官网 Express - 基于 Node.js 平台的 web 应用开发框架 - Express 中文文档 | Express 中文网 Express的部分用法与http模块类似 在我看来Express是一个轻量级的框架,如果用于做一些较复杂的业务会…

亚马逊、阿里国际、Shopee、Temu等跨境电商平台测评自养号经验分享

对于亚马逊、temu、阿里国际等平台商家来说,流量非常重要。商家需要想办法提高流量。卖家店铺没有流量怎么办? 获取流量的第一点:自然搜索 自然搜索流量的来源实际上是通过站点的优化来提高排名的效果。站点优化有很多维度,如选择合适的关键…

嵌入式工程师招聘要求有哪些?

现在有非常多的朋友会问嵌入式软件工程师在实际的招聘中都是什么样的标准与要求呢,对于这个问题空口无凭,我今天从招聘网站上找了一些典型的招聘案例。 一、嵌入式软件工程师 职位描述: 1、对需求进行分析评审,并输出开发计划&a…

基于Vue和SpringBoot的进销存管理系统的设计和实现

作者主页:Designer 小郑 作者简介:Java全栈软件工程师一枚,来自浙江宁波,负责开发管理公司OA项目,专注软件前后端开发(Vue、SpringBoot和微信小程序)、系统定制、远程技术指导。CSDN学院、蓝桥云…

【阶段二】Python数据分析Pandas工具使用08篇:探索性数据分析:数据的描述:数据的分散趋势与数据的分布形态

本篇的思维导图: 数据的分散趋势 数据的分散趋势是用来刻画数值型变量偏离中心的程度,最为常用的分散趋势指标有标准差、极差、四分位差等。通过这些指标可以反映样本之间的差异大小,如果指标值越大,说明样本之间差异越明显,反之差异越小。 方差与标准差 …

阿里云产品专家陶炳哲:Java应用最佳实验

2022 年 9 月 28 日,阿里云用户组(AUG)第 11 期活动在深圳举办。活动现场,阿里云产品专家陶炳哲向参会企业分享了《Java 应用最佳实验》。本文根据演讲内容整理而成。 大家好,我是陶炳哲,是阿里云产品专家。…

React(coderwhy)- 02(函数组件)

React的组件化开发 组件化是一种分而治之的思想: 如果我们将一个页面中所有的处理逻辑全部放在一起,处理起来就会变得非常复杂,而且不利于后续的管理以及扩展。 但如果,我们讲一个页面拆分成一个个小的功能块,每个功能…

191:vue+openlayers 选择feature,固定按钮删除selected feature

第191个 点击查看专栏目录 本示例的目的是介绍如何在vue+openlayer中使用select来选择feature元素,通过按键来删除selected的feature。 直接复制下面的 vue+openlayers源代码,操作2分钟即可运行实现效果 文章目录 示例效果配置方式示例源代码(共95行)相关API参考专栏目标…

【回答问题】ChatGPT上线了!给我推荐20个比较流行的ROS(Robot Operating System)算法模型

给我推荐20个比较流行的ros算法模型 很好,下面是20个比较流行的ROS(Robot Operating System)算法模型: 便携式机器人操作系统(PR2) 全向轮操作系统(TurtleBot) 机器人操作系统(ROS)用于机器人操作 基于空间相对导航的机器人操作系统(SLAM) 自适应机器人…

全网惟一面向软件测试人员的Python基础教程-在Python中如何对列表进行增删改排序?

全网惟一面向软件测试人员的Python基础教程 起点:《python软件测试实战宝典》介绍 第一章 为什么软件测试人员要学习Python 第二章 学Python之前要搞懂的道理 第三章 你知道Python代码是怎样运行的吗? 第四章 Python数据类型中有那些故事呢?…

IB成绩可以申请英国大学吗?

我们都知道ALEVEL课程是英国正统的高中课程,几乎被所有英国学校作为大学招收新生的入学课程。 那么,IB课程作为与ALEVEL课程一样享受极高国际知名度的课程,是否也能够申请英国大学呢?可以确定的告诉大家,用IB课程申请英…

高压放大器的组成部分有哪些(功率放大器的性能好坏)

虽然很多电子工程师经常使用高压放大器,但是对于高压功率放大器的组成和使用都不太清楚,下面来介绍一下高压放大器的组成部分以及如何验证功率放大器的性能好坏。 一、高压放大器的介绍 高压放大器是一种理想的功率放大器,可以放大交流和直流…

[C语言]和我一起来认识“整型在内存中的存储”

目录 1.整型类型中的成员 2.整型在内存中的存储 2.1原码,反码,补码 2.2整型在内存中以补码存放数据 2.3大小端 2.3.1大小端的介绍 2.3.2通过编程判别当前机器的字节序 1.整型类型中的成员 (unsigned为无符号类型,signed为有符号类型) 1.c…

【java集合】HashMap源码解析(基于JDK1.8)

一、Hashmap简介 类继承关系图如下: HashMap实现了三个接口,一个抽象类。主要的方法都在Map接口中,AbstractMap抽象类实现了Map方法中的公共方法,例如:size(),containsKey(),clear()等,主要方法由子类自己实现。 Ha…

Linux驱动之系统移植----uboot移植_修改网络驱动(uboot无设备树版本)

uboot版本:uboot.2016.03 开发板:100ask_imx6ull_pro 修改网络驱动 须知 I.MX6UL/ULL内部有个以太网 MAC外设,也就是 ENET,需要外接一个 PHY芯片来实现网络通信功能,也就是内部 MAC外部 PHY芯片的方案。(一个MAC可对应N个PHY芯片, PHY有地址…

5G NR标准 第11章 多天线传输

第11章 多天线传输 多天线传输是 NR 的关键组成部分,尤其是在较高频率下。 本章一般性地介绍了多天线传输的背景,然后详细描述了 NR 多天线预编码。 11.1 简介 使用多个天线进行传输和/或接收可以在移动通信系统中提供巨大的好处。 发射机和/或接收…