How tf.scatter_update () works inside while_loop ()

I am trying to update tf.Variable inside tf.while_loop() using tf.scatter_update() . However, the result is an initial value instead of an updated value. Here is an example of the code I'm trying to do:

 from __future__ import print_function import tensorflow as tf def cond(sequence_len, step): return tf.less(step,sequence_len) def body(sequence_len, step): begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0)) begin = tf.scatter_update(begin,1,step,use_locking=None) tf.get_variable_scope().reuse_variables() return (sequence_len, step+1) with tf.Graph().as_default(): sess = tf.Session() step = tf.constant(0) sequence_len = tf.constant(10) _,step, = tf.while_loop(cond, body, [sequence_len, step], parallel_iterations=10, back_prop=True, swap_memory=False, name=None) begin = tf.get_variable("begin",[3],dtype=tf.int32) init = tf.initialize_all_variables() sess.run(init) print(sess.run([begin,step])) 

Result: [array([0, 0, 0], dtype=int32), 10] . However, I think the result should be [0, 0, 10] . Am I something wrong here?

+5
source share
1 answer

The problem is that nothing in the body of the loop depends on your tf.scatter_update() op, so it never executes. The easiest way to make it work is to add an update control dependency to the return values:

 def body(sequence_len, step): begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0)) begin = tf.scatter_update(begin, 1, step, use_locking=None) tf.get_variable_scope().reuse_variables() with tf.control_dependencies([begin]): return (sequence_len, step+1) 

Note that this problem is not unique to loops in TensorFlow. If you just defined tf.scatter_update() op with the name begin , but call sess.run() on it or something that depends on it, the update will not happen. When you use tf.while_loop() , there is no way to run the operations defined in the body of the loop directly, so the easiest way to get the side effect is to add control dependencies.

Note that the final result is [0, 9, 0] : each iteration assigns the current begin[1] step begin[1] , and in the last iteration, the value of the current step 9 (the condition is false when step == 10 ).

+6
source

All Articles