使用DTW算法简单实现曲线的相似度计算

news2025/1/16 16:40:34

相对接近产品交付形态的实现:基于DTW距离的KNN算法实现股票高相似筛选案例-CSDN博客

一、问题背景和思路

问题背景:如果你有历史股票的K线图,怎么从众多股票K线图中提取出TopN相似的几支股票,用来提供给投资者或专家做分析、决策参考使用呢?

问题理解:股票K线图中除了时间点数据,还包含每个时间的4个数据 最高价、最低价、开盘价和收盘价。下文中我们均使用每日的股票收盘价作为股票走势曲线,基于此给出一个找出TopN相似股票曲线的方案和代码实现。这里先假设没有标记过的相似曲线样本,使用无监督的方法去做。

方案思路:上面我们把问题简化为找出最相似的2条曲线,那么就需要先找一个度量算法,怎么度量两条曲线的相似性或距离。如果度量算法符合我们的先验经验,比如我们人工标示最相似的两条曲线,算法给出的距离度量也是最小的,我们判断最不相似的,算法给出的距离度量也是最大的。那么度量算法就是可以拿来做排序使用的。有了判断两条曲线距离度量的算法,且算法结果具有一定的排序性,那么就可以计算出任两条曲线的距离度量值,根据此值就可以给出与指定股票曲线最相似的TopN股票曲线。

有不少曲线相似度算法可供选择,但每种都有自己的局限和适合的场景。比如:

1、欧几里德距离(Euclidean Distance):
适用场景:适用于曲线样本数相同的情况,当曲线具有明显的平移和缩放变换时表现较好。

2、动态时间规整(Dynamic Time Warping,DTW):
特点:考虑了时间轴的变化,能够捕捉曲线的形状相似性。对于时间轴缩放和平移具有一定的容忍性。
适用场景:适用于曲线在时间上存在变换、平移、扭曲等情况,比如语音识别、时间序列数据分析等。

3、余弦相似度(Cosine Similarity):
特点:忽略了曲线的振幅,只关注其方向。适用于振幅不重要的情况。
适用场景:文本分类、推荐系统中用户兴趣相似性等。

4、皮尔逊相关系数(Pearson Correlation Coefficient):
特点:衡量线性相关性,取值范围在-1到1之间。
适用场景:适用于评估两个变量之间的线性关系,不仅限于时间序列数据。

5、曼哈顿距离(Manhattan Distance):
特点:考虑了各维度之间的差异,适用于具有多维度的曲线数据。
适用场景:图像识别、多维时间序列分析等。

6、动态核相关(Dynamic Kernel Correlation,DKC):
特点:将时间序列映射到高维特征空间中,计算相关性。可以捕获非线性关系。
适用场景:适用于非线性关系较为复杂的时间序列数据。

7、平均绝对误差(Mean Absolute Error,MAE):
特点:衡量实际值和预测值之间的差异。
适用场景:用于衡量预测模型的精度,例如回归模型的性能评估。

二、代码简单实现

下面使用DTW算法作为两条曲线的相似性度量算法,mock下数据简单做一个实现。

2.1 mock股票数据

这里简单mock下数据代表x,y,z三支股票的K线图。或者申请下股票网站的API接口做下数据爬虫。

假设x股票是我们在观察想投资的股票,y,z股票是希望去比较的股票。

%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

# We define two sequences x, y as numpy array
# where y is actually a sub-sequence from x
x = np.array([2, 0, 1, 1, 2, 4, 2, 1, 2, 0]).reshape(-1, 1)
y = np.array([1, 1, 2, 4, 2, 1, 2, 0, 0, 0]).reshape(-1, 1)
z = np.array([3, 2, 2, 4, 2, 3, 2, 0, 2, 3]).reshape(-1, 1)

plt.plot(x, label='x', color= (0.7,0.2,0.1))
plt.plot(y, label='y',  color= (0.2,0.1,0.7))
plt.plot(z, label='z', color= (0.2,0.7,0.1))
plt.title('Our two temporal sequences')
plt.legend()

2.2 DTW算法度量曲线相似性

根据这里mock的数据,人工也能判断y该是与x曲线最接近的,z次之。首先验证下DTW算法得出的曲线距离排序性的结论是不是跟人工判断的一致。

(1) 首先曲线x与自身的dtw距离是0,即相关性最大
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

x = np.array([2, 0, 1, 1, 2, 4, 2, 1, 2, 0]).reshape(-1, 1)
plt.plot(x, label='x', color= (0.7,0.2,0.1))
plt.title('Our two temporal sequences')
plt.legend()

from dtw import dtw
# dtw算法中使用 L2 norm 作为元素比较距离
l2_norm = lambda x, y: (x - y) ** 2
dist11, cost_matrix11, acc_cost_matrix11, path11 = dtw(x, x, dist=l2_norm)
print("曲线x与曲线x的dtw距离:",dist11)
plt.imshow(acc_cost_matrix11.T, origin='lower', cmap='gray', interpolation='nearest')
plt.plot(path11[0], path11[1], 'w')
plt.show()

(2)曲线x与y距离 小于 x与z的距离,即x与y相关性 高于x与z
dist1, cost_matrix1, acc_cost_matrix1, path1 = dtw(x, y, dist=l2_norm)
dist2, cost_matrix2, acc_cost_matrix2, path2 = dtw(x, z, dist=l2_norm)
print("曲线x与曲线y的dtw距离:",dist1)
print("曲线x与曲线z的dtw距离:",dist1)
plt.imshow(acc_cost_matrix1.T, origin='lower', cmap='gray', interpolation='nearest')
plt.imshow(acc_cost_matrix2.T, origin='lower', cmap='gray', interpolation='nearest')
plt.plot(path1[0], path1[1], 'w')
plt.plot(path2[0], path1[1], 'w')
plt.show()

dtw算法dtw(x,y)=2 < dtw(x,z)=18 判断曲线y与曲线x的距离小于曲线z与x的距离,即相关性更高,符合期望,所以可以作为股票相关性算法使用。

(3)使用DTW算法计算出与曲线x高相关的TopN曲线 

将更多股票曲线数据与x曲线比较,找出距离最短的TopN曲线,即为TopN高相关股票曲线。

x_similaritys = {}
# 假设stock_data中存放里足够多的股票和其k线数据
for stock_name, k_line in stock_data:
    dist, cost_matrix, acc_cost_matrix, path = dtw(x, k_line, dist=l2_norm)
    if stock_name not in x_similaritys.keys():
        x_similaritys[stock_name] = dist

#选出与x曲线Top10相似曲线
res = sorted(x_similaritys.items(), key=lambda x: x[1])
print(res[:10])

Done

三、延伸阅读:关于DTW距离算法在KNN近邻算法中的应用

KNN近邻算法中使用DTW距离算法的代码实现,github代码

import sys
import collections
import itertools
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import mode
from scipy.spatial.distance import squareform

plt.style.use('bmh')
%matplotlib inline

try:
    from IPython.display import clear_output
    have_ipython = True
except ImportError:
    have_ipython = False

class KnnDtw(object):
    """K-nearest neighbor classifier using dynamic time warping
    as the distance measure between pairs of time series arrays
    
    Arguments
    ---------
    n_neighbors : int, optional (default = 5)
        Number of neighbors to use by default for KNN
        
    max_warping_window : int, optional (default = infinity)
        Maximum warping window allowed by the DTW dynamic
        programming function
            
    subsample_step : int, optional (default = 1)
        Step size for the timeseries array. By setting subsample_step = 2,
        the timeseries length will be reduced by 50% because every second
        item is skipped. Implemented by x[:, ::subsample_step]
    """
    
    def __init__(self, n_neighbors=5, max_warping_window=10000, subsample_step=1):
        self.n_neighbors = n_neighbors
        self.max_warping_window = max_warping_window
        self.subsample_step = subsample_step
    
    def fit(self, x, l):
        """Fit the model using x as training data and l as class labels
        
        Arguments
        ---------
        x : array of shape [n_samples, n_timepoints]
            Training data set for input into KNN classifer
            
        l : array of shape [n_samples]
            Training labels for input into KNN classifier
        """
        
        self.x = x
        self.l = l
        
    def _dtw_distance(self, ts_a, ts_b, d = lambda x,y: abs(x-y)):
        """Returns the DTW similarity distance between two 2-D
        timeseries numpy arrays.

        Arguments
        ---------
        ts_a, ts_b : array of shape [n_samples, n_timepoints]
            Two arrays containing n_samples of timeseries data
            whose DTW distance between each sample of A and B
            will be compared
        
        d : DistanceMetric object (default = abs(x-y))
            the distance measure used for A_i - B_j in the
            DTW dynamic programming function
        
        Returns
        -------
        DTW distance between A and B
        """

        # Create cost matrix via broadcasting with large int
        ts_a, ts_b = np.array(ts_a), np.array(ts_b)
        M, N = len(ts_a), len(ts_b)
        cost = sys.maxsize * np.ones((M, N))

        # Initialize the first row and column
        cost[0, 0] = d(ts_a[0], ts_b[0])
        for i in np.arange(1, M):
            cost[i, 0] = cost[i-1, 0] + d(ts_a[i], ts_b[0])

        for j in np.arange(1, N):
            cost[0, j] = cost[0, j-1] + d(ts_a[0], ts_b[j])

        # Populate rest of cost matrix within window
        for i in np.arange(1, M):
            for j in np.arange(max(1, i - self.max_warping_window),
                            min(N, i + self.max_warping_window)):
                choices = cost[i - 1, j - 1], cost[i, j-1], cost[i-1, j]
                cost[i, j] = min(choices) + d(ts_a[i], ts_b[j])

        # Return DTW distance given window 
        return cost[-1, -1]
    
    def _dist_matrix(self, x, y):
        """Computes the M x N distance matrix between the training
        dataset and testing dataset (y) using the DTW distance measure
        
        Arguments
        ---------
        x : array of shape [n_samples, n_timepoints]
        
        y : array of shape [n_samples, n_timepoints]
        
        Returns
        -------
        Distance matrix between each item of x and y with
            shape [training_n_samples, testing_n_samples]
        """
        
        # Compute the distance matrix        
        dm_count = 0
        
        # Compute condensed distance matrix (upper triangle) of pairwise dtw distances
        # when x and y are the same array
        if(np.array_equal(x, y)):
            x_s = np.shape(x)
            dm = np.zeros((x_s[0] * (x_s[0] - 1)) // 2, dtype=np.double)
            
            p = ProgressBar(shape(dm)[0])
            
            for i in np.arange(0, x_s[0] - 1):
                for j in np.arange(i + 1, x_s[0]):
                    dm[dm_count] = self._dtw_distance(x[i, ::self.subsample_step],
                                                      y[j, ::self.subsample_step])
                    
                    dm_count += 1
                    p.animate(dm_count)
            
            # Convert to squareform
            dm = squareform(dm)
            return dm
        
        # Compute full distance matrix of dtw distnces between x and y
        else:
            x_s = np.shape(x)
            y_s = np.shape(y)
            dm = np.zeros((x_s[0], y_s[0])) 
            dm_size = x_s[0]*y_s[0]
            
            p = ProgressBar(dm_size)
        
            for i in np.arange(0, x_s[0]):
                for j in np.arange(0, y_s[0]):
                    dm[i, j] = self._dtw_distance(x[i, ::self.subsample_step],
                                                  y[j, ::self.subsample_step])
                    # Update progress bar
                    dm_count += 1
                    p.animate(dm_count)
        
            return dm
        
    def predict(self, x):
        """Predict the class labels or probability estimates for 
        the provided data

        Arguments
        ---------
          x : array of shape [n_samples, n_timepoints]
              Array containing the testing data set to be classified
          
        Returns
        -------
          2 arrays representing:
              (1) the predicted class labels 
              (2) the knn label count probability
        """
        
        dm = self._dist_matrix(x, self.x)

        # Identify the k nearest neighbors
        knn_idx = dm.argsort()[:, :self.n_neighbors]

        # Identify k nearest labels
        knn_labels = self.l[knn_idx]
        
        # Model Label
        mode_data = mode(knn_labels, axis=1)
        mode_label = mode_data[0]
        mode_proba = mode_data[1]/self.n_neighbors

        return mode_label.ravel(), mode_proba.ravel()

class ProgressBar:
    """This progress bar was taken from PYMC
    """
    def __init__(self, iterations):
        self.iterations = iterations
        self.prog_bar = '[]'
        self.fill_char = '*'
        self.width = 40
        self.__update_amount(0)
        if have_ipython:
            self.animate = self.animate_ipython
        else:
            self.animate = self.animate_noipython

    def animate_ipython(self, iter):
        print('\r', self,)
        sys.stdout.flush()
        self.update_iteration(iter + 1)

    def update_iteration(self, elapsed_iter):
        self.__update_amount((elapsed_iter / float(self.iterations)) * 100.0)
        self.prog_bar += '  %d of %s complete' % (elapsed_iter, self.iterations)

    def __update_amount(self, new_amount):
        percent_done = int(round((new_amount / 100.0) * 100.0))
        all_full = self.width - 2
        num_hashes = int(round((percent_done / 100.0) * all_full))
        self.prog_bar = '[' + self.fill_char * num_hashes + ' ' * (all_full - num_hashes) + ']'
        pct_place = (len(self.prog_bar) // 2) - len(str(percent_done))
        pct_string = '%d%%' % percent_done
        self.prog_bar = self.prog_bar[0:pct_place] + \
            (pct_string + self.prog_bar[pct_place + len(pct_string):])

    def __str__(self):
        return str(self.prog_bar)

DTW算法计算两条曲线的距离:

time = np.linspace(0,20,1000)
amplitude_a = 5*np.sin(time)
amplitude_b = 3*np.sin(time + 1)

m = KnnDtw()
distance = m._dtw_distance(amplitude_a, amplitude_b)

fig = plt.figure(figsize=(12,4))
_ = plt.plot(time, amplitude_a, label='A')
_ = plt.plot(time, amplitude_b, label='B')
_ = plt.title('DTW distance between A and B is %.2f' % distance)
_ = plt.ylabel('Amplitude')
_ = plt.xlabel('Time')
_ = plt.legend()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1961286.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

任意空间平面点云旋转至与水平面平行(python)

1、背景介绍 将三维空间中位于任意平面上的点云数据&#xff0c;通过一系列的坐标变换&#xff08;平移旋转&#xff09;&#xff0c;使其投影到与XOY平面平行&#xff0c;同时点云形状保持不变。具体效果如下&#xff0c;对于原始点集&#xff08;蓝色点集&#xff09;&#x…

关于 AGGLIGATOR(猛禽)网络宽频聚合器

AGGLIGATOR 是一个用于多个链路UDP/IP带宽聚合的工具软件&#xff0c;类似MTCP的作用&#xff0c;不过它是针对UDP/IP宽频聚合的。 举个例子&#xff1a; 中国大陆有三台公网服务器&#xff0c;中国香港有一台大带宽服务器。 那么&#xff1a; AGGLIGATOR 允许中国大陆的客户…

【C++高阶】:深入探索C++11

✨ 心似白云常自在&#xff0c;意如流水任东西 &#x1f30f; &#x1f4c3;个人主页&#xff1a;island1314 &#x1f525;个人专栏&#xff1a;C学习 &#x1f680; 欢迎关注&#xff1a;&#x1f44d;点赞 &#x1f4…

Prometheus+Grafana+Alertmanager监控告警

PrometheusGrafanaAlertmanager告警 Alertmanager开源地址&#xff1a;github.com/prometheus Prometheus是一款基于时序数据库的开源监控告警系统&#xff0c;它是SoundCloud公司开源的&#xff0c;SoundCloud的服务架构是微服务架构&#xff0c;他们开发了很多微服务&#xf…

【实际源码】工厂进销存管理系统(仓库、采购、生产、销售)

工厂进销存管理系统是一个集采购管理、仓库管理、生产管理和销售管理于一体的综合解决方案。该系统旨在帮助企业优化流程、提高效率、降低成本&#xff0c;并实时掌握各环节的运营状况。 在采购管理方面&#xff0c;系统能够处理采购订单、供应商管理和采购入库等流程&#xff…

亚马逊云科技 re:Inforce 2024中国站大会

亚马逊云科技 re:Inforce 2024中国站大会 - 生成式AI时代的全面安全&#xff0c;将于7月25日本周四在北京富力万丽酒店揭幕

向量数据库(二):Qdrant

写在前面 我们借助 Qdrant 来了解向量数据库的一些内容 内容 什么是 Qdrant? Qdrant 是一个开源的针对向量相似性搜索的引擎,它提供了一系列的 API 用于对向量数据进行存储、搜索和管理等功能。 下面是来自 Qdrant 官网的一个架构图: 初步了解 Qdrant 里的一些概念 …

h5页面 打开自动跳转小程序

移动端项目中打开页面&#xff0c;直接跳转到小程序功能效果图 pc mobile 代码 样式代码css&#xff1a; <style>* {margin: 0;padding: 0;}html,body {height: 100%;}.full {position: absolute;top: 0;bottom: 0;left: 0;right: 0;}.public-web-container {displa…

每日OJ_牛客CM26 二进制插入

目录 牛客CM26 二进制插入 解析代码 牛客CM26 二进制插入 二进制插入_牛客题霸_牛客网 解析代码 m:1024&#xff1a;100000000 00 n:19 &#xff1a; 10011 要把n的二进制值插入m的第j位到第i位&#xff0c;只需要把n先左移j位&#xff0c;然后再进行或运算&#xff08;|&am…

JS知识点巩固

目录 前言 一、reduce方法 二、二维数组的行和列交换 总结 前言 这里的知识点记录的是日常生活中容易搞忘的知识点 一、reduce方法 function(total,currentValue, index,arr) redece可以用作累加&#xff1a;可以传入初始值&#xff0c;如果传入初始值&#xff0c;则从累加…

【FAS】《The Research of RGB Image Based Face Anti-Spoofing》

文章目录 1、原文2、相关工作3、基于特征解耦的人脸活体检测算法3.1、方法3.2、实验结果 4、基于解构与组合的人脸活体检测方法4.1、方法4.2、实验 5、作者总结6、结论&#xff08;own&#xff09;7、附录7.1、CycleGAN7.2、InfoGAN7.3、3D Face 1、原文 [1]张克越.基于RGB图像…

项目成功秘诀:工单管理系统如何加速进程

国内外主流的10款项目工单管理系统对比&#xff1a;PingCode、Worktile、浪潮云工单管理系统、华为企业智能工单系统、金蝶云苍穹、紫光软件管理系统、Jira、Asana、ServiceNow、Smartsheet。 在管理日益复杂的个人项目时&#xff0c;找到一款能够真正符合需求的管理软件&#…

Stable Diffusion 图生图

区别于文生图&#xff0c;所谓的图生图&#xff0c;俗称的垫图&#xff0c;就是比文生图多了一张参考图&#xff0c;由参考一张图来生成图片&#xff0c;影响这个图片的要素不仅只靠提示词了&#xff0c;还有这个垫图的因素&#xff0c;这个区域就上上传垫图的地方&#xff0c;…

二叉树--堆(下卷)

二叉树–堆&#xff08;下卷&#xff09; 如果有还没看过上卷的&#xff0c;可以看这篇&#xff0c;链接如下&#xff1a; http://t.csdnimg.cn/HYhax 向上调整算法 堆的插⼊ 将新数据插⼊到数组的尾上&#xff0c;再进⾏向上调整算法&#xff0c;直到满⾜堆。 &#x1f4…

Monaco 使用 LinkedEditingRangeProvider

Monaco LinkEdit 功能是指同时修改同样的字符串&#xff0c;例如在编辑 Html 时&#xff0c;修改开始标签时会同时修改闭合标签。Monaco 支持自定义需要一起更新的字符串列表。最终效果如下&#xff1a; 首先&#xff0c;通过 registerLinkedEditingRangeProvider 注册 LinkEd…

day17(nginx反向代理)

反向代理 安装nginx 1.26.1 平滑升级 负载均衡 1.nginx 反向代理配置 反向代理&#xff1a;⽤户直接访问反向代理服务器就可以获得⽬标服务器 &#xff08;后端服务器&#xff09;的资源。 反向代理效果&#xff1a;当访问200主机&#xff08;web1&#xff09;&#xff0c;&a…

vite instanceof 失效

背景&#xff1a;给一个巨石单体项目进行标准化模块拆分&#xff0c;封装出来的模块代码用 vite 进行构建&#xff0c;但模块启动后页面上的表现一直和 webpack 那版不一致 一步步 debug 后&#xff0c;发现问题出在下面这个判断条件 const GeneratorFunction function* () …

解决jenkins配置extendreport不展示样式

下载插件&#xff1a;Groovy 、 HTML Publisher plugin 配置&#xff1a; 1&#xff09;Post Steps &#xff1a; 增加 Execute system Groovy script &#xff0c; 内容&#xff1a; System.setProperty("hudson.model.DirectoryBrowserSupport.CSP", "&qu…

【React】详解 React Router

文章目录 一、React Router 的基本概念1. 什么是 React Router&#xff1f;2. React Router 的主要特性 二、React Router 的核心组件1. BrowserRouter2. Route3. Link4. Switch 三、React Router 的使用方法1. 安装 React Router2. 定义路由组件3. 配置路由4. 启动应用 四、Re…

再谈istio

微服务之间调用观测&#xff0c; istio的版本是对k8s 版本有要求的&#xff0c;案例中 istioshi 1.15.2 版本的 一、下载 Istio 二、部署 egressgateway 和 ingressgateway 分别控制进出 istio 通过 Envoy proxy&#xff0c;也就是pod加边车的方式来控制用户对svc的访问 这样…