7

理解tf.transpose

 3 years ago
source link: http://www.banbeichadexiaojiubei.com/index.php/2020/12/05/理解tf-transpose/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

本文介绍Tensorflow中的transpose()函数。函数的原型如下:

tf.transpose(a, perm=None, name='transpose')

tf.transpose函数中文意思是转置,主要用于交换Input Tensor的不同维度。

tf.transpose的第一个参数a是Input Tensor。

tf.transpose的第二个参数perm指定Tensor的维度。tf.transpose(x, perm=[1,0,2])表示将三维Input Tensor的第一维和第二维进行交换。如果没有显式指定perm,默认perm = [n – 1, n -2, …, 0],其中n = rank(a)。

如果Input Tensor是二维的,就相当于线性代数中的转置。

x = tf.constant([[1, 2, 3], [4, 5, 6]])
tf.transpose(x)  # [[1, 4]
                 #  [2, 5]
                 #  [3, 6]]

# Equivalently
tf.transpose(x, perm=[1, 0])  # [[1, 4]
                              #  [2, 5]
                              #  [3, 6]]

对于多维Tensor,相对比较难理解。下面代码使用tf.transpose实现将shape=(1, 4, 3)的Tensor转换为shape=(3, 4, 1)的Tensor。

# 'perm' is more useful for n-dimensional tensors, for n > 2
tf.constant([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]])

tf.transpose(input_data, perm=[2, 1, 0])

 [[[ 1]
   [ 4]
   [ 7]
   [10]]
 
  [[ 2]
   [ 5]
   [ 8]
   [11]]
 
  [[ 3]
   [ 6]
   [ 9]
   [12]]]

如何理解tf.transpose的原理

上述将shape=(1, 4, 3)的Tensor转换为shape=(3, 4, 1)的Tensor的Demo在tf.transpose内部的操作实质就是将第一维和第三维进行交换,换句话说,就是对Tensor的所有元素按照下面的公式进行交换:

$a_{i,j,k} =a_{k,j,i}$

直观的理解就是交换了x轴和z轴,y轴保持不变,然后按照交换后的坐标轴重新解读Tensor数据。

356f92e9b2a6c7078465a9fa03798d69.jpg

参考材料

https://www.tensorflow.org/api_docs/python/tf/transpose

https://www.programmersought.com/article/1555520786/

除非注明,否则均为[半杯茶的小酒杯]原创文章,转载必须以链接形式标明本文链接

本文链接: http://www.banbeichadexiaojiubei.com/index.php/2020/12/05/%e7%90%86%e8%a7%a3tf-transpose/


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK