In Tensorflow, how to solve flattened indexes obtained by tf.nn.max_pool_with_argmax?

I encounter a problem: after using tf.nn.max_pool_with_argmax , I get indexes, i.e. argmax: A Tensor of type Targmax. 4-D. The flattened indices of the max values chosen for each output.

How to unlock flattened indexes back to the coordinate list in Tensorflow?

Many thanks.

+6
source share
1 answer

I had the same problem today and I ended up with this solution:

 def unravel_argmax(argmax, shape): output_list = [] output_list.append(argmax // (shape[2] * shape[3])) output_list.append(argmax % (shape[2] * shape[3]) // shape[3]) return tf.pack(output_list) 

Here is an example of using ipython on a laptop (I use it to send the argmax pool positions to my subcooling method)

+2
source

All Articles