【ASTGCN】模型调试学习笔记--数据生成详解(超详细)

news2025/1/12 9:00:02

利用滑动窗口生成时间序列

原理图示:

PEMS04数据集为例。

  • 该数据集维度为:(16992,307,3)16992表示时间序列的长度,307为探测器个数,即图的顶点个数,3为特征数,即流量,速度、平均占用率。
  • 现在利用滑动窗口生成新时间序列,假设滑动窗口大小(每次滑动所取时间序列的多少)为4,滑动窗口步长(每次滑动几格)为1,如图1所示,每次取4个长度的数据(总长度为16992,也就是图1中的L),滑动1个长度取一次,之后将滑动窗口取到的数据合并成新数据,如图2所示。
    在这里插入图片描述
    图1
    在这里插入图片描述
    图2

函数操作

函数调用关系:

def read_and_generate_dataset(graph_signal_matrix_filename,
                              num_of_weeks, num_of_days,
                              num_of_hours, num_for_predict,
                              points_per_hour=12, save=False):
   def get_sample_indices(data_sequence, num_of_weeks, num_of_days, num_of_hours,
                       label_start_idx, num_for_predict, points_per_hour=12):
             def search_data(sequence_length, num_of_depend, label_start_idx,
                 num_for_predict, units, points_per_hour):

read_and_generate_dataset函数调用get_sample_indices函数,get_sample_indices函数再调用search_data函数。

search_data函数

  • 函数功能:获取每个滑动生成的窗口的索引的首尾。
  • 函数具体操作:
def search_data(sequence_length, num_of_depend, label_start_idx,
                num_for_predict, units, points_per_hour):
 ####参数说明          
 #sequence_length在源码中接收的参数是get_sample_indices传递过来的data_sequence.shape[0],即原始数据的shape(16992,307,3)
 的第一个维度,即16992
 #num_of_depend:生成近期周期或日周期或周周期,源码中默认为num_of_hours = 1
 #label_start_idx在源码中接收的参数是get_sample_indices传递过来的label_start_idx,而get_sample_indices中的
 label_start_idx是read_and_generate_dataset函数传递过来的idx,在read_and_generate_dataset中,idx
 是range(data_seq.shape[0]),即0~16991,所以search_data中的label_start_idx是0~16991,search_data是处于for循环中被调用的。
 #num_for_predict:要预测的时间步长,源码中为12,也就是一个小时
 #units在get_sample_indices函数中传过来的值有三个,分别是7 * 24241,即前文所说的滑动窗口的步长,也就是论文原文中的
 近期周期、日周期、周周期,1代表一个小时;24代表24个小时,即一天;7*24代表一周。
 #points_per_hour:一个小时的步长,12

 	# 如果points_per_hour小于0,则抛出一个ValueError异常,提示points_per_hour应该大于0
    if points_per_hour < 0:
        raise ValueError("points_per_hour should be greater than 0!")
        
    # 检查预测目标的起始索引加上要预测的时间步长是否超出了历史数据的长度,如果超出了历史数据的长度,则返回None,表示无法生成
    有效的索引范围;例如循环进行到idx(label_start_idx)=16981,此时16981+12>16992,则返回空。
    if label_start_idx + num_for_predict > sequence_length:
        return None
        
    # 创建一个空列表,用于存储生成的索引范围
    x_idx = []
    
    # 遍历依赖的数据点数量范围。在每次迭代中,计算当前依赖序列的起始索引start_idx和结束索引end_idx
    for i in range(1, num_of_depend + 1):#源码中num_of_hours为1,此循环只执行一次
        # 计算当前依赖序列的起始索引
        start_idx = label_start_idx - points_per_hour * units * i  # idx-12*1*1
        # 计算当前依赖序列的结束索引
        end_idx = start_idx + num_for_predict # start_idx+12
        # 检查计算得到的起始索引是否大于等于0。如果大于等于0,说明该序列在历史数据中是有效的,可以加入到结果列表中
        if start_idx >= 0:
            x_idx.append((start_idx, end_idx))
        else:
            return None
    # 检查生成的索引范围的数量是否与预期的依赖数据点数量相等。如果不相等,则说明生成的索引范围数量不正确,返回None
    if len(x_idx) != num_of_depend:
        return None
    # 将生成的索引范围列表进行反转,并返回
    return x_idx[::-1]

举例说明:

  • for i in range(1, num_of_depend + 1):中num_of_depend为1时(源码中就为1),表示使用一个小时(近期)的历史数据来预测。
    因此代码为:for i in range(1, 2):,则此循环只执行一次。

    • 假设当前read_and_generate_dataset函数中的循环执行到idx=13
      start_idx = label_start_idx - points_per_hour * units * i :,则此时label_start_idx=13,start_idx =13-12*1*1=1end_idx=start_idx +num_for_predict =1+12=13
      if start_idx >= 0: x_idx.append((start_idx, end_idx)),此时 x_idx=[1,13],len(x_idx)==1
      if len(x_idx) != num_of_depend:,1=1

    • 假设当前read_and_generate_dataset函数中的循环执行到idx=14
      start_idx = label_start_idx - points_per_hour * units * i :,则此时label_start_idx=14,start_idx =14-12*1*1=2end_idx=start_idx +num_for_predict =2+12=14
      if start_idx >= 0: x_idx.append((start_idx, end_idx)),此时 x_idx=[2,14],len(x_idx)==1
      if len(x_idx) != num_of_depend:,1=1

  • for i in range(1, num_of_depend + 1):中num_of_depend为2时,表示使用2个小时(近期)的历史数据来预测。
    因此代码为:for i in range(1, 3):,则此循环执行两次。

    • 假设当前read_and_generate_dataset函数中的循环执行到idx=12for i in range(1, 3):循环执行到i=1
      start_idx = label_start_idx - points_per_hour * units * i :,则此时label_start_idx=12,start_idx =12-12*1*1=0end_idx=start_idx +num_for_predict =0+12=12
      if start_idx >= 0: x_idx.append((start_idx, end_idx)),此时 x_idx=[0,12],len(x_idx)==1
      if len(x_idx) != num_of_depend:,1=1

    • 假设当前read_and_generate_dataset函数中的循环执行到idx=12for i in range(1, 3):循环执行到i=2
      start_idx = label_start_idx - points_per_hour * units * i :,则此时label_start_idx=13,start_idx =14-12*1*2=-10end_idx=start_idx +num_for_predict =-10+12=2
      if start_idx >= 0: x_idx.append((start_idx, end_idx)),此时 start_idx =-10<0,索引不会加入x_idx中。
      因此本次search_data函数调用(函数内部进行了两次for循环)返回None

    • 假设当前read_and_generate_dataset函数中的循环执行到idx=24for i in range(1, 3):循环执行到i=1
      start_idx = label_start_idx - points_per_hour * units * i :,则此时label_start_idx=24,start_idx =24-12*1*1=12end_idx=start_idx +num_for_predict =12+12=24
      if start_idx >= 0: x_idx.append((start_idx, end_idx)),此时 x_idx=[12,24],

    • 假设当前read_and_generate_dataset函数中的循环执行到idx=24for i in range(1, 3):循环执行到i=2
      start_idx = label_start_idx - points_per_hour * units * i :,则此时label_start_idx=24,start_idx =24-12*1*2=0end_idx=start_idx +num_for_predict =0+12=12
      if start_idx >= 0: x_idx.append((start_idx, end_idx)),此时 x_idx=[0,12],
      循环2次之后, x_idx=[12,24],[0,12],最后执行return x_idx[::-1]
      则本次函数调用返回的索引序列是[[0,12],[12,24]]

代码测试:

if __name__ == '__main__':
    data_seq = np.load(graph_signal_matrix_filename)['data']
    for idx in range(data_seq.shape[0]):
        hour_indices=search_data(data_seq.shape[0],1,idx,12,1,12)#一个小时的索引
        print("hour_indice:",hour_indices)

输出:

hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: [(0, 12)]
hour_indice: [(1, 13)]
hour_indice: [(2, 14)]
hour_indice: [(3, 15)]
hour_indice: [(4, 16)]
hour_indice: [(5, 17)]
hour_indice: [(6, 18)]
hour_indice: [(7, 19)]
hour_indice: [(8, 20)]
hour_indice: [(9, 21)]
hour_indice: [(10, 22)]
hour_indice: [(11, 23)]
hour_indice: [(12, 24)]
...
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
if __name__ == '__main__':
    data_seq = np.load(graph_signal_matrix_filename)['data']
    for idx in range(data_seq.shape[0]):
        hour_indices=search_data(data_seq.shape[0],2,idx,12,1,12)#两个小时的索引
        print("hour_indice:",hour_indices)

输出:

hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: [(0, 12), (12, 24)]
hour_indice: [(1, 13), (13, 25)]
hour_indice: [(2, 14), (14, 26)]
hour_indice: [(3, 15), (15, 27)]
...
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None

get_sample_indices函数

  • 函数功能:按近期、日周期、周周期获得样本数据。
  • 函数具体操作:
def get_sample_indices(data_sequence, num_of_weeks, num_of_days, num_of_hours,
                       label_start_idx, num_for_predict, points_per_hour=12):
                       
 ####参数说明          
 #data_sequence在源码中接收的参数是read_and_generate_dataset传递过来的data_seq,即原始数据(16992,307,3)
 #num_of_weeks:0
 #num_of_days:0
 #num_of_hours:1
 #label_start_idx在源码中接收的参数是read_and_generate_dataset传递过来的idx,在read_and_generate_dataset中,idx
 是range(data_seq.shape[0]),即0~16991,所以get_sample_indices中的label_start_idx是0~16991,get_sample_indices是
 处于for循环中被调用的。
 #num_for_predict:要预测的时间步长,源码中为12,也就是一个小时
 #points_per_hour:一个小时的步长,12
 
    week_sample, day_sample, hour_sample = None, None, None
    # 构建sample的区间限制,分界点
    
    #如果索引越界了,直接return,例如循环进行到idx(label_start_idx)=16981,此时16981+12>16992
    if label_start_idx + num_for_predict > data_sequence.shape[0]:
        return week_sample, day_sample, hour_sample, None
        
 	#num_of_weeks ,num_of_days ,num_of_hours 只能有一个大于0,因为只能同时构造一种时间序列数据
    if num_of_weeks > 0:
        week_indices = search_data(data_sequence.shape[0], num_of_weeks,
                                   label_start_idx, num_for_predict,
                                   7 * 24, points_per_hour)
        if not week_indices:
            return None, None, None, None

        week_sample = np.concatenate([data_sequence[i: j]
                                      for i, j in week_indices], axis=0)

    if num_of_days > 0:
        day_indices = search_data(data_sequence.shape[0], num_of_days,
                                  label_start_idx, num_for_predict,
                                  24, points_per_hour)
        if not day_indices:
            return None, None, None, None

        day_sample = np.concatenate([data_sequence[i: j]
                                     for i, j in day_indices], axis=0)
	#如果num_of_hours >0
    if num_of_hours > 0:
    	#生成hours切片数据,search_data函数的返回值为:[0,12][1,13],...,[[0,12],[1,13],...]等索引,
        hour_indices = search_data(data_sequence.shape[0], num_of_hours,
                                   label_start_idx, num_for_predict,
                                   1, points_per_hour)
        if not hour_indices:
            return None, None, None, None
		#按照索引在原始数据中提取数据
        hour_sample = np.concatenate([data_sequence[i: j]
                                      for i, j in hour_indices], axis=0)
    #生成标签
    target = data_sequence[label_start_idx: label_start_idx + num_for_predict]
    return week_sample, day_sample, hour_sample, target

举例说明

if __name__ == '__main__':
    data_seq = np.load(graph_signal_matrix_filename)['data']
    sample = []
    targetlist= []
    for idx in range(data_seq.shape[0]):
        hour_indice=search_data(data_seq.shape[0],1,idx,12,1,12)
       
        if not hour_indice:
            continue
        #从原始数据集中按索引取出对应的数据,并像图2那样拼在一起
        hour_sample = np.concatenate([data_seq[i: j]
                              for i, j in hour_indice], axis=0)
        sample.append(hour_sample)
        #从hour_sample的后num_for_predict个步长取出数据,作为标签
        target = data_seq[idx: idx + num_for_predict]
        targetlist.append(target)
        print("idx:",idx)
        print("hour_sample.shape:",hour_sample.shape)
    print("len(sample):",len(sample))
    print("sample[0].shape):",sample[0].shape)
    print("sample[0][0].shape:",sample[0][0].shape)
    print("len(targetlist):",len(targetlist))
    print("targetlist[0].shape:",targetlist[0].shape)
    print("targetlist[0][0].shape:",targetlist[0][0].shape)

部分输出

idx: 12
hour_sample.shape: (12, 307, 3)
idx: 13
hour_sample.shape: (12, 307, 3)
idx: 14
hour_sample.shape: (12, 307, 3)
idx: 15
hour_sample.shape: (12, 307, 3)
idx: 16
hour_sample.shape: (12, 307, 3)
idx: 17
hour_sample.shape: (12, 307, 3)
idx: 18
hour_sample.shape: (12, 307, 3)
...
len(sample): 16969
sample[0].shape): (12, 307, 3)
sample[0][0].shape: (307, 3)
len(targetlist): 16969
targetlist[0].shape: (12, 307, 3)
targetlist[0][0].shape: (307, 3)

这里的

idx: 12
hour_sample.shape: (12, 307, 3)

就是根据search_data函数生成的hour_indice: [(0, 12)]索引在原数据集中取得的。

read_and_generate_dataset函数

  • 函数功能:调用search_dataget_sample_indices函数,按近期、日周期、周周期获得样本标签,并把样本、标签、时间步都放入all_samples列表。
  • 函数具体操作:
def read_and_generate_dataset(graph_signal_matrix_filename,
                              num_of_weeks, num_of_days,
                              num_of_hours, num_for_predict,
                              points_per_hour=12, save=False):
 
    all_samples = []
    for idx in range(data_seq.shape[0]):
    
        sample = get_sample_indices(data_seq, num_of_weeks, num_of_days,
                                    num_of_hours, idx, num_for_predict,
                                    points_per_hour)
        if ((sample[0] is None) and (sample[1] is None) and (sample[2] is None)):
            continue
        week_sample, day_sample, hour_sample, target = sample
        sample = [] 
        
        # N表示传感器,F表示特征数,T表示时间段
        if num_of_weeks > 0:
            week_sample = np.expand_dims(week_sample, axis=0).transpose((0, 2, 3, 1))  # (1,N,F,T)
            sample.append(week_sample)

        if num_of_days > 0:
            day_sample = np.expand_dims(day_sample, axis=0).transpose((0, 2, 3, 1))  # (1,N,F,T)
            sample.append(day_sample)
		#把hour_sample(sample_i)进行维度变换
        if num_of_hours > 0:
            hour_sample = np.expand_dims(hour_sample, axis=0).transpose((0, 2, 3, 1))  # (1,N,F,T)
            sample.append(hour_sample)

        target = np.expand_dims(target, axis=0).transpose((0, 2, 3, 1))[:, :, 0, :]  # (1,N,T)
        sample.append(target)

        time_sample = np.expand_dims(np.array([idx]), axis=0)  # (1,1)
        sample.append(time_sample)

        all_samples.append(
            sample)  # sampe:[(week_sample),(day_sample),(hour_sample),target,time_sample] = [(1,N,F,Tw),(1,N,F,Td),(1,N,F,Th),(1,N,Tpre),(1,1)]
    # 60%作为训练,20%作为验证,20%作为测试
    split_line1 = int(len(all_samples) * 0.6)
    split_line2 = int(len(all_samples) * 0.8)
     
	training_set = [np.concatenate(i, axis=0)
                    for i in zip(*all_samples[:split_line1])]  # [(B,N,F,Tw),(B,N,F,Td),(B,N,F,Th),(B,N,Tpre),(B,1)]
    validation_set = [np.concatenate(i, axis=0)
                      for i in zip(*all_samples[split_line1: split_line2])]
    testing_set = [np.concatenate(i, axis=0)
                   for i in zip(*all_samples[split_line2:])]
    train_x = np.concatenate(training_set[:-2], axis=-1)  # (B,N,F,T')
    val_x = np.concatenate(validation_set[:-2], axis=-1)
    test_x = np.concatenate(testing_set[:-2], axis=-1)

    train_target = training_set[-2]  # (B,N,T)
    val_target = validation_set[-2]
    test_target = testing_set[-2]

    train_timestamp = training_set[-1]  # (B,1)
    val_timestamp = validation_set[-1]
    test_timestamp = testing_set[-1]

    (stats, train_x_norm, val_x_norm, test_x_norm) = normalization(train_x, val_x, test_x)

    all_data = {
        'train': {
            'x': train_x_norm,
            'target': train_target,
            'timestamp': train_timestamp,
        },
        'val': {
            'x': val_x_norm,
            'target': val_target,
            'timestamp': val_timestamp,
        },
        'test': {
            'x': test_x_norm,
            'target': test_target,
            'timestamp': test_timestamp,
        },
        'stats': {
            '_mean': stats['_mean'],
            '_std': stats['_std'],
        }
    }
    print('train x:', all_data['train']['x'].shape)
    print('train target:', all_data['train']['target'].shape)
    print('train timestamp:', all_data['train']['timestamp'].shape)
    print()
    print('val x:', all_data['val']['x'].shape)
    print('val target:', all_data['val']['target'].shape)
    print('val timestamp:', all_data['val']['timestamp'].shape)
    print()
    print('test x:', all_data['test']['x'].shape)
    print('test target:', all_data['test']['target'].shape)
    print('test timestamp:', all_data['test']['timestamp'].shape)
    print()
    print('train data _mean :', stats['_mean'].shape, stats['_mean'])
    print('train data _std :', stats['_std'].shape, stats['_std'])

    if save:
        file = os.path.basename(graph_signal_matrix_filename).split('.')[0]
        dirpath = os.path.dirname(graph_signal_matrix_filename)
        filename = os.path.join(dirpath, file + '_r' + str(num_of_hours) + '_d' + str(num_of_days) + '_w' + str(
            num_of_weeks)) + '_astcgn'
        print('save file:', filename)
        np.savez_compressed(filename,
                            train_x=all_data['train']['x'], train_target=all_data['train']['target'],
                            train_timestamp=all_data['train']['timestamp'],
                            val_x=all_data['val']['x'], val_target=all_data['val']['target'],
                            val_timestamp=all_data['val']['timestamp'],
                            test_x=all_data['test']['x'], test_target=all_data['test']['target'],
                            test_timestamp=all_data['test']['timestamp'],
                            mean=all_data['stats']['_mean'], std=all_data['stats']['_std']
                            )
    return all_data

sampletarget 在前面两个函数已经分析过,对于sample,在本函数中多了一个操作hour_sample = np.expand_dims(hour_sample, axis=0).transpose((0, 2, 3, 1))target = np.expand_dims(target, axis=0).transpose((0, 2, 3, 1))[:, :, 0, :]下面具体解释该操作。

  • np.expand_dims(hour_sample, axis=0):在指定的轴上插入一个新维度,hour_sample的原始形状是(12,307,3),经过expand_dims后变成(1,12,307,3)

  • target的原始形状是(12,307,3),经过expand_dims后变成(1,12,307,3)

  • .transpose((0, 2, 3, 1)):将hour_sample的维度按照指定的顺序重新排列,原来是(1,12,307,3),处理之后是(1,307,3,12)

  • .transpose((0, 2, 3, 1))[:, :, 0, :]:先将 target 变为(1,307,3,12),再提取 target 第三个维度的第一个特征,target 变为(1,307,1,12),即(1,307,12)

使用原始数据的数字对比:

  • sample[i]=(1,12,307,3)转为csv文件(12,307,1)只保存流量特征
#把sample[i]12, 307, 3)转为csv文件(12,307,1)只保存流量特征
import pandas as pd
def sampletocsv(i,sample):
    sample = sample[i]
    print("sample.shape:",sample.shape)
    #只提取流量
    reshaped_sample = sample[:, :, :1, :]
    print("Reshaped shape:", reshaped_sample.shape)
    data_2d = reshaped_sample.reshape(-1, reshaped_sample.shape[-1])
    df = pd.DataFrame(data_2d)
    df.to_csv(f'npytocsv/sample_dim1_{i}.csv', index=False)
if __name__ == '__main__':
    sampletocsv(0,sample)

输出如下,这里的每一份sample_i文件的维度都是(307,12),根据上面的get_sample_indices函数代码举例的输出,这样的sample_i一共有16969个。
在这里插入图片描述
sample_0:(部分)
在这里插入图片描述
sample_1:(部分)
在这里插入图片描述
sample_2:(部分)
在这里插入图片描述

  • targetlist[i](1, 307, 12)转为csv文件
#把target[i]1, 307, 12)转为csv文件
#运行search_data和get_sample_indices测试函数,不运行read_and_generate_dataset测试函数
import pandas as pd
def targetlisttocsv(i,targetlist):
    targetlist=targetlist[i]
    print("targetlist[i].shape:",targetlist.shape)
    targetlist_2d = targetlist.reshape(targetlist.shape[1], targetlist.shape[2])
    df = pd.DataFrame(targetlist_2d)
    df.to_csv(f'npytocsv/targetlist_{i}.csv', index=False)
if __name__ == '__main__':
    targetlisttocsv(2,targetlist)

在这里插入图片描述
targetlist_0:(部分)
在这里插入图片描述
targetlist_1:(部分)
在这里插入图片描述
targetlist_2:(部分)
在这里插入图片描述

  • 提取原始数据PEMS04.npz中的一部分转为csv文件
#把原始数据的流量的前i条的第j+1个特征转为csv
import pandas as pd
def dataseqtocsv(i,j):
    data_seq = np.load(graph_signal_matrix_filename)['data']
    print("data_seq.shape",data_seq.shape)

    subset_data_seq = data_seq[:i, :, j]
    print("subset_data_seq.shape:",subset_data_seq.shape)
    
    df = pd.DataFrame(subset_data_seq)
    df.to_csv(f'subset_data_seq{j+1}.csv', index=False)
if __name__ == '__main__':
    dataseqtocsv(50,0)

为了方便只输出16992条数据中的前50条,且只取流量特征,输出如下:
在这里插入图片描述
sampletarget联合起来与原数据集对比
在这里插入图片描述
在这里插入图片描述

  • 总结:get_sample_indices函数就是根据search_data函数所形成的索引,去原始数据集中提取对应的数据和标签,组成我们需要的近期(num_of_hours),日周期(num_of_days)、周周期(num_of_weeks)数据。

处理完hour_sample之后,把time_sample target 加入到all_samples中,

此时all_samples=[[hour_sample],[target],[time_sample],...,[hour_sample],[target],[time_sample]]

其中每一个 hour_sample=(1,307,3,12),每一个target =(1,307,),每一个time_sample =(1,1),且time_sample[0]=[[12]]time_sample[16969]=[[16980]]

接下来是按比例划分all_samples60%training_set20%validation_set20%testing_set

training_set = [np.concatenate(i, axis=0)
                    for i in zip(*all_samples[:split_line1])] 
validation_set = [np.concatenate(i, axis=0)
                    for i in zip(*all_samples[split_line1: split_line2])]
testing_set = [np.concatenate(i, axis=0)
                    for i in zip(*all_samples[split_line2:])]

training_set中取出从第一个到倒数第二个之间元素(不包括倒数第二个),即[hour_sample]=(1,307,3,12),并沿着最后一个轴(时间)连接起来,组成train_xval_x test_x 同理。

    train_x = np.concatenate(training_set[:-2], axis=-1)
    val_x = np.concatenate(validation_set[:-2], axis=-1)
    test_x = np.concatenate(testing_set[:-2], axis=-1)

training_set中取出倒数第二个元素,即target=(1,307,12),组成train_targetval_target test_target同理。

	train_target = training_set[-2]  
    val_target = validation_set[-2]
    test_target = testing_set[-2]

training_set中取出最后一个元素,即time_sample =(1,1),组成train_timestamp val_timestamp test_timestamp 同理。

    train_timestamp = training_set[-1]  
    val_timestamp = validation_set[-1]
    test_timestamp = testing_set[-1]

接下来是归一化操作,先看归一化函数。

normalization函数

  • 函数功能:对这输入的数据集进行标准化处理,使得每个数据集的均值为 0,标准差为 1
  • 函数具体操作:
def normalization(train, val, test):
 	#确保 train、val 和 test 数据集在第1轴及其后面的维度上形状相同
    assert train.shape[1:] == val.shape[1:] and val.shape[1:] == test.shape[1:]  # ensure the num of nodes is the same
    mean = train.mean(axis=(0, 1, 3), keepdims=True)
    std = train.std(axis=(0, 1, 3), keepdims=True)
    print('mean.shape:', mean.shape)
    print('std.shape:', std.shape)
    def normalize(x):
        return (x - mean) / std
    train_norm = normalize(train)
    val_norm = normalize(val)
    test_norm = normalize(test)
    return {'_mean': mean, '_std': std}, train_norm, val_norm, test_norm

mean = train.mean(axis=(0, 1, 3), keepdims=True):计算 train数据集在第013轴上的均值,并保持这些维度以便后续广播。
std = train.std(axis=(0, 1, 3), keepdims=True):计算 train数据集在第013轴上的标准差,并保持这些维度以便后续广播。
返回一个字典,和标准化的结果,接下来构建一系列字典。

整体梳理

for idx in range(16992): 0-16991
  search_data函数生成hour_indice: [(0, 12)]
  get_sample_indices函数根据hour_indice,从原始数据中取出索引从011的数据,添加到hour_sample中,维度: (12,307,3),经过维度扩展操作后变为:(1,12,307,3)
  从原始数据中取出索引从12-23的数据,添加到target中,维度: (12,307,3),经过维度扩展操作后变为:(1,12,307,3)
  取idx的值加添加到time_sample中:[0]
  把hour_sampletargettime_sample添加到samples
  把samples添加到all_samples中,all_samples=[[[hour_sample],[target],[samples]],...,[[hour_sample],[target],[samples]]]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1929917.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

车间数据采集网关的工作原理和应用场景-天拓四方

在智能制造日益盛行的今天&#xff0c;车间数据采集作为整个生产流程中的关键环节&#xff0c;其重要性愈发凸显。数据采集网关作为这一环节的核心设备&#xff0c;扮演着承上启下的重要角色。本文旨在深入探讨车间数据采集网关的工作原理和应用场景。 一、数据采集网关的工作…

C++:链表插入排序/删除重复节点题解

插入排序 插入排序的思路很简单&#xff0c;基本都知道。 关键是放在链表中&#xff0c; 1.要建立一个哨兵位&#xff0c;这个哨兵位的下一个节点&#xff0c;始终指向val最小的节点。 2.prev指针作为cur的前一个节点&#xff0c;始终指向val最大的节点。它的下一个节点始终…

玩转HarmonyOS NEXT之IM应用首页布局

本文从目前流行的垂类市场中&#xff0c;选择即时通讯应用作为典型案例详细介绍HarmonyOS NEXT的各类布局在实际开发中的综合应用。即时通讯应用的核心功能为用户交互&#xff0c;主要包含对话聊天、通讯录&#xff0c;社交圈等交互功能。 应用首页 创建一个包含一列的栅格布…

Eureka——Spring Cloud中的服务注册与发现组件

目录 1. 前言2. Eureka的概述2.1 Eureka的核心功能2.2 Eureka的角色与特点2.3 Eureka的使用优势 3. 创建 Spring Cloud 的注册中心3.1 创建一个父项目3.2 创建Spring Cloud的注册中心Eureka 4. 创建服务提供者5. 创建一个消费者Consumer&#xff0c;调用服务提供者Provider 1. …

利用OSMnx进行城市路网数据的速度与通行时间推算及分析

本文还是以广州市路网为例&#xff0c;通过osmmx调用ox.add_edge_speeds(G)时&#xff0c;该函数会遍历图G 中的每条边&#xff08;即每条街道&#xff09;&#xff0c;并基于一些预设的规则或街道属性&#xff08;如街道类型、是否为主要道路、是否有速度限制等&#xff09;来…

netdata 监控软件安装与学习

netdata官网 netdata操作文档 前言&#xff1a; netdata是一款监控软件&#xff0c;可以监控多台主机也可以监控单台主机&#xff0c;监控单台主机时&#xff0c;开箱即用&#xff0c;web ui很棒。 环境&#xff1a; [root192 ~]# cat /etc/system-release CentOS Linux rel…

【qt】正则表达式来判断是否为邮箱登录

正则表达式是用来匹配字符串的神器. 在Qt中我们需要使用到QRegExp这个类 用exactMatch来进行匹配. [] 使用方括号 [] 来定义字符类&#xff0c;表示匹配方括号内的任意一个字符 A-Za-z0-9是字符的匹配范围. 是用于指定字符或字符类出现的次数,常见的如下 *&#xff08;匹配 0…

树结构添加分组,向上向下添加同级,添加子级

树结构添加分组&#xff0c;向上向下添加同级&#xff0c;添加子级 效果代码实现页面js 效果 代码实现 页面 <el-tree :data"treeData" :props"defaultProps" :expand-on-click-node"false":filter-node-method"filterNode" :ref&…

pico+unity3d手部动画

在 Unity 开发中&#xff0c;输入系统的选择和运用对于实现丰富的交互体验至关重要。本文将深入探讨 Unity 中的 Input System 和 XR Input Subsystem 这两种不同的输入系统&#xff0c;并详细介绍它们在控制手部动画方面的应用。 一、Input System 和 XR Input Subsystem 的区…

有关电力电子技术的一些相关仿真和分析:⑤交-直-交全桥逆变+全波整流结构电路(MATLAB/Siumlink仿真)

全桥逆变+全波整流结构 参数:Vin=500V, Vo=200V, T=2:1:1, RL=10Ω, fs=100kHz, L=1mH, C=100uF (1)给定输入电压,输出电压和主电路参数,仿真研究电路工作原理,分析工作时序; (2)调节负载电阻,实现电流连续和断续,并仿真验证; (3)调节占空比,分析占空比与电…

【2024开发插件大赛】如何为 ONLYOFFICE 开发插件

我们发布了 2024 插件开发大赛&#xff1a;为 ONLYOFFICE 开发适合中国用户的插件&#xff0c;获得福利与证书。如果您想要参加&#xff0c;阅读本文了解如何为 ONLYOFFICE 开发插件。 关于 ONLYOFFICE ONLYOFFICE 是一个国际开源项目&#xff0c;由领先的 IT 公司 Ascensio Sy…

Microsoft Edge(简称Edge)

Microsoft Edge&#xff08;简称Edge&#xff09;是一款由微软开发的网页浏览器&#xff0c;它为用户提供了许多便捷的功能和选项。以下是Edge浏览器的使用方法&#xff1a; 一、基本使用方法 打开Edge浏览器&#xff1a; 可以在Windows的开始菜单中找到“Microsoft Edge”并点…

Flink Window 窗口【更新中】

Flink Window 窗口 在Flink流式计算中&#xff0c;最重要的转换就是窗口转换Window&#xff0c;在DataStream转换图中&#xff0c;可以发现处处都可以对DataStream进行窗口Window计算。 窗口&#xff08;window&#xff09;就是从 Streaming 到 Batch 的一个桥梁。窗口将无界流…

【数据结构取经之路】二叉搜索树的实现

目录 前言 二叉搜索树 概念 性质 二叉搜索树的实现 结点的定义 插入 查找 删除 二叉搜索树完整代码 前言 首先&#xff0c;二叉搜索树是一种数据结构&#xff0c;了解二叉搜素树有助于理解map和set的特性。 二叉搜索树 概念 二叉搜索树又称二叉排序树&#xff0c…

推荐系统之MIND用户多兴趣网络

目录 引言MIND算法原理1. 算法概述2. 模型结构3. 多兴趣提取层4. 标签感知注意力层 实践应用应用场景1. 电商平台2. 社交媒体3. 视频流媒体4. 内容分发平台 结论 引言 随着大数据和人工智能技术的快速发展&#xff0c;推荐系统已成为电商平台、社交媒体和内容分发平台的重要组成…

如何用python写接口

如何用python写接口&#xff1f;具体步骤如下&#xff1a;  1、实例化server 2、装饰器下面的函数变为一个接口 3、启动服务 开发工具和流程&#xff1a; python库&#xff1a;flask 》实例化server&#xff1a;server flask.Flask(__name__) 》server.route(/index,met…

吃空上千袋,养猫10年经验,生生不息、希喂、弗列加特谁是卷王?

身为宠物医生&#xff0c;我每天都在与猫咪和狗狗的相处中度过&#xff0c;对它们的身体变化十分敏感。当前&#xff0c;许多家养猫面临肥胖和肝脏损伤的双重困扰&#xff0c;虽然医疗手段可以介入&#xff0c;但问题的核心在于宠物主人的喂养方法是否得当。 在我职业生涯的这…

磁盘空间不足java.sql.sQLException:磁盘空间不足

java.sql.sQLException:磁盘空间不足 环境介绍1 查询表空间使用情况2 对表空间文件扩展限制进行修改(或新增表空间数据文件)3 达梦数据库学习使用列表 环境介绍 遇到此错误时,首先查看数据库服务器 , 数据库相关磁盘磁盘空间使用率;在磁盘空间充足的情况下, 业务系统操作达梦数…

React Native 自定义 Hook 获取组件位置和大小

在 React Native 中自定义 Hook useLayout 获取 View、Pressable 等组件的位置和大小的信息 import {useState, useCallback} from react import {LayoutChangeEvent, LayoutRectangle} from react-nativeexport function useLayout() {const [layout, setLayout] useState&l…

搜维尔科技:【产品推荐】Euleria Health Riablo 运动功能训练与评估系统

Euleria Health Riablo 运动功能训练与评估系统 Riablo提供一种创新的康复解决方案&#xff0c;将康复和训练变得可激励、可衡量和可控制。Riablo通过激活本体感觉&#xff0c;并通过视听反馈促进神经肌肉的训练。 得益于其技术先进和易用性&#xff0c;Riablo是骨科、运动医…