I am changing my TensorFlow code from the old queue interface to the new Dataset API . When using the old interface, I can specify the argument num_threadsto the queue tf.train.shuffle_batch. However, the only way to control the number of threads in the Dataset API seems to be in a function mapusing an argument num_parallel_calls. However, instead, I use a function flat_mapthat does not have such an argument.
Question : Is there a way to control the number of threads / processes for a function flat_map? Or is there a way to use mapin conjunction with flat_mapand still indicate the number of concurrent calls?
Please note that it is very important to run multiple threads in parallel, as I intend to run heavy preprocessing on the CPU before the data enters the queue.
Here are two (and here ) related posts on GitHub, but I don't think they answer this question.
The following is a minimal code example of my usage example:
with tf.Graph().as_default():
data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
input_tensors = (data,)
def pre_processing_func(data_):
results = (tf.expand_dims(data_, axis=0),)
return tf.data.Dataset.from_tensor_slices(results)
dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
dataset = dataset_source.flat_map(pre_processing_func)
source
share