目录
- nn.Embedding
- triu 函数
- copy.deepcopy
nn.Embedding
#参数1:词表大小(即词表单词个数)即只要输入的值在0——n-1之间就可,对于序列长度无影响。
#参数2:词映射的维度n(即将每个词映射成一个二维1*n)
input = pair_tensor[0][0]#一个词所对应的张量中的一个元素
print(pair_tensor)
print(input)
embedding=nn.Embedding(20,3)
output1 = embedding(input)#此处只要是张量即可,此时仅对一个含有一个元素的一维张量进行
output2= embedding(input).view(1, 1, -1)#变为三维1*1*n
print(output1)#输出为input元素个数*3的张量
print(output2)
print(output1.shape,output2.shape)
input1=torch.LongTensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
output3=embedding(input1)
print(output3)
triu 函数
矩阵运算中非常常见的操作,用于提取或构建一个矩阵的上三角部分。上三角部分是指矩阵的主对角线及其上方的元素(对于每行来说,即在主对角线右侧的元素)。triu 是 “triangular upper” 的缩写,意味着"上三角"。
- 基本概念
上三角矩阵:一个矩阵的上三角部分是指包括主对角线及其上方的元素的部分。对于一个矩阵 A,如果 A[i, j] 在上三角部分,那么 i ≤ j。
主对角线:矩阵的主对角线指的是从左上角到右下角的对角线,即满足 i == j 的元素 - numpy.triu 函数(torch.triu 函数同)
在 NumPy 中,triu 函数用于生成矩阵的上三角部分。它的函数签名如下:
numpy.triu(m, k=0) 参数:
- m:输入矩阵,可以是一个二维数组。
- k:整数,控制对角线的偏移量。默认值为 0,表示从主对角线开始(即主对角线以下为0,加偏移一样)。如果 k > 0,则偏移 k行上方的对角线;如果 k < 0,则偏移 k 行下方的对角线。
- 返回值:返回一个与 m具有相同形状的数组,其中所有在上三角之外的元素都被设置为 0。
示例:
>>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], k=-1)
array([[ 1, 2, 3],
[ 4, 5, 6],
[ 0, 8, 9],
[ 0, 0, 12]])
>>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], k=0)
array([[ 1, 2, 3],
[ 0, 5, 6],
[ 0, 0, 9],
[ 0, 0, 0]])
>>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], k=1)
array([[ 0, 2, 3],
[ 0, 0, 6],
[ 0, 0, 0],
[ 0, 0, 0]])
copy.deepcopy
- 深拷贝 (deepcopy)
copy.deepcopy: 这是 Python 标准库 copy 模块中的一个函数,用于创建对象的深拷贝。
- 深拷贝: 深拷贝会递归地复制对象及其内部所有对象,而不仅仅是复制对象的引用。这样,原对象和复制对象在内存中是完全独立的修改一个对象不会影响另一个对象。
- 浅拷贝: 与之相对的是浅拷贝(copy.copy),它只复制对象本身,不复制嵌套的对象。
- 代码解释
c = copy.deepcopy: 这一行代码将 copy.deepcopy 函数赋值给变量 c。
这样,c 现在是一个指向 deepcopy 函数的快捷方式。使用时可以简化调用,如将 copy.deepcopy(obj) 改为 c(obj)。 - 什么时候使用深拷贝?
深拷贝在以下情况非常有用:
- 需要完全独立的对象:
当你需要确保新对象和原对象完全独立,没有共享的子对象时。
- 对象包含嵌套的可变对象:
当对象中包含复杂的嵌套结构(如列表、字典)时,深拷贝会递归地复制这些嵌套对象。
- 避免意外修改:
在某些场景下,如果多个对象共享子对象,可能会引起意外的修改(因为这些子对象是共享的)。深拷贝可以防止这种情况。
示例
import copy
# 使用深拷贝创建一个独立的副本
original = [1, [2, 3], 4]
c = copy.deepcopy
copied = c(original)
# 修改副本,不影响原对象
copied[1][0] = 'changed'
print(original) # 输出: [1, [2, 3], 4]
print(copied) # 输出: [1, ['changed', 3], 4]