guodong's blog

master@zhejiang university
   

tensorflow学习(1):tf.app模块

很多大型模型都会使用tf.app模型,其作用是:脚本函数的入口

包含有flags模块和run函数。

flags模块:

就是absl.flags的路由,即在app.flags上定义flag与在absl.flags定义相同。

除此之外,tf.flags也有相同功能。

run()函数:

tf.app.run(
    main=None,
    argv=None
) 

类似整个工程的main函数,程序的入口。

tf.app.run()与if __name__==”__main__”:的区别:

if __name__==”__main__”:表示当前的文件被作为一个脚本在shell里执行,而不是被import成module。

tf.app.run()可以通过源代码app.py

def run(main=None, argv=None):

"""Runs the program with an optional 'main' function and 'argv' list."""

f = flags.FLAGS

# Extract the args from the optional `argv` list.

args = argv[1:] if argv else None

# Parse the known flags from that list, or from the command

# line otherwise.

# pylint: disable=protected-access

flags_passthrough = f._parse_flags(args=args)

# pylint: enable=protected-access

main = main or sys.modules['__main__'].main

# Call the main function, passing through any arguments

# to the final program.

sys.exit(main(sys.argv[:1] + flags_passthrough))

其中这一行:

flags_passthrough = f._parse_flags(args=args) 

这确保了通过命令传入的参数都是有效的。例如

python my_model.py --data_dir='...' --max_iteration=10000
main = main or sys.modules['__main__'].main

等号后的第一个main是run(main=NONE,argv=None)里面的第一个参数。sys.modules[‘__main__’]表示当前运行的文件,例如上面的my_model.py。

所以这里有两种情况:

  1. 如果在my_model.py没有main函数,这时需要调用 tf.app.run(my_main_running_function)
  2. 有main函数,直接跑。(这是最常见的)

最后一行代码确保使用已解析正确的参数调用函数。

tf.app()中的flags还有一个重要的功能:全局性。我们可以自定义自己的flags,比如

 tf.flags.DEFINE_integer('batch_size', 128, 'Number of images to process in a batch.') 

然后使用tf.app.run()进行设置。

使用flag的目的:(转自Stack Overflow)

The tf.app.flags module is a functionality provided by Tensorflow to implement command line flags for your Tensorflow program. As an example, the code you came across would do the following:

flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')

The first parameter defines the name of the flag while the second defines the default value in case the flag is not specified while executing the file.

So if you run the following:

$ python fully_connected_feed.py --learning_rate 1.00

then the learning rate is set to 1.00 and will remain 0.01 if the flag is not specified.

As mentioned in this article, the docs are probably not present because this might be something that Google requires internally for its developers to use.

Also, as mentioned in the post, there are several advantages of using Tensorflow flags over flag functionality provided by other Python packages such as argparse especially when dealing with Tensorflow models, the most important being that you can supply Tensorflow specific information to the code such as information about which GPU to use.

另外 tf.app.flags.DEFINE_xxx()就是添加命令行的optional argument(可选参数),而tf.app.flags.FLAGS可以从对应的命令行参数取出参数。




上一篇:
下一篇:

guodong

没有评论


你先离开吧:)



发表评论

电子邮件地址不会被公开。 必填项已用*标注