使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:
http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
我对这篇文章进行了整理和汇总。
首先是模型的保存。直接上代码:
#!/usr/bin/env python #-*- coding:utf-8 -*- ############################ #File Name: tut1_save.py #Author: Wang #Mail: wang19920419@hotmail.com #Created Time:2017-08-30 11:04:25 ############################ import tensorflow as tf # prepare to feed input, i.e. feed_dict and placeholders w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') b1 = tf.Variable(2.0, name = 'bias1') feed_dict = {w1:[10,3], w2:[5,5]} # define a test operation that will be restored w3 = tf.add(w1, w2) # without name, w3 will not be stored w4 = tf.multiply(w3, b1, name = "op_to_restore") #saver = tf.train.Saver() saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) sess = tf.Session() sess.run(tf.global_variables_initializer()) print sess.run(w4, feed_dict) #saver.save(sess, 'my_test_model', global_step = 100) saver.save(sess, 'my_test_model') #saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)
需要说明的有以下几点:
1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。
2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。
3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。
下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:
#!/usr/bin/env python #-*- coding:utf-8 -*- ############################ #File Name: tut2_import.py #Author: Wang #Mail: wang19920419@hotmail.com #Created Time:2017-08-30 14:16:38 ############################ import tensorflow as tf sess = tf.Session() new_saver = tf.train.import_meta_graph('my_test_model.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./')) print sess.run('w1:0')
使用加载的模型,输入新数据,计算输出,还是直接上代码:
#!/usr/bin/env python #-*- coding:utf-8 -*- ############################ #File Name: tut3_reuse.py #Author: Wang #Mail: wang19920419@hotmail.com #Created Time:2017-08-30 14:33:35 ############################ import tensorflow as tf sess = tf.Session() # First, load meta graph and restore weights saver = tf.train.import_meta_graph('my_test_model.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) # Second, access and create placeholders variables and create feed_dict to feed new data graph = tf.get_default_graph() w1 = graph.get_tensor_by_name('w1:0') w2 = graph.get_tensor_by_name('w2:0') feed_dict = {w1:[-1,1], w2:[4,6]} # Access the op that want to run op_to_restore = graph.get_tensor_by_name('op_to_restore:0') print sess.run(op_to_restore, feed_dict) # ouotput: [6. 14.]
在已经加载的网络后继续加入新的网络层:
import tensorflow as tf sess=tf.Session() #First let's load meta graph and restore weights saver = tf.train.import_meta_graph('my_test_model-1000.meta') saver.restore(sess,tf.train.latest_checkpoint('./')) # Now, let's access and create placeholders variables and # create feed-dict to feed new data graph = tf.get_default_graph() w1 = graph.get_tensor_by_name("w1:0") w2 = graph.get_tensor_by_name("w2:0") feed_dict ={w1:13.0,w2:17.0} #Now, access the op that you want to run. op_to_restore = graph.get_tensor_by_name("op_to_restore:0") #Add more to the current graph add_on_op = tf.multiply(op_to_restore,2) print sess.run(add_on_op,feed_dict) #This will print 120.
对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):
...... ...... saver = tf.train.import_meta_graph('vgg.meta') # Access the graph graph = tf.get_default_graph() ## Prepare the feed_dict for feeding data for fine-tuning #Access the appropriate output for fine-tuning fc7= graph.get_tensor_by_name('fc7:0') #use this if you only want to change gradients of the last layer fc7 = tf.stop_gradient(fc7) # It's an identity function fc7_shape= fc7.get_shape().as_list() new_outputs=2 weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) output = tf.matmul(fc7, weights) + biases pred = tf.nn.softmax(output) # Now, you run this with fine-tuning data in sess.run()
有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
《魔兽世界》大逃杀!60人新游玩模式《强袭风暴》3月21日上线
暴雪近日发布了《魔兽世界》10.2.6 更新内容,新游玩模式《强袭风暴》即将于3月21 日在亚服上线,届时玩家将前往阿拉希高地展开一场 60 人大逃杀对战。
艾泽拉斯的冒险者已经征服了艾泽拉斯的大地及遥远的彼岸。他们在对抗世界上最致命的敌人时展现出过人的手腕,并且成功阻止终结宇宙等级的威胁。当他们在为即将于《魔兽世界》资料片《地心之战》中来袭的萨拉塔斯势力做战斗准备时,他们还需要在熟悉的阿拉希高地面对一个全新的敌人──那就是彼此。在《巨龙崛起》10.2.6 更新的《强袭风暴》中,玩家将会进入一个全新的海盗主题大逃杀式限时活动,其中包含极高的风险和史诗级的奖励。
《强袭风暴》不是普通的战场,作为一个独立于主游戏之外的活动,玩家可以用大逃杀的风格来体验《魔兽世界》,不分职业、不分装备(除了你在赛局中捡到的),光是技巧和战略的强弱之分就能决定出谁才是能坚持到最后的赢家。本次活动将会开放单人和双人模式,玩家在加入海盗主题的预赛大厅区域前,可以从强袭风暴角色画面新增好友。游玩游戏将可以累计名望轨迹,《巨龙崛起》和《魔兽世界:巫妖王之怒 经典版》的玩家都可以获得奖励。