Une façon Simple de visualiser un graphique de TensorFlow dans Jupyter?
la façon officielle de visualiser un graphe TensorFlow est avec TensorBoard, mais parfois je veux juste un rapide coup d'oeil au graphe quand je travaille à Jupyter.
Existe-t-il une solution rapide, idéalement basée sur des outils TensorFlow, ou des paquets scipy standard (comme matplotlib), mais si nécessaire basée sur des bibliothèques tierces?
4 réponses
Voici une recette que j'ai copiée d'un rêve profond D'Alex Mordvintsev ordinateur portable à un certain point
from IPython.display import clear_output, Image, display, HTML
def strip_consts(graph_def, max_const_size=32):
"""Strip large constant values from graph_def."""
strip_def = tf.GraphDef()
for n0 in graph_def.node:
n = strip_def.node.add()
n.MergeFrom(n0)
if n.op == 'Const':
tensor = n.attr['value'].tensor
size = len(tensor.tensor_content)
if size > max_const_size:
tensor.tensor_content = "<stripped %d bytes>"%size
return strip_def
def show_graph(graph_def, max_const_size=32):
"""Visualize TensorFlow graph."""
if hasattr(graph_def, 'as_graph_def'):
graph_def = graph_def.as_graph_def()
strip_def = strip_consts(graph_def, max_const_size=max_const_size)
code = """
<script>
function load() {{
document.getElementById("{id}").pbtxt = {data};
}}
</script>
<link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
<div style="height:600px">
<tf-graph-basic id="{id}"></tf-graph-basic>
</div>
""".format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))
iframe = """
<iframe seamless style="width:1200px;height:620px;border:0" srcdoc="{}"></iframe>
""".format(code.replace('"', '"'))
display(HTML(iframe))
puis visualiser le graphe actuel
show_graph(tf.get_default_graph().as_graph_def())
Si votre graphique est enregistré comme pbtxt, vous pourriez faire
gdef = tf.GraphDef()
from google.protobuf import text_format
text_format.Merge(open("tf_persistent.pbtxt").read(), gdef)
show_graph(gdef)
Vous verrez quelque chose comme ceci
j'ai écrit une extension Jupyter pour l'intégration tensorboard. Il peut:
- Démarrer tensorboard juste en cliquant sur un bouton dans Jupyter
- gérer plusieurs instances de tensorboard.
- intégration transparente avec L'interface Jupyter.
j'ai écrit un simple helper qui démarre une planche de tenseur à partir du carnet jupyter. Il suffit d'ajouter cette fonction quelque part en haut de votre ordinateur portable
def TB(cleanup=False):
import webbrowser
webbrowser.open('http://127.0.1.1:6006')
!tensorboard --logdir="logs"
if cleanup:
!rm -R logs/
Et ensuite l'exécuter TB()
chaque fois que vous avez généré votre sommaire. Au lieu d'ouvrir un graphique dans le même jupyter fenêtre:
- démarre un tensorboard
- ouvre un nouvel onglet avec tensorboard
- naviguez vers cet onglet
après avoir terminé l'exploration, cliquez simplement sur l'onglet, et de cesser d'interrompre le noyau. Si vous voulez nettoyer votre répertoire log, après la course, Lancez TB(1)
une version libre Tensorboard / iframes de cette visualisation qui, il est vrai, est encombrée rapidement peut
import pydot
from itertools import chain
def tf_graph_to_dot(in_graph):
dot = pydot.Dot()
dot.set('rankdir', 'LR')
dot.set('concentrate', True)
dot.set_node_defaults(shape='record')
all_ops = in_graph.get_operations()
all_tens_dict = {k: i for i,k in enumerate(set(chain(*[c_op.outputs for c_op in all_ops])))}
for c_node in all_tens_dict.keys():
node = pydot.Node(c_node.name)#, label=label)
dot.add_node(node)
for c_op in all_ops:
for c_output in c_op.outputs:
for c_input in c_op.inputs:
dot.add_edge(pydot.Edge(c_input.name, c_output.name))
return dot
qui peut alors être suivi de
from IPython.display import SVG
# Define model
tf_graph_to_dot(graph).write_svg('simple_tf.svg')
SVG('simple_tf.svg')
pour rendre le graphe sous forme d'enregistrements dans un fichier SVG statique