How to update BatchNorm variable on multiple GPUs in Tensorflow

I have a network that trains the batch layer. The batch size is 16, so I have to use several GPUs. I followed the example of inceptionv3 , which can be summarized as

with tf.Graph().as_default(), tf.device('/cpu:0'): images_splits = tf.split(axis=0, num_or_size_splits=FLAGS.num_gpus, value=images) labels_splits = tf.split(axis=0, num_or_size_splits=FLAGS.num_gpus, value=labels) for i in range(FLAGS.num_gpus): with tf.device('/gpu:%d' % i): with tf.name_scope('%s_%d' % (inception.TOWER_NAME, i)) as scope: ... # Reuse variables for the next tower. batchnorm_updates = tf.get_collection(slim.ops.UPDATE_OPS_COLLECTION, scope) grads = opt.compute_gradients(loss) tower_grads.append(grads) grads = _average_gradients(tower_grads) apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) variable_averages = tf.train.ExponentialMovingAverage( inception.MOVING_AVERAGE_DECAY, global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) variables_averages_op = variable_averages.apply(variables_to_average) batchnorm_updates_op = tf.group(*batchnorm_updates) train_op = tf.group(apply_gradient_op, variables_averages_op, batchnorm_updates_op) 

Unfortunately, he used a thin library for the BN layer, while I used the standard BN tf.contrib.layers.batch_norm

 def _batch_norm(self, x, name, is_training, activation_fn, trainable=False): with tf.variable_scope(name+'/BatchNorm') as scope: o = tf.contrib.layers.batch_norm( x, scale=True, activation_fn=activation_fn, is_training=is_training, trainable=trainable, scope=scope) return o 

To collect mov_mean and moving_variance, I used tf.GraphKeys.UPDATE_OPS

 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): self.train_op = tf.group(train_op_conv, train_op_fc) 

Finally, the idea of ​​using BN in multiple GPUs can borrow from inceptionv3 as

 split_image_batch = tf.split(self.image_batch, self.conf.num_gpus, 0) split_label_batch = tf.split(self.label_batch, self.conf.num_gpus, 0) global_step = tf.train.get_or_create_global_step() opt= tf.train.MomentumOptimizer(self.learning_rate, self.conf.momentum) tower_grads_encoder = [] tower_grads_decoder = [] update_ops=[] with tf.variable_scope(tf.get_variable_scope()): for i in range(self.conf.num_gpus): with tf.device('/gpu:%d' % i): net = Resnet(split_image_batch[i], self.conf.num_classes) #Build BN layer # Loss function self.reduced_loss = tf.reduce_mean(loss) + tf.add_n(l2_losses) # Reuse variables for the next GPU. tf.get_variable_scope().reuse_variables() update_ops.extend)tf.get_collection(tf.GraphKeys.UPDATE_OPS)) # Compute grads grads_encoder = opt.compute_gradients(self.reduced_loss, var_list=encoder_trainable) grads_decoder = opt.compute_gradients(self.reduced_loss, var_list=decoder_trainable) tower_grads_encoder.append(grads_encoder) tower_grads_decoder.append(grads_decoder) grads_encoder = self._average_gradients(tower_grads_encoder) grads_decoder = self._average_gradients(tower_grads_decoder) # Update params train_op_conv = opt.apply_gradients(grads_encoder, global_step=global_step) train_op_fc = opt.apply_gradients(grads_decoder,global_step=global_step) variable_averages = tf.train.ExponentialMovingAverage(self.conf.MOVING_AVERAGE_DECAY, global_step) variables_averages_op = variable_averages.apply(tf.trainable_variables()) with tf.control_dependencies(update_ops): self.train_op = tf.group(train_op_conv, train_op_fc, variables_averages_op) 

Although the code worked without errors, the performance is very low. It seems that I did not correctly collect the BN parameters. Could you take a look at my code and give me some direction for learning BN in multiple GPUs? Thanks

0
deep-learning machine-learning tensorflow
source share
1 answer

I suspect that performance problems are related to the fact that you are doing several variable updates for each step (from each batch norm in each tower).

Is there a reason you need to receive serial updates from each GPU? We recommend that you use statistics from only one tower to update the batch standard, since if your partition has no distortions (which will cause other problems), this should turn out to be the same.

If you limit periodic updates to those on the same tower, you reduce your update variables in num_gpus .

0
source share

All Articles