I implemented a dimshuffle for TensorFlow in our Returnn framework ( here ). The code looks like this:
def expand_multiple_dims(x, axes, name="expand_multiple_dims"): """ :param tf.Tensor x: :param list[int]|tuple[int] axes: after completion, tf.shape(y)[axis] == 1 for axis in axes :param str name: scope name :return: y where we have a new broadcast axis for each axis in axes :rtype: tf.Tensor """ with tf.name_scope(name): for i in sorted(axes): x = tf.expand_dims(x, axis=i, name="expand_axis_%i" % i) return x def dimshuffle(x, axes, name="dimshuffle"): """ Like Theanos dimshuffle. Combines tf.transpose, tf.expand_dims and tf.squeeze. :param tf.Tensor x: :param list[int|str]|tuple[int|str] axes: :param str name: scope name :rtype: tf.Tensor """ with tf.name_scope(name): assert all([i == "x" or isinstance(i, int) for i in axes]) real_axes = [i for i in axes if isinstance(i, int)] bc_axes = [i for (i, j) in enumerate(axes) if j == "x"] if x.get_shape().ndims is None: x_shape = tf.shape(x) x = tf.reshape(x, [x_shape[i] for i in range(max(real_axes) + 1)])
Albert
source share