一、说明
ARIMA 时间序列预测模型非常适合具有趋势和季节性的序列。它是一种被广泛采用的经典模型,通常作为现代深度学习方法基准测试的基准。然而,估计其准确的参数具有挑战性。研究人员和开发人员通常使用包括视觉绘图在内的试错方法。
二、什么是ARIMA模型?
ARIMA 模型是“自动回归移动平均线”的缩写,是一类使用过去值来估计未来预测的模型。ARIMA 模型由三个参数定义:p、d 和 q。
ARIMA模型在文献中研究了不同的变体。在这篇文章中,我们将使用 statsmodels 库中的实现。
整个笔记本显示了此处提供的简单实现。您可以为数据集修改此实现。根据需要创建单独的训练-测试拆分。我简单概述了重要的调整步骤。
三、完整代码:使用 Mango 自动调优
import pandas as pd
df = pd.read_csv('https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv')
from statsmodels.tsa.arima.model import ARIMA
from sklearn.metrics import mean_squared_error
from mango import scheduler, Tuner
def arima_objective_function(args_list):
global data_values
params_evaluated = []
results = []
for params in args_list:
try:
p,d,q = params['p'],params['d'], params['q']
trend = params['trend']
model = ARIMA(data_values, order=(p,d,q), trend = trend)
predictions = model.fit()
mse = mean_squared_error(data_values, predictions.fittedvalues)
params_evaluated.append(params)
results.append(mse)
except:
#print(f"Exception raised for {params}")
#pass
params_evaluated.append(params)
results.append(1e5)
#print(params_evaluated, mse)
return params_evaluated, results
param_space = dict(p= range(0, 30),
d= range(0, 30),
q =range(0, 30),
trend = ['n', 'c', 't', 'ct']
)
conf_Dict = dict()
conf_Dict['num_iteration'] = 200
data_values = list(df['#Passengers'])
tuner = Tuner(param_space, arima_objective_function, conf_Dict)
results = tuner.minimize()
print('best parameters:', results['best_params'])
print('best loss:', results['best_objective'])
best parameters: {'d': 0, 'p': 17, 'q': 23, 'trend': 'ct'}
best loss: 112.06886739549542
四、调整步骤
数据集:我们将使用一个简单的航空通行证数据集来记录航空公司乘客的数量。
import pandas as pd
df = pd.read_csv('https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv')
df.head()
绘制序列以查看趋势和季节性
from matplotlib import pyplot as plt
f = plt.figure()
f.set_figwidth(15)
f.set_figheight(6)
plt.plot(df['#Passengers'], linewidth = 4, label = "original Series")
plt.legend(fontsize=25)
plt.xlabel('Months', fontsize = 25)
plt.ylabel('Count', fontsize = 25)
plt.show()
该数据集呈上升趋势,季节性为 12 个月。
传统上,一种方法是使用领域知识从原始序列中删除趋势和季节性,然后使用残差序列来预测未来。但是,我们将研究一种更直接的自动化方法。
五、如何自动调整参数?
我们将使用一个名为 Mango 的最先进的优化库来为我们的数据集找到最佳参数。让我们首先定义参数的范围。在这种优化方法中,我们定义了可能的参数范围。这个范围可以非常大,不需要精确。这些参数是从 statsmodels 库中定义的。
param_space = dict(p= range(0, 30),
d= range(0, 30),
q =range(0, 30),
trend = ['n', 'c', 't', 'ct']
)
参数空间是使用 python 构造定义的:range 和 list。参数的总可能组合集为 30*30*30*4 = 108,000。因此,详尽的网格搜索非常耗时。我们将使用贝叶斯搜索优化器方法在 ~100 次迭代内自动执行搜索。注意:根据您的数据集,范围的大小及其搜索空间可能会有所不同。定义一个大的搜索空间是可以的;让优化器为您完成艰巨的工作。
六、训练 ARIMA 模型
要使用 Mango,我们必须定义一个目标函数,该函数返回给定参数集的 ARIMA 模型误差。
from statsmodels.tsa.arima.model import ARIMA
from sklearn.metrics import mean_squared_error
from mango import scheduler, Tuner
def arima_objective_function(args_list):
global data_values
params_evaluated = []
results = []
for params in args_list:
try:
p,d,q = params['p'],params['d'], params['q']
trend = params['trend']
model = ARIMA(data_values, order=(p,d,q), trend = trend)
predictions = model.fit()
mse = mean_squared_error(data_values, predictions.fittedvalues)
params_evaluated.append(params)
results.append(mse)
except:
#print(f"Exception raised for {params}")
#pass
params_evaluated.append(params)
results.append(1e5)
#print(params_evaluated, mse)
return params_evaluated, results
我们从 Mango 库中获取参数,并返回参数及其结果。结果包括经过训练的 ARIMA 模型的误差。在这种情况下,错误是mean_squared_error。 我们还包括 try-catch 语句,因为 ARIMA 模型可能不会收敛于每个参数组合/选择。我们只返回模型工作的参数集。Mango 在内部以最佳方式使用这些参数,以在很少的迭代(本例中为 100 次)内找到最佳模型。我们的目标是找到最小化误差函数的参数。
Mango 的控制迭代:Config 参数。
from mango import scheduler, Tuner
conf_Dict = dict()
conf_Dict['num_iteration'] = 200
tuner = Tuner(param_space, arima_objective_function, conf_Dict)
七、可视化最佳模型预测
总体而言,我们看到参数的可能组合总数非常大(108,000)。
def plot_arima(data_values, order = (1,1,1), trend = 'c'):
print('final model:', order, trend)
model = ARIMA(data_values, order=order, trend = trend)
results = model.fit()
error = mean_squared_error(data_values, results.fittedvalues)
print('MSE error is:', error)
from matplotlib import pyplot as plt
f = plt.figure()
f.set_figwidth(15)
f.set_figheight(6)
plt.plot(data_values, label = "original Series", linewidth = 4)
plt.plot(results.fittedvalues, color='red', label = "Predictions", linestyle='dashed', linewidth = 3)
plt.legend(fontsize = 25)
plt.xlabel('Months', fontsize = 25)
plt.ylabel('Count', fontsize = 25)
plt.show()
print(results['best_params'])
order = (results['best_params']['p'], results['best_params']['d'], results['best_params']['q'])
plot_arima(data_values, order=order, trend = results['best_params']['trend'])
如上所述,预测与基本事实完全吻合。有兴趣了解有关 Mango checkout 的更多信息,其 GitHub 存储库包含一组不同的示例。Sandeep Singh Sandha博士