关于broadcasting的介绍,参考这篇文章。
https://blog.csdn.net/python_LC_nohtyp/article/details/104097417
import tensorflow as tf
import numpy as np
tf.__version__
#关于broadcasting的介绍,参考这篇文章
#https://blog.csdn.net/python_LC_nohtyp/article/details/104097417
tensor = tf.random.normal([4,32,32,3])
#tensor + [3],[3]可以扩展为[4,32,32,3]后相加
t = tensor + tf.random.normal([3])
print("=========tensor + [3]=======", t.shape)
#tensor +[32,32,1],可以扩展为[4,32,32,3]后相加
t = tensor + tf.random.normal([32,32,1])
print("=========tensor + [32,32,1]=======", t.shape)
#tensor + [4,1,1,1],可以扩展为[4,32,32,3]后相加
t = tensor + tf.random.normal([4,1,1,1])
print("=========tensor + [4,1,1,1]=======", t.shape)
#不使用运算符,使用broadcast_to来扩展维度
b = tf.broadcast_to(tf.random.normal([4,1,1,1]), [4,32,32,3])
print("=========tf.broadcast_to([4,1,1,1], [4,32,32,3])========", b.shape)
#tensor + [1,4,1,1],第二维度不是1,也不是32,无法相加,报错
#t = tensor + tf.random.normal([1,4,1,1])
#print("=========tensor + [1,4,1,1]=======", t.shape)
#InvalidArgumentError: {{function_node __wrapped__AddV2_device_/job:localhost/replica:0/task:0/device:CPU:0}}
#Incompatible shapes: [4,32,32,3] vs. [1,4,1,1] [Op:AddV2] name:
#使用tile方式进行复制扩展,tile方式会实际分配内存
tensor = tf.random.uniform([3,4], minval=0, maxval=10, dtype=tf.int32)
print("=========Original tensor======\n", tensor)
#使用broadcasting复制扩展,不会分配内存,但实际效果和tile一样
b = tf.broadcast_to(tensor, [2,3,4])
print("==========Broadcasting========\n", b)
#使用expand_dims扩展一个维度
t = tf.expand_dims(tensor, axis=0)
print("==========After expand_dims:", t.shape)
#使用tile复制第一个维度,参数[2,1,1]表示第一个维度复制两次,后两个维度复制1次(不动)
t1 = tf.tile(t, [2,1,1])
print("After tile:", t1.shape)
print(t1)
运行结果