Tensorflow ——分布式数据异步并行代码(单机多卡、多机多卡)

tech2023-05-26  103

Tensorflow ——分布式数据异步并行代码(单机多卡、多机多卡)

在分布式计算中创建Session需要用到MonitoredTrainingSession,区别于普通Session最主要的参数是is_chief。

Session( target='', graph=None, config=None ) MonitoredTrainingSession( master='', is_chief=True, checkpoint_dir=None, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=600, save_summaries_steps=USE_DEFAULT, save_summaries_secs=USE_DEFAULT, config=None, stop_grace_period_secs=120, log_step_count_steps=100 )

is_chief:用于分布式系统中,用于判断该系统是否是chief,如果为True,它将负责初始化并恢复底层TensorFlow会话。如果为False,它将等待chief初始化或恢复TensorFlow会话。当杀死chief除外的其中一个进程,再开启,会接着以当前train_step开始训练,而不是重头训练,而普通Session会从step=0重新开始。

Tensorflow分布式数据异步(单机多卡、多机多卡)代码框架:

"""分布式训练""" ps_hosts = ["127.0.0.1:2222", "127.0.0.1:2223"] worker_hosts = ["127.0.0.1:2224", "127.0.0.1:2225"] cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) flags = tf.app.flags flags.DEFINE_string('job_name', 'ps', "One of 'ps', 'worker'") flags.DEFINE_string('task_index', 0, "Index of task within the job") FLAGS = flags.FLAGS if __name__ == '__main__': server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if flags.job_name == "ps": server.join() elif flags.job_name == "worker": with tf.device("/job:worker/task:" + str(FLAGS.task_index)): print('build model') config = tf.ConfigProto( allow_soft_placement=True ) hooks = [tf.train.StopAtStepHook(last_step=1000000)] with tf.train.MonitoredTrainingSession(master="grpc://" + worker_hosts[FLAGS.task_index], is_chief=(FLAGS.task_index == 0), checkpoint_dir="", summary_dir="", save_checkpoint_steps=0, hooks=hooks) as sess: while not sess.should_stop(): # 不需要初始化 print('train model') # 需要执行四次,并指定gpu # python distributed_tf.py --job_name=ps --task_index=0 # python distributed_tf.py --job_name=ps --task_index=1 # python distributed_tf.py --job_name=worker --task_index=0 # python distributed_tf.py --job_name=worker --task_index=1 这样的方式每个worker运行可能会获取相重复的一批数据,训练时损失会抖动,建议学习率取小。

测试

第一个worker

第二个worker

可以看出两个worker的train_step是不重复的,值得注意的是,train_step不能直接赋值,而是要用全局global_step:

self.global_step = tf.train.get_or_create_global_step() train_steps = sess.run([self.global_step],feed_dict=feed_dict)
最新回复(0)