猿问

当切片本身是张量流中的张量时如何进行切片分配

我想在张量流中进行切片分配。我知道我可以使用:


my_var = my_var[4:8].assign(tf.zeros(4))

基于此链接。


正如您在中看到的,my_var[4:8]我们在这里有特定的索引 4、8 用于切片然后分配。


我的情况不同,我想根据张量进行切片,然后进行分配。


out = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))


 rows_tf = tf.constant (

[[1, 2, 5],

 [1, 2, 5],

 [1, 2, 5],

 [1, 4, 6],

 [1, 4, 6],

 [2, 3, 6],

 [2, 3, 6],

 [2, 4, 7]])


columns_tf = tf.constant(

[[1],

 [2],

 [3],

 [2],

 [3],

 [2],

 [3],

 [2]])


changed_tensor = [[8.3356,    0.,        8.457685 ],

                  [0.,        6.103182,  8.602337 ],

                  [8.8974,    7.330564,  0.       ],

                  [0.,        3.8914037, 5.826657 ],

                  [8.8974,    0.,        8.283971 ],

                  [6.103182,  3.0614321, 5.826657 ],

                  [7.330564,  0.,        8.283971 ],

                  [6.103182,  3.8914037, 0.       ]]

此外,这是sparse_indices张量,它是需要更新的整个索引的连接rows_tf和制作(以防它可以提供帮助:)columns_tf


sparse_indices = tf.constant(

[[1 1]

 [2 1]

 [5 1]

 [1 2]

 [2 2]

 [5 2]

 [1 3]

 [2 3]

 [5 3]

 [1 2]

 [4 2]

 [6 2]

 [1 3]

 [4 3]

 [6 3]

 [2 2]

 [3 2]

 [6 2]

 [2 3]

 [3 3]

 [6 3]

 [2 2]

 [4 2]

 [4 2]])

我想做的是做这个简单的任务:


out[rows_tf, columns_tf] = changed_tensor

为此,我正在这样做:


out[rows_tf:column_tf].assign(changed_tensor)

但是,我收到了这个错误:


tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected begin, end, and strides to be 1D equal size tensors, but got shapes [1,8,3], [1,8,1], and [1] instead. [Op:StridedSlice] name: strided_slice/

这是预期的输出:


[[0.        0.        0.        0.       ]

 [0.        8.3356    0.        8.8974   ]

 [0.        0.        6.103182  7.330564 ]

 [0.        0.        3.0614321 0.       ]

 [0.        0.        3.8914037 0.       ]

 [0.        8.457685  8.602337  0.       ]

 [0.        0.        5.826657  8.283971 ]

 [0.        0.        0.        0.       ]]

知道如何完成这个任务吗?


先感谢您:)


一只甜甜圈
浏览 181回答 1
1回答

哈士奇WWW

tf.scatter_nd_update 此示例(从此处的 tf 文档扩展)应该有所帮助。您想首先将您的 row_indices 和 column_indices 组合成一个二维索引列表,这indices是tf.scatter_nd_update. 然后你输入了一个期望值列表,即updates.ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))indices = tf.constant([[0, 2], [2, 2]])updates = tf.constant([1.0, 2.0])update = tf.scatter_nd_update(ref, indices, updates)with tf.Session() as sess:  sess.run(tf.initialize_all_variables())  print sess.run(update)Result:[[ 0.  0.  1.  0.] [ 0.  0.  0.  0.] [ 0.  0.  2.  0.] [ 0.  0.  0.  0.] [ 0.  0.  0.  0.] [ 0.  0.  0.  0.] [ 0.  0.  0.  0.] [ 0.  0.  0.  0.]]专门针对您的数据,ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))changed_tensor = [[8.3356,    0.,        8.457685 ],                  [0.,        6.103182,  8.602337 ],                  [8.8974,    7.330564,  0.       ],                  [0.,        3.8914037, 5.826657 ],                  [8.8974,    0.,        8.283971 ],                  [6.103182,  3.0614321, 5.826657 ],                  [7.330564,  0.,        8.283971 ],                  [6.103182,  3.8914037, 0.       ]]updates = tf.reshape(changed_tensor, shape=[-1])sparse_indices = tf.constant([[1, 1], [2, 1], [5, 1], [1, 2], [2, 2], [5, 2], [1, 3], [2, 3], [5, 3], [1, 2], [4, 2], [6, 2], [1, 3], [4, 3], [6, 3], [2, 2], [3, 2], [6, 2], [2, 3], [3, 3], [6, 3], [2, 2], [4, 2], [4, 2]])update = tf.scatter_nd_update(ref, sparse_indices, updates)with tf.Session() as sess:  sess.run(tf.initialize_all_variables())  print sess.run(update)Result:[[ 0.          0.          0.          0.        ] [ 0.          8.3355999   0.          8.8973999 ] [ 0.          0.          6.10318184  7.33056402] [ 0.          0.          3.06143212  0.        ] [ 0.          0.          0.          0.        ] [ 0.          8.45768547  8.60233688  0.        ] [ 0.          0.          5.82665682  8.28397083] [ 0.          0.          0.          0.        ]]
随时随地看视频慕课网APP

相关分类

Python
我要回答