An elegant way to select one element per line in a tensor stream

Considering,...

  • Matrix A form [m, n]
  • tensor I form [m]

I want to get a list of J elements from A , where J[i] = A[i, I[i]] .

That is, I contains the index of the item to select from each row in A

Context: I already have argmax(A, 1) , and now I also want max . I know that I can just use reduce_max . And having tried a little, I also came up with this:

 J = tf.gather_nd(A, tf.transpose(tf.pack([tf.to_int64(tf.range(A.get_shape()[0])), I]))) 

Where to_int64 is required, because the range creates only int32 , and argmax creates only int64 .

None of the two seemed particularly elegant to me. One of them has overhead (probably around factor n ), and the other has an unknown cognitive overhead factor. Did I miss something?

+6
source share
2 answers

This is a rather late answer, but he could do

 mask = tf.one_hot(I, depth=n, dtype=tf.bool, on_value=True, off_value=False) elements = tf.boolean_mask(A, mask) 

Reach what you are looking for?

edit: I must point out that this is NOT a good idea if A already a very large tensor, as this leads to the creation of a dense matrix.

+2
source

Link provided by @ yaroslav-bulatov mentions this solution:

 def get_elements(data, indices): indeces = tf.range(0, tf.shape(indices)[0])*data.shape[1] + indices return tf.gather(tf.reshape(data, [-1]), indeces) 

Your solution is currently not differentiable (since gradients for tf.gather_nd are not currently supported).

Hopefully data[:, indices] will be submitted shortly.

0
source

All Articles