Tensorflow: creating a graph in a class and running it outside

I find it difficult for me to understand how graphs work in a tensor flow and how to access them. My intuition is that the lines under "with the graph:" form the graph as a whole. Therefore, I decided to create a class that would build the graph when creating the instance and have a function that would run the graph, as shown below;

class Graph(object): #To build the graph when instantiated def __init__(self, parameters ): self.graph = tf.Graph() with self.graph.as_default(): ... prediction = ... cost = ... optimizer = ... ... # To launch the graph def launchG(self, inputs): with tf.Session(graph=self.graph) as sess: ... sess.run(optimizer, feed_dict) loss = sess.run(cost, feed_dict) ... return variables 

The next steps are to create a main file that will collect the parameters for the transition to the class, build a graph and then run it;

 #Main file ... parameters_dict = { 'n_input': 28, 'learnRate': 0.001, ... } #Building graph G = Graph(parameters_dict) P = G.launchG(Input) ... 

This is very elegant for me, but it doesn’t quite work (obviously). Indeed, it seems that launchG functions do not have access to the nodes defined on the graph that give me an error, for example;

 ---> 26 sess.run(optimizer, feed_dict) NameError: name 'optimizer' is not defined 

Perhaps my understanding of python (and tensorflow) is too limited, but I had a strange impression that when creating a graph (G), starting a session with this graph as an argument should give access to the nodes in it, without requiring explicit access.

Any enlightenment?

+8
source share
1 answer

The nodes prediction , cost and optimizer are local variables created in the __init__ method; they cannot be accessed in the launchG method.

The easiest fix is ​​to declare them as attributes of your Graph class:

 class Graph(object): #To build the graph when instantiated def __init__(self, parameters ): self.graph = tf.Graph() with self.graph.as_default(): ... self.prediction = ... self.cost = ... self.optimizer = ... ... # To launch the graph def launchG(self, inputs): with tf.Session(graph=self.graph) as sess: ... sess.run(self.optimizer, feed_dict) loss = sess.run(self.cost, feed_dict) ... return variables 

You can also get graph nodes with their exact name using graph.get_tensor_by_name and graph.get_operation_by_name .

+11
source

All Articles