Tensorflow: When are destination variables executed in sess.run with a list?

I thought that variable assignments are performed after all operations in the list given by sess.run, but the following code returns different results for different execution. It seems that random operations are performed in the list and assign a variable after starting the operation in the list.

a = tf.Variable(0) b = tf.Variable(1) c = tf.Variable(1) update_a = tf.assign(a, b + c) update_b = tf.assign(b, c + a) update_c = tf.assign(c, a + b) with tf.Session() as sess: sess.run(initialize_all_variables) for i in range(5): a_, b_, c_ = sess.run([update_a, update_b, update_c]) 

I would like to know the time of assignment of variables. What is correct: "update_x β†’ assign x β†’ ... β†’ udpate_z β†’ assign z" or "update_x β†’ udpate_y β†’ udpate_z β†’ assign a, b, c"? (where (x, y, z) is the permutation (a, b, c)) In addition, if there is a way to implement the last assignment (the assignment is performed after all operations in the list are completed), please let me know how to implement it.

+7
variables python variable-assignment tensorflow timing
source share
2 answers

The three operations update_a , update_b and update_c do not have interdependencies in the data flow graph, so TensorFlow can choose to execute them in any order. (In the current implementation, it is possible that all three of them will run in parallel on different threads.) The second nit is that the reading of variables is cached by default, so in your program the value assigned to update_b (i.e. c + a ) can use the original or updated value of a , depending on when the first variable is read.

If you want the operations to be performed in a specific order, you can use the blocks with tf.control_dependencies([...]): to force the operations created in the block to occur after the operations named in the list. You can use tf.Variable.read_value() inside the with tf.control_dependencies([...]): to make the point at which the variable is read explicitly.

Therefore, if you want to make sure that update_a happens before update_b and update_b happen before update_c , you can do:

 update_a = tf.assign(a, b + c) with tf.control_dependencies([update_a]): update_b = tf.assign(b, c + a.read_value()) with tf.control_dependencies([update_b]): update_c = tf.assign(c, a.read_value() + b.read_value()) 
+9
source share

Based on this example of yours,

 v = tf.Variable(0) c = tf.constant(3) add = tf.add(v, c) update = tf.assign(v, add) mul = tf.mul(add, update) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) res = sess.run([mul, mul]) print(res) 

Output: [9, 9]

You get [9, 9] , and this is actually what we requested. Think of it this way:

At runtime, once mul is taken from the list, it searches for a definition of this and finds tf.mul(add, update) . Now it needs the value add , which leads to tf.add(v, c) . Thus, it connects to the value of v and c , gets the value add as 3.

Ok, now we need the update value, which is defined as tf.assign(v, add) . We have values ​​like add (which he calculated just now as 3) and v . Thus, it updates the value of v to 3, which is also the value for update .

Now it has values ​​for add and update , which are equal to 3. Thus, multiplication gives 9 in mul .

Based on the result we get, I think, for the next element (operation) in the list, it simply returns the mul value just calculated. I'm not sure if this repeats the steps or simply returns the same (cached?) Value that he just calculated for mul , realizing that we have a result or these operations happen in parallel (for each element in the list). Maybe @mrry or @YaroslavBulatov can comment on this part?


Quote from @mrry comment:

When you call sess.run([x, y, z]) once, TensorFlow executes each op that these tensors depend on only once (if there is no tf.while_loop() on your graph). If the tensor appears twice in the list (for example, mul in your example), TensorFlow will execute it once and return two copies of the result. To complete the assignment more than once, you must either call sess.run() several times, or use tf.while_loop() to put the loop in the chart.

+1
source share

All Articles