I am trying to update tf.Variable inside tf.while_loop() using tf.scatter_update() . However, the result is an initial value instead of an updated value. Here is an example of the code I'm trying to do:
from __future__ import print_function import tensorflow as tf def cond(sequence_len, step): return tf.less(step,sequence_len) def body(sequence_len, step): begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0)) begin = tf.scatter_update(begin,1,step,use_locking=None) tf.get_variable_scope().reuse_variables() return (sequence_len, step+1) with tf.Graph().as_default(): sess = tf.Session() step = tf.constant(0) sequence_len = tf.constant(10) _,step, = tf.while_loop(cond, body, [sequence_len, step], parallel_iterations=10, back_prop=True, swap_memory=False, name=None) begin = tf.get_variable("begin",[3],dtype=tf.int32) init = tf.initialize_all_variables() sess.run(init) print(sess.run([begin,step]))
Result: [array([0, 0, 0], dtype=int32), 10] . However, I think the result should be [0, 0, 10] . Am I something wrong here?
source share