TensorFlow: poor performance when getting input gradients

I am building a simple multi-layer perceptron with TensorFlow, and I also need to get the gradients (or error signal) of the loss on the inputs of the neural network.

Here is my code that works:

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y)) optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost) ... for i in range(epochs): .... for batch in batches: ... sess.run(optimizer, feed_dict=feed_dict) grads_wrt_input = sess.run(tf.gradients(cost, self.x), feed_dict=feed_dict)[0] 

(edited to include learning cycle)

Without the last line ( grads_wrt_input... ) this works very fast on a CUDA machine. However, tf.gradients() significantly reduces performance by ten times or more.

I remember that the error signals in the nodes are calculated as intermediate values ​​in the backpropagation algorithm, and I have successfully done this using the Java library DeepLearning4j. I also got the impression that this would be a small modification of the calculation graph already built with optimizer .

How can this be done faster or is there any other way to calculate the loss gradients of wrt inputs?

+7
tensorflow
source share
1 answer

The tf.gradients() function builds a new backpropagation graph every time it is called, so the reason for the slowdown is that TensorFlow must analyze the new graph at each iteration of the loop. (This can be surprisingly expensive: the current version of TensorFlow is optimized to run the same schedule many times.)

Fortunately, the solution is easy: just calculate the gradients once, outside the loop. You can change the structure of your code as follows:

 cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y)) optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost) grads_wrt_input_tensor = tf.gradients(cost, self.x)[0] # ... for i in range(epochs): # ... for batch in batches: # ... _, grads_wrt_input = sess.run([optimizer, grads_wrt_input_tensor], feed_dict=feed_dict) 

Note that for performance, I have also combined two calls to sess.run() . This ensures reuse of forward distribution and most of the reverse distribution.


As an aside, one tf.get_default_graph().finalize() advice for finding performance errors like this is to call tf.get_default_graph().finalize() before starting your training cycle. This will throw an exception if you accidentally add any nodes to the graph, which will make it easier to track the cause of these errors.

+10
source share

All Articles