How to remove the space between subheadings in matplotlib.pyplot?

I am working on a project in which I need to assemble a grid of 10 rows and 3 columns. Although I was able to make stories and arrange subheadings, I could not create a good story without a space, such as the one below from the gridspec documentatation . image without spaces .

I tried the following posts, but still could not completely remove the empty space, as in the example image. Can someone please give me some advice? Thank you

  • Patches of different sizes .
  • How to remove the "empty" space between subheadings?

Here is my image: my image

Below is my code. The full script is here on GitHub . Note: images_2 and images_fool are like numpy arrays of flattened images with a form (1032, 10), and delta is an array of images of a form (28, 28).

def plot_im(array=None, ind=0): """A function to plot the image given a images matrix, type of the matrix: \ either original or fool, and the order of images in the matrix""" img_reshaped = array[ind, :].reshape((28, 28)) imgplot = plt.imshow(img_reshaped) # Output as a grid of 10 rows and 3 cols with first column being original, second being # delta and third column being adversaril nrow = 10 ncol = 3 n = 0 from matplotlib import gridspec fig = plt.figure(figsize=(30, 30)) gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1]) for row in range(nrow): for col in range(ncol): plt.subplot(gs[n]) if col == 0: #plt.subplot(nrow, ncol, n) plot_im(array=images_2, ind=row) elif col == 1: #plt.subplot(nrow, ncol, n) plt.imshow(w_delta) else: #plt.subplot(nrow, ncol, n) plot_im(array=images_fool, ind=row) n += 1 plt.tight_layout() #plt.show() plt.savefig('grid_figure.pdf') 
+8
python numpy matplotlib
source share
2 answers

Note at the beginning: if you want to have full control over the interval, avoid using plt.tight_layout() , as it will try to change the graphics in your drawing equally and nicely distributed. This is mostly beautiful and gives nice results, but adjusts the interval as desired.

The reason the GridSpec example that you quote in the Matplotlib example gallery works so well because the aspect of the subheadings is not predefined. That is, the subnets will simply expand on the grid and leave the specified distance (in this case wspace=0.0, hspace=0.0 ) regardless of the size of the figure.

In contrast, you draw images using imshow , and the default image size is equal to (equivalent to ax.set_aspect("equal") ). However, you can of course put set_aspect("auto") on each plot (and optionally add wspace=0.0, hspace=0.0 as arguments to GridSpec, as in the gallery example), which will lead to the creation of a graph without spaces .

However, when using images, it makes sense to maintain the same aspect ratio so that each pixel is as large as possible and the square array is shown as a square image.
Then you will need to play with the image size and margins to get the expected result. The figsize argument for a shape is a number (width, height) in inches, and here you can play the ratio of two numbers. And the parameters of the subtask wspace, hspace, top, bottom, left can be manually configured to give the desired result. The following is an example:

 import numpy as np import matplotlib.pyplot as plt from matplotlib import gridspec nrow = 10 ncol = 3 fig = plt.figure(figsize=(4, 10)) gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1], wspace=0.0, hspace=0.0, top=0.95, bottom=0.05, left=0.17, right=0.845) for i in range(10): for j in range(3): im = np.random.rand(28,28) ax= plt.subplot(gs[i,j]) ax.imshow(im) ax.set_xticklabels([]) ax.set_yticklabels([]) #plt.tight_layout() # do not use this!! plt.show() 

enter image description here

Edit:
Of course, it is advisable not to manually configure the settings. Thus, it is possible to calculate some optimal by the number of rows and columns.

 nrow = 7 ncol = 7 fig = plt.figure(figsize=(ncol+1, nrow+1)) gs = gridspec.GridSpec(nrow, ncol, wspace=0.0, hspace=0.0, top=1.-0.5/(nrow+1), bottom=0.5/(nrow+1), left=0.5/(ncol+1), right=1-0.5/(ncol+1)) for i in range(nrow): for j in range(ncol): im = np.random.rand(28,28) ax= plt.subplot(gs[i,j]) ax.imshow(im) ax.set_xticklabels([]) ax.set_yticklabels([]) plt.show() 
+6
source share

Try adding this line to your code:

 fig.subplots_adjust(wspace=0, hspace=0) 

And for each axis object:

 ax.set_xticklabels([]) ax.set_yticklabels([]) 
+3
source share