文章目录
1. description 2. excel 3. pytorch code
1. description
功能:按一定比例的随机部分样本,简单来说就是按照一定的比例将行向量从小到大的顺序提取出来。 思考1: 用了均匀分布,并且按照一定比例,取前prob概率来表示 思考2:用了torch.argsort 来生成idx_shuffle 来的到快速从小到大排序的 思考3:用了torch.gather 来配合idx_shuffle 来找到最小的部分数据 思考4:用了torch.gather+idx_restore+ones_like matrix 组合的方式生成mask矩阵 小结:主要配合torch.argsort+torch.gather的方式生成相关的mask矩阵和最小矩阵和恢复矩阵,代码的巧妙运用很具备参考意义。
2. excel
3. pytorch code
import torch
import torch. nn as nn
torch. set_printoptions( precision= 3 , sci_mode= False )
torch. manual_seed( 2324 )
def random_masking ( x, mask_ratio) :
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x. shape
len_keep = int ( L * ( 1 - mask_ratio) )
noise = torch. rand( N, L, device= x. device)
ids_shuffle = torch. argsort( noise, dim= 1 )
ids_restore = torch. argsort( ids_shuffle, dim= 1 )
ids_keep = ids_shuffle[ : , : len_keep]
x_masked = torch. gather( x, dim= 1 , index= ids_keep. unsqueeze( - 1 ) . repeat( 1 , 1 , D) )
mask = torch. ones( [ N, L] , device= x. device)
mask[ : , : len_keep] = 0
mask = torch. gather( mask, dim= 1 , index= ids_restore)
return x_masked, mask, ids_restore
if __name__ == "__main__" :
run_code = 0
bs = 2
seq_len = 8
seq_dim = 10
mx_total = bs * seq_dim * seq_len
a_matrix = torch. arange( mx_total) . reshape( ( bs, seq_len, seq_dim) )
a_x_masked, a_mask, a_ids_restore = random_masking( a_matrix, mask_ratio= 0.4 )
print ( f"a_matrix=\n { a_matrix} " )
print ( f"a_x_masked=\n { a_x_masked} " )
print ( f"a_mask=\n { a_mask} " )
print ( f"a_ids_restore=\n { a_ids_restore} " )
a_matrix=
tensor( [ [ [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ] ,
[ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ] ,
[ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ] ,
[ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ] ,
[ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ] ,
[ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ] ,
[ 60 , 61 , 62 , 63 , 64 , 65 , 66 , 67 , 68 , 69 ] ,
[ 70 , 71 , 72 , 73 , 74 , 75 , 76 , 77 , 78 , 79 ] ] ,
[ [ 80 , 81 , 82 , 83 , 84 , 85 , 86 , 87 , 88 , 89 ] ,
[ 90 , 91 , 92 , 93 , 94 , 95 , 96 , 97 , 98 , 99 ] ,
[ 100 , 101 , 102 , 103 , 104 , 105 , 106 , 107 , 108 , 109 ] ,
[ 110 , 111 , 112 , 113 , 114 , 115 , 116 , 117 , 118 , 119 ] ,
[ 120 , 121 , 122 , 123 , 124 , 125 , 126 , 127 , 128 , 129 ] ,
[ 130 , 131 , 132 , 133 , 134 , 135 , 136 , 137 , 138 , 139 ] ,
[ 140 , 141 , 142 , 143 , 144 , 145 , 146 , 147 , 148 , 149 ] ,
[ 150 , 151 , 152 , 153 , 154 , 155 , 156 , 157 , 158 , 159 ] ] ] )
a_x_masked=
tensor( [ [ [ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ] ,
[ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ] ,
[ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ] ,
[ 60 , 61 , 62 , 63 , 64 , 65 , 66 , 67 , 68 , 69 ] ] ,
[ [ 90 , 91 , 92 , 93 , 94 , 95 , 96 , 97 , 98 , 99 ] ,
[ 120 , 121 , 122 , 123 , 124 , 125 , 126 , 127 , 128 , 129 ] ,
[ 140 , 141 , 142 , 143 , 144 , 145 , 146 , 147 , 148 , 149 ] ,
[ 80 , 81 , 82 , 83 , 84 , 85 , 86 , 87 , 88 , 89 ] ] ] )
a_mask=
tensor( [ [ 0. , 0. , 0. , 1. , 1. , 1. , 0. , 1. ] ,
[ 0. , 0. , 1. , 1. , 0. , 1. , 0. , 1. ] ] )
a_ids_restore=
tensor( [ [ 1 , 0 , 2 , 4 , 7 , 6 , 3 , 5 ] ,
[ 3 , 0 , 6 , 4 , 1 , 5 , 2 , 7 ] ] )