EM算法实现对iris数据集和meat数据集的分类【MATLAB版本】

news2024/11/24 2:56:41

摘要:本章实验主要是对于学习 EM 算法的原理,掌握并实现混合高斯模型非监督学习 的 EM 算法,要求在两个数据集上面实现混合高斯模型的非监督学习的EM算法。混合模型是相对于单高斯模型而言的,对于某个样本数据而言,可能单个高斯分布不能够很好的拟合其特征,而是由多个高斯模型混合加权而成(我们假设为K个),我们通过EM算法就可以求解出有未知参数情况的高斯模型的分布,进而完成无监督学习的分类任务。
关键字:混合高斯模型;EM算法;无监督学习。
Abstract: This chapter’s experiment aims to understand and implement the EM (Expectation Maximization) algorithm for unsupervised learning of the mixture Gaussian model. The EM algorithm is applied to two datasets to explore the application of the mixture Gaussian model in unsupervised learning. Compared to a single Gaussian model, the mixture model can better fit the features of sample data. The mixture model assumes that the sample data is a weighted combination of multiple Gaussian models (assumed to be K). The EM algorithm is used to estimate the Gaussian model distribution under unknown parameter conditions, thereby completing the classification task in unsupervised learning.
Keywords: Mixture Gaussian model, EM algorithm, unsupervised learning.

一、 技术论述
(1) 所使用的主要知识是无监督学习的思想,最大似然估计方法,高斯混合模型假设以及EM算法。
(a) 其中无监督学习是机器学习和模式识别的一个大类,机器学习主要分为有监督学习(带标签的学习)和无监督学习(无标签的学习),对于有标签的学习而言,我们主要是对其构造准则函数,然后根据不同的下降方法对其进行梯度下降(随机梯度下降,批量梯度下降等)寻优计算,最后找到对应的准则函数最小的那组权重就是我们所要得到的最优分类器的参数;而无监督学习是在没有数据标签的情况下,利用数据分布的特点构造分类器来完成对于样本的分类任务。
(b) 最大似然估计方法 (MLE) 是参数估计的一个重要方法,他主要利用的就是样本数据的分布特点,构造包含参数的最大似然函数,其理论认为最大似然函数的极值点所对应的参数值就我们想要求得的哪一组参数值,即“出现的事情就是最有可能发生事情”思想。
(c) EM算法,集期望最大化算法(Expectation Maximization, EM)它是最大似然的推广(MLE)其核心思想是根据已有的数据,递归估计似然函数的参 数。EM算法在样本数据中某些特征丢失的情况下,仍然可以进行参数估计
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
图2-1:K=3的混合高斯模型图2-1:K=3的混合高斯模型

在这里插入图片描述
在这里插入图片描述
2.1 实验数据
(1)鸢尾花数据集
该数据集有150个样本,分为3类。每个样本有4个特征。
link
(2)UCI手写数字数据集“0-9”
数据集mfeat是由手写体数字“0”-“9”的特征数据集。1个手写数字有200个模式样本,总共有2000样本,存储在 ASCII 文本文件中,每1行存储一个样 本的特征。其中,集合中的前200个样本是数字“0”的样本,后续200个样本 分别是数字“1”-“9”的样本。
link
在这里插入图片描述

其中,
公式(1)中的π_k是针对某个k的q_nk的均值,是集合X中的所有样本属于第 k 类的后验概率 之和;
公式(2)中的μ_k是以q_nk为权值的数据样本的均值(期望),是以后验概率为权值 加权的均值(期望)
公式(3)中的Σ_k是以q_nk为权值的加权协方差矩阵,是以后验概率为权值 加权的方差。
实验(一)
在这里插入图片描述
图3-1:混淆矩阵(迭代轮数为100)
在这里插入图片描述
图3-2:取前两个特征的可视化结果(迭代轮数为100)

Total Accuracy: 0.62667
Class Accuracy:
0.0600
0.8600
0.9600
实验(二):mfeat-fou数据集
在这里插入图片描述
图3-3:混淆矩阵(迭代轮数为50)
在这里插入图片描述
图3-4:取前两个特征的可视化结果(迭代轮数为5)
代码
实验一

clear all
clc
fid = fopen('iris.data', 'r');
data = textscan(fid, '%f,%f,%f,%f,%s');
fclose(fid);

% 提取特征数据,去除类别标签
features = cell2mat(data(:, 1:4));
% 将类别标签转换为数值标签
labels = data{:, 5}; % 获取类别标签数据
numericLabels = categorical(labels); % 转换为分类变量
% 初始化参数
numClusters = 3; % 分类簇的数量
maxIterations = 200; % 最大迭代次数

[numSamples, numFeatures] = size(features);
probabilities = zeros(numSamples, numClusters); % 各样本属于各分类的概率
means = repmat(min(features), numClusters, 1) + rand(numClusters, numFeatures) .* repmat(range(features), numClusters, 1);
covariances = repmat(diag(var(features)), 1, 1, numClusters); % 各分类的协方差矩阵

weights = ones(1, numClusters) / numClusters; % 各分类的权重

for iteration = 1:maxIterations
    % E 步骤:计算样本属于各分类的概率
    for i = 1:numClusters
        probabilities(:, i) = weights(i) * mvnpdf(features, means(i, :), squeeze(covariances(:, :, i)));
    end
    probabilities = probabilities ./ sum(probabilities, 2);
    
    % M 步骤:更新分类的均值、协方差矩阵和权重
    for i = 1:numClusters
        totalProb = sum(probabilities(:, i));
        weights(i) = totalProb / numSamples;
        means(i, :) = sum(probabilities(:, i) .* features) / totalProb;
        
        diff = features - means(i, :);
        covariances(:, :, i) = (diff' * (diff .* probabilities(:, i))) / totalProb;
    end
    
    % 计算分类准确率和混淆矩阵
    trueLabels = grp2idx(numericLabels);
    %对数据进行倒序
    trueLabels = flip( trueLabels);
    predictedLabels = zeros(numSamples, 1);

    for i = 1:numSamples
        [~, predictedLabels(i)] = max(probabilities(i, :));
    end

    totalAccuracy = sum(predictedLabels == trueLabels) / numSamples;
    classAccuracy = zeros(numClusters, 1);

    for i = 1:numClusters
        idx = trueLabels == i;
        classAccuracy(i) = sum(predictedLabels(idx) == trueLabels(idx)) / sum(idx);
    end


end
    disp(['Total Accuracy: ', num2str(totalAccuracy)]);
    disp('Class Accuracy:');
    disp(classAccuracy);
    
    % 绘制混淆矩阵
    confusionMatrix = confusionmat(trueLabels, predictedLabels);

    figure;
    heatmap(confusionMatrix, 'ColorbarVisible', 'off');
    xlabel('Predicted Class');
    ylabel('True Class');
    title('Confusion Matrix');
[~, labels] = max(probabilities, [], 2);

% 显示分类结果
figure;
gscatter(features(:, 1), features(:, 2), data{:, 5});
title('EM Algorithm - Iris Dataset');
xlabel('Feature 1');
ylabel('Feature 2');

实验二

clear all
clc
j=0;
load('-ascii','mfeat-kar.mat');
% 提取特征数据,去除类别标签
features = mfeat_kar;
trueLabels=zeros(2000,1);
% 将类别标签转换为数值标签
for i=1:1999
    trueLabels(i,1)=fix(i/200);
end
 trueLabels(2000,1)=9;
% 初始化参数
numClusters = 10; % 分类簇的数量
maxIterations = 50; % 最大迭代次数

[numSamples, numFeatures] = size(features);
probabilities = zeros(numSamples, numClusters); % 各样本属于各分类的概率
means = repmat(min(features), numClusters, 1) + rand(numClusters, numFeatures) .* repmat(range(features), numClusters, 1);
covariances = repmat(diag(var(features)), 1, 1, numClusters); % 各分类的协方差矩阵
weights = ones(1, numClusters) / numClusters; % 各分类的权重
%求出的协方差矩阵因为维数之间有相关的关系,所以求出来的协方差矩阵不是正定矩阵,我们需要对其进行处理
for iteration = 1:maxIterations
    % E 步骤:计算样本属于各分类的概率
    for i = 1:numClusters
        covariances(:, :, i) = covariances(:, :, i) + 1e-7 * eye(numFeatures);%添加扰动,反之其非正定,避免Mvpdf函数报错
        probabilities(:, i) = weights(i) * mvnpdf(features, means(i, :), squeeze(covariances(:, :, i)));
    end
    probabilities = probabilities ./ sum(probabilities, 2);
    
    % M 步骤:更新分类的均值、协方差矩阵和权重
    for i = 1:numClusters
        totalProb = sum(probabilities(:, i));
        weights(i) = totalProb / numSamples;
        means(i, :) = sum(probabilities(:, i) .* features) / totalProb;
        
        diff = features - means(i, :);
        covariances(:, :, i) = (diff' * (diff .* probabilities(:, i))) / totalProb;
    end
    
    % 计算分类准确率和混淆矩阵
    predictedLabels = zeros(numSamples, 1);

    for i = 1:numSamples
        [~, predictedLabels(i)] =max(probabilities(i, :));
    end
      for i = 1:numSamples
          predictedLabels(i)=predictedLabels(i)-1;
      end

    totalAccuracy = sum(predictedLabels == trueLabels) / numSamples;
    classAccuracy = zeros(numClusters, 1);

    for i = 1:numClusters
        idx = trueLabels == i;
        classAccuracy(i) = sum(predictedLabels(idx) == trueLabels(idx)) / sum(idx);
    end


end
    disp(['Total Accuracy: ', num2str(totalAccuracy)]);
    disp('Class Accuracy:');
    disp(classAccuracy);

    % 绘制混淆矩阵
    % 计算分类准确率和混淆矩阵
    labelsAdjusted = 0:numClusters-1; % 调整后的标签,从0开始
    confusionMatrix = confusionmat(trueLabels, predictedLabels);
    figure;
    heatmap(labelsAdjusted, labelsAdjusted,   confusionMatrix, 'ColorbarVisible', 'off');
    xlabel('Predicted Class');
    ylabel('True Class');
    title('Confusion Matrix');
[~, labels] = max(probabilities, [], 2);

% 显示分类结果
figure;
gscatter(features(:, 1), features(:, 2), labels-1);
title('EM Algorithm - Iris Dataset');
xlabel('Feature 1');
ylabel('Feature 2');




本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/616810.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【高级篇】分布式事务

分布式事务 1.分布式事务问题 1.1.本地事务 本地事务,也就是传统的单机事务。在传统数据库事务中,必须要满足四个原则: 1.2.分布式事务 分布式事务,就是指不是在单个服务或单个数据库架构下,产生的事务&#xff0c…

Nginx:Tomcat部署及优化(二)

Nginx:Tomcat部署及优化(二) 一、Tomcat 优化1.1 内核参数优化1.2 Tomcat 配置文件参数优化1.3 Java 虚拟机(JVM)调优 二、NginxTomcat 负载均衡、动静分离 一、Tomcat 优化 Tomcat 默认安装下的缺省配置并不适合生产…

9款超级实用的网页设计工具,快来看看有没有你用过的

随着网络时代的快速发展,游戏、购物、音乐、影视和社交网站的兴起都表明了网页设计的重要性! 网页设计工具作为网页设计师的生产工具,自然要选择好的。 让我们分享9个高质量的网页设计工具,让您的设计效率悄然提高! …

【Python TDD和BDD】零基础也能轻松掌握的学习路线与参考资料

Python TDD和BDD的学习路线 TDD(测试驱动开发)和BDD(行为驱动开发)在软件开发中的作用越来越受到重视。TDD通过先写测试代码,再编写生产代码的方式,使得开发者可以在开发过程中确保代码质量和正确性&#…

黑客学习-xss漏洞总结

1、什么是xss 先来看案例 在一个输入框中,输入js代码,存放alter()其弹窗,结果可以看到,代码成功执行。这个就是xss漏洞 XSS攻击全称跨站脚本攻击,是一种在Web应用中常见的安全漏洞,它允许用户将恶意代码植入到Web页面…

分布式事务 2PC

tip:作为程序员一定学习编程之道,一定要对代码的编写有追求,不能实现就完事了。我们应该让自己写的代码更加优雅,即使这会费时费力。 文章目录 一、简介二、2PC 的运行流程三、2PC 一定能保证数据的一致性吗?四、2PC 的…

软件测试——未来软件测试的5个主要趋势

全球各地的企业每天都在发展变化着,以应对市场挑战,满足日益成熟的客户需求。即使是正在进行的技术进步也会使软件测试专家在实践的过程中更加专注和精确。 2021年给软件测试领域带来了新的技术解决方案,以及质量保证和软件测试的实现。与此同…

Springcloud--异步通信RabbitMq快速入门

RabbitMQ 1.初识MQ 1.1.同步和异步通讯 微服务间通讯有同步和异步两种方式: 同步通讯:就像打电话,需要实时响应。 异步通讯:就像发邮件,不需要马上回复。 两种方式各有优劣,打电话可以立即得到响应&am…

YOLOv5改进系列(8)——添加SOCA注意力机制

【YOLOv5改进系列】前期回顾: YOLOv5改进系列(0)——重要性能指标与训练结果评价及分析 YOLOv5改进系列(1)——添加SE注意力机制

动态查找表

动态查找表 1.二叉排序树1.1. 定义1.2. 查找过程1.3. 插入过程1.4. 创建二叉排序树1.5. 删除操作(1)被移除的结点是叶子结点(2)被移除的结点只有左子树或者只有右子树;(3)被移除的结点既有左子树…

两张图理解MR与XR

我们知道,AR是在现实世界上叠加虚拟信息和图像,VR是完全模拟的虚拟世界,那么对于MR和XR的概念会稍显复杂,本文试图通过2张图来理解它们,如有不对,祈请纠正。 MR 关于MR,先来看看下面第一张图。 …

vue 3 第三十四章:nextTick

nextTick是Vue3中的一个非常有用的函数&#xff0c;它可以在下一次DOM更新循环结束后执行回调函数。这个函数可以用来解决一些异步更新视图的问题&#xff0c;例如在修改数据后立即获取更新后的DOM节点。以下是一个简单的示例&#xff1a; <template><div><p&g…

华硕无畏灵耀破晓原装Windows10/11系统

第一步&#xff1a;下载原装系统文件 第二步&#xff1a;灵耀/无畏/破晓需要自备16g空u盘安装 第三步&#xff1a;创建u盘分区&#xff0c;第一个分区格式为FAT32(存放TLK引导文件)&#xff0c;第二个分区大小为NTFS&#xff08;存放底包&#xff1a;HDI.OFS.SWP.EDN.KIT&…

Unity Package Manager 使用

项目组开发的工具可以托管到远程仓库里&#xff0c;别的项目 也可以使用。 在Unity工程Assets 下 创建自己的插件目录 运行时 代码 和 编辑器代码 &#xff0c;创建 对应的 程序集&#xff0c;以及package.json 文件 package.json内容&#xff1a;可参考官方的&#xff0c;n…

测试人何去何从?2023年测试工程师突破自我,卷出测试圈...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 2023年测试行业现…

被上司问“测得怎么样了?”我心里慌到不行

目录 前言 你测的怎么样了&#xff1f; 这样回答 初入测试职场 结尾&#xff1a; 前言 说实话&#xff0c;我真想从上面去掉"似乎"两个字&#xff0c;软件测试人&#xff0c;就是苦逼&#xff01;有的人曾抱怨过开发很糟糕&#xff0c;但我们没办法要求开发在会写代…

360浏览器如何屏蔽某搜索网站的热搜

1.安装油猴&#xff08;Tampermonkey插件&#xff09; 下载油猴&#xff1a;官网油猴tampermonkey官网_油猴脚本手机版油猴插件下载 安装&#xff1a;360浏览器安装可以参考这边文章。 地址&#xff1a;http://www.xz7.com/article/86938.html 其实就是下载crx文件后&#xff…

linuxOPS基础_linux沾滞位T(sticky bit)

命令&#xff1a;chmod 语法&#xff1a;# chmod [选项] 文件夹 作用&#xff1a;只允许文件的创建者和root用户删除文件 常用选项&#xff1a;ot 添加粘滞位 ​ o-t 去掉粘滞位 ​ 用法&#xff1a;chmod ot 目录名 示例代码&#xff1a; #chmod ot 含义&#xff1a;给…

复习之linux系统中的文件传输

一、实验环境设定 本节实验需要两台虚拟机&#xff0c;ip与主机在同一网段&#xff0c;可实现ssh连接&#xff01; 1.创建虚拟机westosb 因为之前实验已存在一台虚拟机westosa,因此还需创建一台虚拟机westosb! 使用# westos-vmctl create westosb 创建虚拟机出错&#…

一个软件测试工程师的岗位职责

其实软件测试入门并不难 我们自己生活中就有接触过很多跟软件测试相关的操作。而要是从事软件测试的工作&#xff0c;就是需要对软件进行更加系统的测试&#xff0c;并把你所测试的东西进行归纳总结&#xff0c;对软件整个使用和运行情况做一个系统、规范的报告。 软件测试的学…