最大似然法(Maximum Likelihood Estimation,简称MLE)是一种统计方法,用于估计概率模型的参数。其基本思想是寻找一组参数值,使得在这组参数下,观测数据出现的概率(即似然性)最大。这种方法广泛应用于统计学、机器学习和群体遗传学等领域,因为它提供了一种从数据出发,通过优化目标函数来估计模型参数的有效途径。
原理
最大似然法的原理基于这样一个假设:
给定某个概率分布模型,如果某个参数值能使从该模型中抽取到现有样本的概率最大,那么该参数值就是最合理的估计。换句话说,我们假设数据是由某个概率分布生成的,而我们的目标是找到这个分布的最可能参数。(假设分布,基于数据,寻找参数)那么在机器学习中就是(模型已定,参数未知)。
基本步骤
(1)定义似然函数:首先,根据已知的概率分布形式和数据集,定义似然函数:
其中θ表示待估计的参数向量,x表示观测到的数据。
(2)求解最大似然估计:接下来,通过数学优化方法(如梯度上升、牛顿法等),寻找能够最大化似然函数L(θ|x)的参数值θ^。这通常涉及到对似然函数取对数,转换为对数似然函数,以便于计算和处理。
(3)评估和检验:得到最大似然估计θ^后,可以通过各种统计测试(如AIC、BIC、卡方检验等)来评估模型的拟合优度,以及进行参数的显著性检验。
简单正态分布参数估计:
下面是一个使用Python进行最大似然估计的简单例子,其中我们将对正态分布的参数进行估计:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
# 生成模拟数据
np.random.seed(0) # 设置随机种子以保证结果可复现
true_mean = 0 # 真实均值
true_std = 1 # 真实标准差
data = np.random.normal(true_mean, true_std, size=100)
# 最大似然估计
def mle_normal(data):
# 对于正态分布,最大似然估计得到的均值为样本均值,方差为样本方差的n/(n-1)倍
mean = np.mean(data)
var = np.var(data, ddof=1)
std = np.sqrt(var)
return mean, std
# 应用最大似然估计
est_mean, est_std = mle_normal(data)
# 可视化
x = np.linspace(min(data), max(data), 100)
plt.hist(data, bins=30, density=True, alpha=0.5, label='Data Histogram')
plt.plot(x, norm.pdf(x, est_mean, est_std), label='Estimated PDF')
plt.title('Maximum Likelihood Estimation for Normal Distribution')
plt.legend()
plt.show()
高斯混合模型(Gaussian Mixture Model, GMM)最大似然估计
让我们考虑一个更复杂的情况,其中我们将使用最大似然估计来拟合一个高斯混合模型(Gaussian Mixture Model, GMM)。GMM是一种概率模型,它假设所有的数据点都是从有限数量的高斯分布中生成的。每个高斯分布都有自己的均值、方差和权重。
为了简化,我们将使用一个由两个高斯分布组成的GMM,并使用Python的scikit-learn
库来执行最大似然估计。
!pip install scikit-learn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
接下来,我们将生成两组服从不同正态分布的数据,并将它们混合在一起
# 生成模拟数据
np.random.seed(0) # 设置随机种子以保证结果可复现
mean1 = [-1, 0]
cov1 = [[1, 0.5], [0.5, 1]]
mean2 = [1, 0]
cov2 = [[1, -0.5], [-0.5, 1]]
# 生成两个高斯分布的数据
data1 = np.random.multivariate_normal(mean1, cov1, size=500)
data2 = np.random.multivariate_normal(mean2, cov2, size=500)
# 混合数据
data = np.vstack([data1, data2])
现在,我们将使用GaussianMixture类来拟合GMM模型,并进行最大似然估计:
# 初始化GMM模型
gmm = GaussianMixture(n_components=2, covariance_type='full', random_state=0)
# 使用数据拟合GMM模型
gmm.fit(data)
# 获取估计的参数
weights = gmm.weights_
means = gmm.means_
covariances = gmm.covariances_
print("Estimated Weights:", weights)
print("Estimated Means:", means)
print("Estimated Covariances:", covariances)
# 可视化
plt.scatter(data[:, 0], data[:, 1], alpha=0.5, label='Data')
# 绘制高斯分布的等高线
x, y = np.meshgrid(np.linspace(-5, 5, 100), np.linspace(-5, 5, 100))
xy = np.column_stack([x.ravel(), y.ravel()])
for i in range(gmm.n_components):
z = np.exp(gmm.score_samples(xy))
z = z.reshape(x.shape)
plt.contour(x, y, z, levels=[0.01, 0.1, 0.5, 1, 2], colors='red' if i == 0 else 'blue', alpha=0.5)
plt.title('Gaussian Mixture Model with Maximum Likelihood Estimation')
plt.legend()
plt.show()
Seaborn高斯混合模型的等高线可视化
为了创建一个更复杂和精美的可视化,我们可以使用seaborn
库来增强我们的图表。seaborn
是一个基于matplotlib
的高级可视化库,它提供了更多的绘图选项和更美观的默认主题。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
import seaborn as sns
# 生成模拟数据
np.random.seed(0) # 设置随机种子以保证结果可复现
mean1 = [-1, 0]
cov1 = [[1, 0.5], [0.5, 1]]
mean2 = [1, 0]
cov2 = [[1, -0.5], [-0.5, 1]]
# 生成两个高斯分布的数据
data1 = np.random.multivariate_normal(mean1, cov1, size=500)
data2 = np.random.multivariate_normal(mean2, cov2, size=500)
# 混合数据
data = np.vstack([data1, data2])
# 初始化GMM模型
gmm = GaussianMixture(n_components=2, covariance_type='full', random_state=0)
# 使用数据拟合GMM模型
gmm.fit(data)
# 获取估计的参数
weights = gmm.weights_
means = gmm.means_
covariances = gmm.covariances_
# 可视化
sns.set(style="white", color_codes=True)
sns.jointplot(x=data[:, 0], y=data[:, 1], kind="kde", space=0)
# 绘制高斯分布的等高线
x, y = np.meshgrid(np.linspace(-5, 5, 100), np.linspace(-5, 5, 100))
xy = np.column_stack([x.ravel(), y.ravel()])
for i in range(gmm.n_components):
z = np.exp(gmm.score_samples(xy))
z = z.reshape(x.shape)
plt.contour(x, y, z, levels=[0.01, 0.1, 0.5, 1, 2], colors='red' if i == 0 else 'blue', alpha=0.5)
plt.suptitle('Gaussian Mixture Model with Maximum Likelihood Estimation')
plt.show()