Une façon Simple de visualiser un graphique de TensorFlow dans Jupyter?

la façon officielle de visualiser un graphe TensorFlow est avec TensorBoard, mais parfois je veux juste un rapide coup d'oeil au graphe quand je travaille à Jupyter.

Existe-t-il une solution rapide, idéalement basée sur des outils TensorFlow, ou des paquets scipy standard (comme matplotlib), mais si nécessaire basée sur des bibliothèques tierces?

40
demandé sur MiniQuark 2016-07-04 19:33:32

4 réponses

Voici une recette que j'ai copiée d'un rêve profond D'Alex Mordvintsev ordinateur portable à un certain point

from IPython.display import clear_output, Image, display, HTML

def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = "<stripped %d bytes>"%size
    return strip_def

def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))

    iframe = """
        <iframe seamless style="width:1200px;height:620px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))
    display(HTML(iframe))

puis visualiser le graphe actuel

show_graph(tf.get_default_graph().as_graph_def())

Si votre graphique est enregistré comme pbtxt, vous pourriez faire

gdef = tf.GraphDef()
from google.protobuf import text_format
text_format.Merge(open("tf_persistent.pbtxt").read(), gdef)
show_graph(gdef)

Vous verrez quelque chose comme ceci

enter image description here

76
répondu Yaroslav Bulatov 2017-02-25 14:42:57

j'ai écrit une extension Jupyter pour l'intégration tensorboard. Il peut:

  1. Démarrer tensorboard juste en cliquant sur un bouton dans Jupyter
  2. gérer plusieurs instances de tensorboard.
  3. intégration transparente avec L'interface Jupyter.

Github:https://github.com/lspvic/jupyter_tensorboard

9
répondu Liu Shengpeng 2017-08-22 09:18:01

j'ai écrit un simple helper qui démarre une planche de tenseur à partir du carnet jupyter. Il suffit d'ajouter cette fonction quelque part en haut de votre ordinateur portable

def TB(cleanup=False):
    import webbrowser
    webbrowser.open('http://127.0.1.1:6006')

    !tensorboard --logdir="logs"

    if cleanup:
        !rm -R logs/

Et ensuite l'exécuter TB() chaque fois que vous avez généré votre sommaire. Au lieu d'ouvrir un graphique dans le même jupyter fenêtre:

  • démarre un tensorboard
  • ouvre un nouvel onglet avec tensorboard
  • naviguez vers cet onglet

après avoir terminé l'exploration, cliquez simplement sur l'onglet, et de cesser d'interrompre le noyau. Si vous voulez nettoyer votre répertoire log, après la course, Lancez TB(1)

4
répondu Salvador Dali 2017-04-23 07:41:47

une version libre Tensorboard / iframes de cette visualisation qui, il est vrai, est encombrée rapidement peut

import pydot
from itertools import chain
def tf_graph_to_dot(in_graph):
    dot = pydot.Dot()
    dot.set('rankdir', 'LR')
    dot.set('concentrate', True)
    dot.set_node_defaults(shape='record')
    all_ops = in_graph.get_operations()
    all_tens_dict = {k: i for i,k in enumerate(set(chain(*[c_op.outputs for c_op in all_ops])))}
    for c_node in all_tens_dict.keys():
        node = pydot.Node(c_node.name)#, label=label)
        dot.add_node(node)
    for c_op in all_ops:
        for c_output in c_op.outputs:
            for c_input in c_op.inputs:
                dot.add_edge(pydot.Edge(c_input.name, c_output.name))
    return dot

qui peut alors être suivi de

from IPython.display import SVG
# Define model
tf_graph_to_dot(graph).write_svg('simple_tf.svg')
SVG('simple_tf.svg')

pour rendre le graphe sous forme d'enregistrements dans un fichier SVG statique Integrated Tensorflow Graph in Dot

3
répondu kmader 2017-05-11 16:37:39