This is a bit complicated, maybe there is a better solution. tf.scatter_update() does not work here because it can only change parts of the tensor along the first dimension (not like an element in the first and second columns, for example).
You should get values and indices from tf.nn.top_k() to create a sparse tensor and subtract it into the initial Tensor x :
x = tf.constant([[6., 2., 0.], [0., 4., 5.]]) # of type tf.float32 k = 2 values, indices = tf.nn.top_k(x, k, sorted=False) # indices will be [[0, 1], [1, 2]], values will be [[6., 2.], [4., 5.]] # We need to create full indices like [[0, 0], [0, 1], [1, 2], [1, 1]] my_range = tf.expand_dims(tf.range(0, indices.get_shape()[0]), 1) # will be [[0], [1]] my_range_repeated = tf.tile(my_range, [1, k]) # will be [[0, 0], [1, 1]] # change shapes to [N, k, 1] and [N, k, 1], to concatenate into [N, k, 2] full_indices = tf.concat([tf.expand_dims(my_range_repeated, 2), tf.expand_dims(indices, 2)], axis=2) full_indices = tf.reshape(full_indices, [-1, 2]) to_substract = tf.sparse_to_dense(full_indices, x.get_shape(), tf.reshape(values, [-1]), default_value=0.) res = x - to_substract # res should be all 0.