Set k-largest tensor elements to zero in TensorFlow

I want to find the k largest elements of each row h and set the value to 0 for these maximum elements.

I could select the highest indexes of each row using the top_k function, for example:

top_k = tf.nn.top_k(h, 1) 

But I could not use the indexes returned by top_k to update the tensor.

How can i do this? Thanks in advance...

+8
tensorflow
source share
2 answers

This is a bit complicated, maybe there is a better solution. tf.scatter_update() does not work here because it can only change parts of the tensor along the first dimension (not like an element in the first and second columns, for example).

You should get values and indices from tf.nn.top_k() to create a sparse tensor and subtract it into the initial Tensor x :

 x = tf.constant([[6., 2., 0.], [0., 4., 5.]]) # of type tf.float32 k = 2 values, indices = tf.nn.top_k(x, k, sorted=False) # indices will be [[0, 1], [1, 2]], values will be [[6., 2.], [4., 5.]] # We need to create full indices like [[0, 0], [0, 1], [1, 2], [1, 1]] my_range = tf.expand_dims(tf.range(0, indices.get_shape()[0]), 1) # will be [[0], [1]] my_range_repeated = tf.tile(my_range, [1, k]) # will be [[0, 0], [1, 1]] # change shapes to [N, k, 1] and [N, k, 1], to concatenate into [N, k, 2] full_indices = tf.concat([tf.expand_dims(my_range_repeated, 2), tf.expand_dims(indices, 2)], axis=2) full_indices = tf.reshape(full_indices, [-1, 2]) to_substract = tf.sparse_to_dense(full_indices, x.get_shape(), tf.reshape(values, [-1]), default_value=0.) res = x - to_substract # res should be all 0. 
+7
source share

I ran into the opposite problem and wanted an operation that supported gradients. top_k does not support gradient propagation, and so implementing a function in C ++ is a good way.

top_k C ++ code is here .

Your core operation will look like this:

 template <typename T> class MakeSparseOp : public OpKernel { public: explicit MakeSparseOp(OpKernelConstruction *context) : OpKernel(context) {} void Compute(OpKernelContext *context) override { // Grab the input tensors const auto &k_in = context->input(1); OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()), errors::InvalidArgument("k must be scalar, got shape ", k_in.shape().DebugString())); int k = k_in.scalar<int32>()(); OP_REQUIRES(context, k >= 0, errors::InvalidArgument("Need k >= 0, got ", k)); const Tensor &x_in = context->input(0); OP_REQUIRES(context, x_in.dims() >= 1, errors::InvalidArgument("input must be >= 1-D, got shape ", x_in.shape().DebugString())); OP_REQUIRES( context, x_in.dim_size(x_in.dims() - 1) >= k, errors::InvalidArgument("input must have at least k columns")); // Flattening the input tensor const auto &x = x_in.flat_inner_dims<T>(); const auto num_rows = x.dimension(0); const auto num_cols = x.dimension(1); TensorShape output_shape = x_in.shape(); // Create an output tensor Tensor *x_out = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &x_out)); /* * Get the top k values along the first dimension for input */ auto x_sparse = x_out->flat_inner_dims<T>(); if (k == 0) return; // Nothing to do // Using TopN to get the k max element gtl::TopN<std::pair<T, int32>> filter(k); x_sparse = x; // Copy all elements for (int r = 0; r < num_rows; r++) { // Processing a row at a time for (int32 c = 0; c < num_cols; c++) { // The second element is the negated index, so that lower-index // elements // are considered larger than higher-index elements in case of // ties. filter.push(std::make_pair(x(r, c), -c)); } for (auto top_k_it = filter.unsorted_begin(); top_k_it != filter.unsorted_end(); ++top_k_it) { x_sparse(r, -top_k_it->second) = 0; // Set max k to zero } filter.Reset(); } } }; 

My implementation for the related problem is here .

+2
source share

All Articles