Use tf.scatter_update in two-dimensional tf.Variable

I follow this Manipulation of matrix elements in a tensor flow . using tf.scatter_update. But my problem is: What happens if my tf.Variable is 2D? Let them talk:

a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]]) 

How can I update, for example, the first element of each row and set this to 1?

I tried something like

 for line in range(2): sess.run(tf.scatter_update(a[line],[0],[1])) 

but it fails (I expected this) and gave me an error:

TypeError: Input 'ref' of 'ScatterUpdate' Op requires an input of value l

How can I fix such problems?

`

+7
python matrix tensorflow
source share
1 answer

In tensorflow you cannot update the tensor, but you can update the variable.

The scatter_update can only update the first dimension of a variable. You should always pass the reference tensor to update the scattering ( a instead of a[line] ).

Here's how you can update the first element of a variable:

 import tensorflow as tf g = tf.Graph() with g.as_default(): a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]]) b = tf.scatter_update(a, [0, 1], [[1, 0, 0, 0], [1, 0, 0, 0]]) with tf.Session(graph=g) as sess: sess.run(tf.initialize_all_variables()) print sess.run(a) print sess.run(b) 

Output:

 [[0 0 0 0] [0 0 0 0]] [[1 0 0 0] [1 0 0 0]] 

But, to change the whole tensor again, it would be faster to assign a completely new one.

+6
source share

All Articles