文章目录
- 1. 生成类别矩阵如下
- 2. pytorch 代码
- 3. 循环移动矩阵
1. 生成类别矩阵如下

2. pytorch 代码
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(precision=3, sci_mode=False)
if __name__ == "__main__":
run_code = 0
a_matrix = torch.arange(4).reshape(2, 2) + 1
b_matrix = torch.ones((2, 2))
print(f"a_matrix=\n{a_matrix}")
print(f"b_matrix=\n{b_matrix}")
c_matrix = torch.kron(input=a_matrix, other=b_matrix)
print(f"c_matrix=\n{c_matrix}")
d_matrix = torch.arange(9).reshape(3, 3) + 1
e_matrix = torch.ones((2, 2))
f_matrix = torch.kron(input=d_matrix, other=e_matrix)
print(f"d_matrix=\n{d_matrix}")
print(f"e_matrix=\n{e_matrix}")
print(f"f_matrix=\n{f_matrix}")
g_matrix = f_matrix[1:-1, 1:-1]
print(f"g_matrix=\n{g_matrix}")
a_matrix=
tensor([[1, 2],
[3, 4]])
b_matrix=
tensor([[1., 1.],
[1., 1.]])
c_matrix=
tensor([[1., 1., 2., 2.],
[1., 1., 2., 2.],
[3., 3., 4., 4.],
[3., 3., 4., 4.]])
d_matrix=
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
e_matrix=
tensor([[1., 1.],
[1., 1.]])
f_matrix=
tensor([[1., 1., 2., 2., 3., 3.],
[1., 1., 2., 2., 3., 3.],
[4., 4., 5., 5., 6., 6.],
[4., 4., 5., 5., 6., 6.],
[7., 7., 8., 8., 9., 9.],
[7., 7., 8., 8., 9., 9.]])
g_matrix=
tensor([[1., 2., 2., 3.],
[4., 5., 5., 6.],
[4., 5., 5., 6.],
[7., 8., 8., 9.]])
3. 循环移动矩阵
- excel 表示

- pytorch 源码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
torch.set_printoptions(precision=3, sci_mode=False)
class WindowMatrix(object):
def __init__(self, num_patch=4, size=2):
self.num_patch = num_patch
self.size = size
self.width = self.num_patch
self.height = self.size * self.size
self._result = torch.zeros((self.width, self.height))
@property
def result(self):
a_size = int(math.sqrt(self.num_patch))
a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
b_matrix = torch.ones(self.size, self.size)
self._result = torch.kron(input=a_matrix, other=b_matrix)
return self._result
class ShiftedWindowMatrix(object):
def __init__(self, num_patch=9, size=2):
self.num_patch = num_patch
self.size = size
self.width = self.num_patch
self.height = self.size * self.size
self._result = torch.zeros((self.width, self.height))
@property
def result(self):
a_size = int(math.sqrt(self.num_patch))
a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
b_matrix = torch.ones(self.size, self.size)
my_result = torch.kron(input=a_matrix, other=b_matrix)
self._result = my_result[1:-1, 1:-1]
return self._result
class RollShiftedWindowMatrix(object):
def __init__(self, num_patch=9, size=2):
self.num_patch = num_patch
self.size = size
self.width = self.num_patch
self.height = self.size * self.size
self._result = torch.zeros((self.width, self.height))
@property
def result(self):
a_size = int(math.sqrt(self.num_patch))
a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
b_matrix = torch.ones(self.size, self.size)
my_result = torch.kron(input=a_matrix, other=b_matrix)
my_result = my_result[1:-1, 1:-1]
roll_result = torch.roll(input=my_result, shifts=(-1, -1), dims=(-1, -2))
self._result = roll_result
return self._result
class BackRollShiftedWindowMatrix(object):
def __init__(self, num_patch=9, size=2):
self.num_patch = num_patch
self.size = size
self.width = self.num_patch
self.height = self.size * self.size
self._result = torch.zeros((self.width, self.height))
@property
def result(self):
a_size = int(math.sqrt(self.num_patch))
a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
b_matrix = torch.ones(self.size, self.size)
my_result = torch.kron(input=a_matrix, other=b_matrix)
my_result = my_result[1:-1, 1:-1]
roll_result = torch.roll(input=my_result, shifts=(-1, -1), dims=(-1, -2))
print(f"roll_result=\n{roll_result}")
roll_result = torch.roll(input=roll_result, shifts=(1, 1), dims=(-1, -2))
self._result = roll_result
return self._result
if __name__ == "__main__":
run_code = 0
my_window_matrix = WindowMatrix()
my_window_matrix_result = my_window_matrix.result
print(f"my_window_matrix_result=\n{my_window_matrix_result}")
shifted_window_matrix = ShiftedWindowMatrix()
shifed_window_matrix_result = shifted_window_matrix.result
print(f"shifed_window_matrix_result=\n{shifed_window_matrix_result}")
roll_shifted_window_matrix = RollShiftedWindowMatrix()
roll_shifed_window_matrix_result = roll_shifted_window_matrix.result
print(f"roll_shifed_window_matrix_result=\n{roll_shifed_window_matrix_result}")
Back_roll_shifted_window_matrix = BackRollShiftedWindowMatrix()
back_roll_shifed_window_matrix_result = Back_roll_shifted_window_matrix.result
print(f"back_roll_shifed_window_matrix_result=\n{back_roll_shifed_window_matrix_result}")
my_window_matrix_result=
tensor([[1., 1., 2., 2.],
[1., 1., 2., 2.],
[3., 3., 4., 4.],
[3., 3., 4., 4.]])
shifed_window_matrix_result=
tensor([[1., 2., 2., 3.],
[4., 5., 5., 6.],
[4., 5., 5., 6.],
[7., 8., 8., 9.]])
roll_shifed_window_matrix_result=
tensor([[5., 5., 6., 4.],
[5., 5., 6., 4.],
[8., 8., 9., 7.],
[2., 2., 3., 1.]])
roll_result=
tensor([[5., 5., 6., 4.],
[5., 5., 6., 4.],
[8., 8., 9., 7.],
[2., 2., 3., 1.]])
back_roll_shifed_window_matrix_result=
tensor([[1., 2., 2., 3.],
[4., 5., 5., 6.],
[4., 5., 5., 6.],
[7., 8., 8., 9.]])