全国免费咨询:

13245491521

VR图标白色 VR图标黑色
X

中高端软件定制开发服务商

与我们取得联系

13245491521     13245491521

2018-10-05_令人困惑的 TensorFlow!(II)

您的位置:首页 >> 新闻 >> 行业资讯

令人困惑的 TensorFlow!(II) 选自Jacobbuckman 作者:Jacob Buckman 机器之心编译 参与:高璇、王淑婷 六月底,机器之心发布了「令人困惑的 TensorFlow!」,讲述了初上手 TensorFlow 时会遇到的麻烦。日前,作者更新博客,对该文写了续篇,主要讲述了保存和加载 TensorFlow 模型以及上下文管理器的一些问题。 命名和域 命名变量和张量 正如我们在第一部分讨论的,每次调用 tf.get_variable() 时,都需要为变量赋予一个新的唯一名称。实际上,图中的每个张量也需要一个唯一的名称。可以通过张量、操作和变量的 .name 属性访问该名称。绝大多数情况下,名称会自动创建;例如,一个常量节点会以 Const 命名,当创建更多常量节点时,其名称将是 Const_1,Const_2 等。还可以通过 name=的属性设置节点名称,列举后缀仍会自动添加: 代码: importtensorflowastf a=tf.constant(0.) b=tf.constant(1.) c=tf.constant(2.,name="cool_const") d=tf.constant(3.,name="cool_const") printa.name,b.name,c.name,d.name 输出: Const:0Const_1:0cool_const:0cool_const_1:0 虽然节点命名并非必要,但在调试时非常有用。当 Tensorflow 代码崩溃时,error trace 将指向一个特定的操作。如果有很多同类型的操作,那么很难确定是哪一个出了问题。而通过明确命名每个节点,可以获得信息详细的 error trace,并更快地识别问题。 使用范围 随着图形越来越复杂,手动命名所有内容变得愈加困难。Tensorflow 提供 tf.variable_scope 对象,它通过将图形细分为更小的组块,使图形更易梳理。通过将一段图形创建代码封装在 with tf.variable_scope(scope_name):语句中,创建的所有节点名称都将自动以 scope_name 字符串作为前缀。此外,这些作用域堆栈,在另一个范围内创建的作用域会简单地将前缀链接在一起,用斜杠分隔。 代码: importtensorflowastf a=tf.constant(0.) b=tf.constant(1.) withtf.variable_scope("first_scope"): c=a+b d=tf.constant(2.,name="cool_const") coef1=tf.get_variable("coef",[],initializer=tf.constant_initializer(2.)) withtf.variable_scope("second_scope"): e=coef1*d coef2=tf.get_variable("coef",[],initializer=tf.constant_initializer(3.)) f=tf.constant(1.) g=coef2*f printa.name,b.name printc.name,d.name printe.name,f.name,g.name printcoef1.name printcoef2.name 输出: Const:0Const_1:0 first_scope/add:0first_scope/cool_const:0 first_scope/second_scope/mul:0first_scope/second_scope/Const:0first_scope/second_scope/mul_1:0 first_scope/coef:0 first_scope/second_scope/coef:0 我们能够使用代码 coef 创建两个名称相同的变量。这是因为作用域可以将名称转换为 first_scope/coef:0 和 first_scope/second_scope/coef:0,它们是不同的。 保存和加载 训练好的神经网络包括两个基本组成部分: 已经学习过某些任务优化的网络权重 说明如何利用权重获得结果的网络图 Tensorflow 将这两个组件分开,但很明显它们需要紧密匹配。如果没有图结构进行说明,那权重也无用,而带有随机权重的图也效果也不好。事实上,即使仅交换两个权重矩阵也可能完全破坏模型。这通常会让 Tensorflow 初学者感觉很挫败。使用预先训练好的模型作为神经网络的一个组成部分不失为加速训练的好方法,但是也有可能搞砸一切。 保存模型 当只有单个模型时,Tensorflow 用于保存和加载的内置工具使用很方便:只需创建一个 tf.train.Saver()。类似于 tf.train.Optimizer,tf.train.Saver 本身并不是一个节点,而是在已有图形上执行有用功能的更高级类别。你可能已经预料到 tf.train 的「有用功能」了,即保存和加载模型。 代码: importtensorflowastf a=tf.get_variable('a',[]) b=tf.get_variable('b',[]) init=tf.global_variables_initializer() saver=tf.train.Saver() sess=tf.Session() sess.run(init) saver.save(sess,'./tftcp.model') 输出 四个新文件: checkpoint tftcp.model.data-00000-of-00001 tftcp.model.index tftcp.model.meta 具体内容分析如下: 首先:当我们只保存一个模型时,为什么会输出四个文件?重建模型所需的信息被分散到它们当中。如果想复制或者备份模型,需要有四个文件(前缀为文件名)。下面简述答案: tftcp.model.data-00000-of-00001 包含模型权重(上述第一个要点)。它可能这里最大的文件。 tftcp.model.meta 是模型的网络结构(上述第二个要点)。它包含重建图形所需的所有信息。 tftcp.model.index 是连接前两点的索引结构。用于在数据文件中找到对应节点的参数。 checkpoint 实际上不需要重建模型,但如果在整个训练过程中保存了多个版本的模型,那它会跟踪所有内容。 其次,我为什么一定要为该示例创建 tf.Session 和 tf.global_variables_initializer 呢? 因为,如果要保存一个模型,我们需要保存相关的内容。计算存于图中,但数值存于会话中。tf.train.Saver 可以通过指向图表的全局指针访问网络结构。但当我们保存变量的值(即网络权重)时,我们需要访问 tf.Session 来确定这些值;这就是为什么 sess 作为 save 函数的第一个参数传入。此外,尝试保存未初始化的变量会引发错误,因为尝试访问未初始化变量的值总是会引发错误。因此,我们需要一个会话和一个初始化程序(或等价的 tf.assign)。 加载模型 既然我们已经保存了模型,现在重新加载它。第一步是重新创建变量:我们希望变量的名称、形状和类型都与保存时一致。第二步是创建与之前一样的 tf.train.Saver,并调用 restore 函数。 代码: importtensorflowastf a=tf.get_variable('a',[]) b=tf.get_variable('b',[]) saver=tf.train.Saver() sess=tf.Session() saver.restore(sess,'./tftcp.model') sess.run([a,b]) 输出: [1.3106428,0.6413864] 在运行之前,我们不需要初始化 a 或 b!这是因为 restore 运算将值从文件移动到会话的变量中。由于会话不再包含任何空值变量,因此不再需要初始化。(如果不小心,会适得其反:还原后运行 init 会使随机初始化的值覆盖加载的值。) 选择变量 当一个 tf.train.Saver 程序初始化后,它会查看当前图形并获取变量列表;这是 saver「关心」的永久存储的变量列表。我们可以用._var_list 属性来检查: 代码: importtensorflowastf a=tf.get_variable('a',[]) b=tf.get_variable('b',[]) saver=tf.train.Saver() c=tf.get_variable('c',[]) printsaver._var_list 输出: [tf.Variable'a:0'shape=()dtype=float32_ref,tf.Variable'b:0'shape=()dtype=float32_ref] 因为在创建 saver 时 c 还没有出现,所以它并没有成为函数的一部分。一般来说,你要在创建 saver 之前确保已经创建了所有的变量。 当然,在某些特定的情况下,可能只需保存变量的一个子集。当创建 var_list 以期望它跟踪可用变量子集时,tf.train.Saver 允许传递 var_list。 代码: importtensorflowastf a=tf.get_variable('a',[]) b=tf.get_variable('b',[]) c=tf.get_variable('c',[]) saver=tf.train.Saver(var_list=[a,b]) printsaver._var_list 输出: [tf.Variable'a:0'shape=()dtype=float32_ref,tf.Variable'b:0'shape=()dtype=float32_ref] 加载修正模型 上面例子中涵盖的模型加载方案类似于物理中的「真空中无摩擦的完美球体」(perfect sphere in frictionless vacuum)场景。只要你使用自己的代码保存和加载模型,且不擅自更改二者,实现保存和加载轻而易举。但很多情况下,并不会有如此完美的场景。在这些情况下,我们需要多加思量。 让我们通过几个场景来说明这些问题。首先,如果我们想保存一个完整的模型,但只想加载其中的一部分怎么办?(在下面代码示例中,我依次运行两个脚本。) 代码: importtensorflowastf a=tf.get_variable('a',[]) b=tf.get_variable('b',[]) init=tf.global_variables_initializer() saver=tf.train.Saver() sess=tf.Session() sess.run(init) saver.save(sess,'./tftcp.model')importtensorflowastf a=tf.get_variable('a',[]) init=tf.global_variables_initializer() saver=tf.train.Saver() sess=tf.Session() sess.run(init) saver.restore(sess,'./tftcp.model') sess.run(a) 输出: 1.1700551 OK。当我们在相反的场景里,就会出现失败的状况:我们希望将一个模型作为大型模型的组件加载。 代码: importtensorflowastf a=tf.get_variable('a',[]) init=tf.global_variables_initializer() saver=tf.train.Saver() sess=tf.Session() sess.run(init) saver.save(sess,'./tftcp.model') importtensorflowastf a=tf.get_variable('a',[]) d=tf.get_variable('d',[]) init=tf.global_variables_initializer() saver=tf.train.Saver() sess=tf.Session() sess.run(init) saver.restore(sess,'./tftcp.model') 输出: Keydnotfoundincheckpoint [[{{nodesave/RestoreV2}}=RestoreV2[dtypes=[DT_FLOAT,DT_FLOAT,DT_FLOAT],_device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0,save/RestoreV2/tensor_names,save/RestoreV2/shape_and_slices)]] 我们只想加载 a,却忽略了新变量 b。我们犯了一个错误,却抱怨 d 没有出现在 checkpoint 中。 第三种情况是,我们想将一个模型的参数加载到另一个模型的计算图中。这也会引发一个错误,原因很明显:Tensorflow 不知道把加载的所有参数放置在何处。幸好有个方法可以给它点提示。 还记得 var_list 吗?或者更准确来说是「var_list_or_dictionary_mapping_names_to_vars」,但这个名字有点拗口,所以他们使用第一个。 保存模型是 Tensorflow 要求使用全局唯一变量名的关键原因之一。在保存-模型-文件中,每个保存变量的名称都与其形状和值有关。将其加载到新的计算图中与将想要加载的变量的原始名称映射到当前模型的变量中一样简单。示例如下: 代码: importtensorflowastf a=tf.get_variable('a',[]) init=tf.global_variables_initializer() saver=tf.train.Saver() sess=tf.Session() sess.run(init) saver.save(sess,'./tftcp.model') importtensorflowastf d=tf.get_variable('d',[]) init=tf.global_variables_initializer() saver=tf.train.Saver(var_list={'a':d}) sess=tf.Session() sess.run(init) saver.restore(sess,'./tftcp.model') sess.run(d) 输出: -0.9303965 这是一种关键机制,通过这个机制,可以将没有相同计算图的模型组合在一起。例如,你可能从网上获得了一个预训练好的语言模型,希望重用词嵌入。或者你可能在两次训练之间改变了模型的参数化,想让这个新版本在旧版本的基础上继续前进;但你又不想重新训练整个过程。在这两种情况下,你只需手动创建一个字典,将旧变量名称映射到新变量即可。 需要注意的是:你要明确地知道正在加载的参数是如何使用的。如果可以,你应该使用原作者用来构建模型的确切代码,以确保计算图的组件与训练时看起来一样。如果需要复现模型,务必记住,无论多微小的更改,都可能严重损害预训练网络的性能。所以始终要将复现结果和原来的结果进行对比。 模型检查 如果想加载的模型来源于网络或由自己创建(两个月前),那你很可能不知道原始变量是如何命名的。要检查保存的模型,需要使用官方 Tensorflow 库的一些工具。 链接:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/framework/python/framework/checkpoint_utils.py 代码: importtensorflowastf a=tf.get_variable('a',[]) b=tf.get_variable('b',[10,20]) c=tf.get_variable('c',[]) init=tf.global_variables_initializer() saver=tf.train.Saver() sess=tf.Session() sess.run(init) saver.save(sess,'./tftcp.model') printtf.contrib.framework.list_variables('./tftcp.model') 输出: [('a',[]),('b',[10,20]),('c',[])] 利用这些工具(结合原始代码库一起使用)通常可以找到你想要的变量名称。 结论 希望本文能帮你了解关于保存和加载 Tensorflow 模型的基础知识。还有其他一些高级技巧,比如自动 checkpoint 和保存/恢复元图,可能会在以后的文章中提到;但是根据我的经验,这些并不常用,特别是对于初学者来说。 原文链接:https://jacobbuckman.com/post/tensorflow-the-confusing-parts-2/ 本文为机器之心编译,转载请联系本公众号获得授权。 ?------------------------------------------------ 加入机器之心(全职记者 / 实习生):hr@jiqizhixin.com 投稿或寻求报道:content@jiqizhixin.com 广告 & 商务合作:bd@jiqizhixin.com

上一篇:2023-12-11_对话清华武廷海:技术是衡量人类文明的尺子 | WeCityX系列专家访谈 下一篇:2025-09-23_「转」椰树的倾诉欲,太强了

TAG标签:

20
网站开发网络凭借多年的网站建设经验,坚持以“帮助中小企业实现网络营销化”为宗旨,累计为4000多家客户提供品质建站服务,得到了客户的一致好评。如果您有网站建设网站改版域名注册主机空间手机网站建设网站备案等方面的需求...
请立即点击咨询我们或拨打咨询热线:13245491521 13245491521 ,我们会详细为你一一解答你心中的疑难。
项目经理在线

相关阅读 更多>>

猜您喜欢更多>>

我们已经准备好了,你呢?
2022我们与您携手共赢,为您的企业营销保驾护航!

不达标就退款

高性价比建站

免费网站代备案

1对1原创设计服务

7×24小时售后支持

 

全国免费咨询:

13245491521

业务咨询:13245491521 / 13245491521

节假值班:13245491521()

联系地址:

Copyright © 2019-2025      ICP备案:沪ICP备19027192号-6 法律顾问:律师XXX支持

在线
客服

技术在线服务时间:9:00-20:00

在网站开发,您对接的直接是技术员,而非客服传话!

电话
咨询

13245491521
7*24小时客服热线

13245491521
项目经理手机

微信
咨询

加微信获取报价