How to get TensorFlow 'import_graph_def' to return tensors

If I try to import a saved TensorFlow definition using

import tensorflow as tf from tensorflow.python.platform import gfile with gfile.FastGFile(FLAGS.model_save_dir.format(log_id) + '/graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) x, y, y_ = tf.import_graph_def(graph_def, return_elements=['data/inputs', 'output/network_activation', 'data/correct_outputs'], name='') 

the return values ​​are not Tensor as expected, but something else: instead, for example, getting x as

 Tensor("data/inputs:0", shape=(?, 784), dtype=float32) 

I get

 name: "data/inputs_1" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { } } } 

That is, instead of getting the expected tensor x I get x.op This bothers me because the documentation seems to say that I should get a Tensor (although there are a lot of ors out there that make it hard to understand).

How can I get tf.import_graph_def to return a specific Tensor that I can use (for example, when loading a loaded model or when performing analyzes)?

+5
source share
1 answer

The names 'data/inputs' , 'output/network_activation' and 'data/correct_outputs' are actually the names of the operations. To get tf.import_graph_def() for returning tf.Tensor objects, you must add an output index to the name of the operation, which is usually ':0' for single-output operations:

 x, y, y_ = tf.import_graph_def(graph_def, return_elements=['data/inputs:0', 'output/network_activation:0', 'data/correct_outputs:0'], name='') 
+4
source

All Articles