Enregistrement TensorFlow dans / Chargement d'un graphique à partir d'un fichier
D'après ce que j'ai recueilli jusqu'à présent, il y a plusieurs façons différentes de décharger un graphique de TensorFlow dans un fichier et ensuite de le charger dans un autre programme, mais je n'ai pas été en mesure de trouver des exemples clairs/des informations sur la façon dont ils fonctionnent. Ce que je sais déjà c'est ceci:
- Enregistrer le modèle des variables dans un fichier de point de contrôle (.ckpt) en utilisant un
tf.train.Saver()
et les restaurer plus tard ( source ) - Enregistrer un modèle dans un .pb de fichier et la charge à l'aide de
tf.train.write_graph()
ettf.import_graph_def()
( source ) - charge dans un modèle de A.pb de fichier, de recycler, et la déverser dans une nouvelle .pb de fichier à l'aide avec les sections de bazel ( source )
- geler le graphe pour sauvegarder le graphe et les poids ensemble ( source )
- utiliser
as_graph_def()
pour sauvegarder le modèle, et pour les poids/variables, les mapper en constantes ( source )
cependant, je n'ai pas été en mesure de clarifier plusieurs questions concernant ces différentes méthodes:
- en ce qui concerne les fichiers des points de contrôle, ne sauvegardent-ils que les poids formés d'un modèle? Les fichiers de checkpoint peuvent-ils être chargés dans un nouveau programme, et être utilisés pour exécuter le modèle, ou servent-ils simplement à sauvegarder les poids dans un modèle à un moment ou à une étape donnés?
- concernant
tf.train.write_graph()
, les poids/variables sont-ils également sauvegardés? - en ce qui concerne Bazel, peut-il seulement sauver dans/Charger à partir .fichiers pb pour le recyclage? Est-ce Qu'il y a une commande Bazel simple juste pour balancer un graphe dans un .le pb?
- en ce qui concerne la congélation, un graphe congelé peut-il être chargé en utilisant
tf.import_graph_def()
? - Le Android démo pour TensorFlow charges dans Google Création du modèle à partir d'une .pb de fichier. Si je voulais remplacer mon propre .pb de fichier, comment pourrais-je aller sur faire cela? Est-ce que je devrais changer un code/méthode natif?
- En général, quelle est exactement la différence entre toutes ces méthodes? Ou plus généralement, Quelle est la différence entre
as_graph_def()
/.CKPT./le pb?
en bref, ce que je cherche est une méthode pour enregistrer à la fois un graphique (comme dans, les diverses opérations et tel) et ses poids / variables dans un fichier, qui peut ensuite être utilisé pour charger le graphique et les poids dans un autre programme, pour utiliser (pas nécessairement continue/recyclage).
la Documentation sur ce sujet n'est pas très simple, donc toute réponse/information serait grandement appréciée.
2 réponses
il y a plusieurs façons d'aborder le problème de sauver un modèle dans TensorFlow, ce qui peut le rendre un peu confus. Répondre à chacune de vos sous-questions à tour de rôle:
-
les fichiers de points de contrôle (produits par exemple en appelant
saver.save()
sur un objettf.train.Saver
) ne contiennent que les poids, et toute autre variable définie dans le même programme. Les utiliser dans un autre programme, vous devez recréer la structure du graphique associée (par exemple en exécutant le code pour le compiler à nouveau, ou en appelanttf.import_graph_def()
), ce qui indique à TensorFlow ce qu'il faut faire avec ces poids. Notez que l'appelsaver.save()
produit également un fichier contenantMetaGraphDef
, qui contient un graphique et des détails sur la façon d'associer les poids d'un point de contrôle à ce graphique. Voir le tutoriel pour plus de détails. -
tf.train.write_graph()
écrit seulement la structure du graphe; pas les poids. -
Bazel n'est pas lié à la lecture ou à l'écriture de graphiques TensorFlow. (Peut-être ai-je mal compris votre question: n'hésitez pas à la clarifier dans un commentaire.)
-
un graphe congelé peut être chargé en utilisant
tf.import_graph_def()
. Dans ce cas, les poids sont (généralement) intégré dans le graphique, donc vous n'avez pas besoin de charger un checkpoint séparé. -
le changement principal serait de mettre à jour les noms du ou des tenseurs qui sont introduits dans le modèle, et les noms du ou des tenseurs qui sont extraits du modèle. Dans la démo Android de TensorFlow, cela correspondrait aux chaînes
inputName
etoutputName
qui sont passées àTensorFlowClassifier.initializeTensorFlow()
. -
le
GraphDef
est la structure du programme, qui ne change généralement pas au cours du processus de formation. Le point de contrôle est un instantané de l'état d'un processus de formation, qui change généralement à chaque étape du processus de formation. En conséquence, TensorFlow utilise différents formats de stockage pour ces types de données, et L'API de bas niveau fournit différentes façons de les enregistrer et de les charger. Bibliothèques de niveau supérieur, telles que lesMetaGraphDef
bibliothèques, Keras , et skflow s'appuient sur ces mécanismes pour fournir des moyens plus pratiques de sauvegarder et de restaurer un modèle entier.
vous pouvez essayer le code suivant:
with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)