A Complete-ish Guide To Making Scientific Figures for Publication with Python and Matplotlib
(Also Inkscape and MSWORD, unfortunately)
As a PhD student in computational neuroscience, a significant share of my time ended up devoted to annoying little graphic design details used to make figures. I like to automate everything, which means that instead of editing things in Inkscape or Adobe Illustrator like a sane person, I wanted to do everything in code, using Python and matplotlib, only resorting to graphic design programs where it was absolutely necessary (like to make a schematic of an ion channel or something). As I’ve written several papers for publication in this manner, I have obtained a disconcertingly intimate familiarity with the nuances of matplotlib, and I have also, after much trial and error, believe I have settled on a more-or-less optimal pipeline for creating scientific figures in matplotlib. While other guides that focus on the aesthetics of the individual visualizations, I’m going to be emphasizing how to structure and organize your figure-producing code so that you can easily produce all of the figures for your paper with minimum hassle. The full code for this project can be found here.
This tutorial will assume you know the basics of making plots with matplotlib and that your are comfortable with the concept of matplotlib figures, axes, and so on. I will also not focus on using the various libraries with syntactic and graphical sugar that are built as matplotlib wrappers like seaborn
, although the principles I present here are generally applicable and should be easily adaptable to other plotting libraries.
Before we get started, I will note that if you want to create a figure that uses both Python-generated panels and external illustrations, that should work fine with my approach, my code will give you an SVG file at the end that you can easily import as a link or embed into an Inkscape\Illustrator document. The panel layout of the figure will be determined in matplotlib regardless.
Step 1: Defining the Constants: fonts, save folders, etc.
If you’re writing a paper, the paper will probably contain several figures. Often all of the figures will share similar properties, such as the folder where you’re saving the figures, the fonts of the axes, etc. So we’re going to have a Constants.py file that contains all these settings. That way, we won’t have to redefine these things each time, we can just import the constants file. Here’s what my Constants.py file looks like. Make sure that if you specify folders here, that you actually create the folders (Figs and Data).
from matplotlib import pyplot as plt, rcParams
FIG_WIDTH = 7.5 # width of figure in inches
TINY_SIZE = 6
SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12
plt.rc('font', size=SMALL_SIZE) # controls default text sizes
plt.rc('axes', titlesize=TINY_SIZE, titleweight='bold') # fontsize of the axes title
plt.rc('axes', labelsize=SMALL_SIZE) # fontsize of the x and y labels
plt.rc('xtick', labelsize=TINY_SIZE) # fontsize of the tick labels
plt.rc('ytick', labelsize=TINY_SIZE) # fontsize of the tick labels
plt.rc('legend', fontsize=TINY_SIZE) # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
PLOT_FOLDER = 'Figs/' # folder path to save plots
DATA_FOLDER = 'Data/' # folder path for data files
The choice of these values here, especially FIG_WIDTH
is because eventually the figure will appear on an A4 (8.5*11) piece of paper. The width is thus chosen to be 7.5 inches, to account for margins. Thus, we are a priori targeting all the figures to be A4-friendly. Note that we don’t specify the height of the figure in this file because although figures usually take up the entire width of the page in a paper, they often don’t use the full height of the paper; the figure height should thus be specified independently in each figure.
There are also some pre-built stylesheet libraries for scientific figures such as this one, I have not personally used it but maybe I should.
Step 2: Plotting helpers
There are several ancillary tasks than are generally associated with making figures, such as labeling the panels or saving the figures. These tasks may vary depending on your use case. For these tasks, we’ll make a new file called plot_helpers.py. We’ll also include a function to create a blank panel, if you want to hand-draw that panel later.
import matplotlib.transforms as transforms
import constants
def label_panels_mosaic(fig, axes, xloc=0, yloc=1.0, size=constants.BIGGER_SIZE):
"""
Labels the panels in a mosaic plot.
Parameters:
- fig: The figure object.
- axes: A dictionary of axes objects representing the panels.
- xloc: The x-coordinate for the label position (default: 0).
- yloc: The y-coordinate for the label position (default: 1.0).
- size: The font size of the labels (default: constants.BIGGER_SIZE).
"""
for key in axes.keys():
# label physical distance to the left and up:
ax = axes[key]
trans = transforms.ScaledTranslation(-20/72, 7/72, fig.dpi_scale_trans)
ax.text(xloc, yloc, key, transform=ax.transAxes + trans,
fontsize=size, va='bottom')
def make_blank_panel(ax):
"""
Makes a panel blank by turning off the axis and setting aspect ratio to 'auto'.
Parameters:
- ax: The axis object representing the panel.
Returns:
The modified axis object.
"""
ax.axis('off')
ax.set_aspect('auto')
return ax
def save_figure(figure, fignum, folder=constants.PLOT_FOLDER):
"""
Saves the figure as an SVG file.
Parameters:
- figure: The figure object to be saved.
- fignum: The figure number or name.
- folder: The folder path to save the figure (default: constants.PLOT_FOLDER).
"""
figure.savefig(folder + str(fignum) + '.svg', dpi=figure.dpi)
(Check with your journal about their preferred figure file formats, you can easily modify the save_figure
function in the file to use a different format by choosing a different file extension, e.g. write ‘.tiff’ instead of ‘.svg’. The advantage of saving as an svg is that each object in the saved figure is editable in the svg format if you want to edit externally.)
Step 3: Global plotting functions
There will often be certain types of plots that are repeated among multiple panels and/or figures in your project, where the only thing that changes between the plots is the data, but not the type of plot. For these plots, instead of repeating all the code for the plot within each figure file, we’ll just have one generic function for every time we need to make that kind of plot. We’ll put all these generic plotting functions in a new file called plotting_functions.py (creative, I know). For example, I often find myself making myself annotated heatmaps where red means positive and blue means negative. So I’ll add a function to make that kind of heatmap in this file.
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
def create_annotated_posneg_heatmap(ax, data):
"""
Creates an annotated positive-negative heatmap.
Parameters:
- ax: The axis object to plot the heatmap.
- data: The 2D array of data for the heatmap.
Returns:
The modified axis object and the colorbar object.
"""
cmap = mcolors.TwoSlopeNorm(vcenter=0)
im = ax.imshow(data, cmap='bwr', norm=cmap)
cbar = ax.figure.colorbar(im, ax=ax)
for i in range(data.shape[0]):
for j in range(data.shape[1]):
text = ax.text(j, i, f"{data[i, j]:.2f}", ha="center", va="center", color="k")
ax.set_aspect('auto')
return ax, cbar
We’ll test this plot to make sure it works:
if __name__ == "__main__":
data = np.random.randn(5, 5) # Random data for demonstration
print(data)
fig, ax = plt.subplots()
create_annotated_posneg_heatmap(ax, data)
plt.show()
And we get:
Looks decent. Note that there are no axis labels or whatnot here. This is intentional. Because this is a generic function, we’ll add those cosmetics to the axis later.
Step 4: Each figure gets a file with a class
In most cases, the easiest way to structure the code for producing multiple figures is just to have a separate code file for each figure, i.e. fig1.py, fig2.py, etc. (It’s also helpful to have an informative title in the file name, e.g. fig1_demo.py, because you probably won’t remember what’s in each figure).
Inside this file, there should be a class, which will contain all the data you will be plotting as well as the plotting functions. Depending on the use case you may have functions to generate and analyze the data in this class as well, or if you have a complicated data analysis pipeline you can load the data to be plotted from an external source. But either way, the important thing is that all of the data to be plotted in your figure should be saved as an instance variable in the class. This makes life much easier because you don’t have to worry about which data you’re passing to which plotting function, it’s all in the object.
First we’ll generate some data for our plots.
import numpy as np
import matplotlib.pyplot as plt
import constants
import plot_helpers as ph
import plotting_functions as pf
np.random.seed(420) # Set the random seed to a specific value for reproducibility
class fig1_demo:
def __init__(self):
self.unicorn_data = np.random.randint(0, 10, size=50)
self.rainbow_data = np.random.randint(20, 30, size=50)
self.ice_cream_flavors = ['Van.', 'Choc.', 'Straw.']
self.penguin_data = np.random.randint(0, 50, size=len(self.ice_cream_flavors))
self.horn_length_data = np.random.uniform(0, 10, size=100)
self.magic_power_data = (3+0.2*self.horn_length_data)+np.random.randn(len(self.horn_length_data))
self.months = ['January', 'February', 'March']
self.wombat_species = ['Common', 'Hairy', 'Northern']
self.happiness_data = np.random.randint(-10, 10, size=(len(self.months), len(self.wombat_species)))
self.widgets = np.random.randint(0, 100, size=100)
self.laughs_per_minute = np.random.gamma(5, scale=2, size=1000)
Step 6: Each figure panel gets its own function inside the class
Now we’ll add the plotting functions inside the class. Each panel in the figure gets a function, which is responsible for the plotting and decoration of that panel. Each panel function takes an axis as a parameter - in other words, we don’t create the panel itself in these functions, we pass in the panel from the figure (see next step).
def lineplot_unicorns_vs_rainbows(self, ax):
x = np.arange(len(self.unicorn_data))
ax.plot(x, self.unicorn_data, label='Unicorns')
ax.plot(x, self.rainbow_data, label='Rainbows')
ax.set_xlabel('Time')
ax.set_ylabel('Count')
ax.set_title('Unicorns vs. Rainbows')
ax.legend(loc='center left', bbox_to_anchor=(0, 1.5))
def barplot_ice_cream_flavors(self, ax):
colors = plt.cm.tab20(np.arange(len(self.ice_cream_flavors)))
ax.bar(self.ice_cream_flavors, self.penguin_data, color=colors)
ax.set_xlabel('Ice Cream Flavors')
ax.set_ylabel('Penguin ct.')
ax.set_title('Flavors Consumed by Penguins')
def scatterplot_horn_length_vs_magic_power(self, ax):
ax.scatter(self.horn_length_data, self.magic_power_data, c='hotpink')
ax.set_xlabel("Unicorn's Horn Length (cm)")
ax.set_ylabel('Magic Power')
ax.set_title("Horn Length vs. Magic Power")
def annotated_heatmap_wombat_happiness(self, ax):
ax, cbar = pf.create_annotated_posneg_heatmap(ax, self.happiness_data)
ax.set_xticks(np.arange(len(self.wombat_species)))
ax.set_yticks(np.arange(len(self.months)))
ax.set_xticklabels(self.wombat_species)
ax.set_yticklabels(self.months)
ax.set_xlabel('Wombat Species')
ax.set_ylabel('Months')
ax.set_title('Wombat Happiness')
cbar.set_label('Happiness')
return ax
def lineplot_widgets(self, ax):
ax.plot(self.widgets)
ax.set_xlabel('Time')
ax.set_ylabel('Number of Widgets')
ax.set_title('Widgets over Time')
def histogram_laughs(self, ax):
# Plot the histogram on the specified axis
ax.hist(self.laughs_per_minute, bins=30, color='orange', edgecolor='black')
# Set labels and title
ax.set_xlabel('Laughs per Minute')
ax.set_ylabel('Frequency')
ax.set_title('Laughs per Minute')
Note that in the annotated_heatmap_wombat_happiness
function, we call the create_annotated_posneg_heatmap function
from the plotting functions (nicknamed ‘pf’ here) file.
The other plots are single-use only, so we just create the full plots in these functions.
Step 5: Create your figure using subplot_mosaic
Now we’re ready to create the figure (or ‘dashboard’). We’ll use subplot_mosaic, which is pretty much the only figure organizational scheme you should ever use unless you’re prototyping. This one is based on a 3*3 grid, but as you see in the code we have two panels (F and G) that take up two squares in that grid. The labels we give in the mosaic will be the labels of the panels in the final figure.
def plot_dashboard(self):
"""
Plots the whole figure.
"""
# This code defines a 2D list called mosaic that represents
# the layout of the subplots in the figure.
# Each element in the list corresponds to a subplot position.
# Note that F takes up two rows and one column, and G takes up one row and two columns.
mosaic = [['A','B','C'],
['D', 'E','F'],
['G', 'G', 'F']]
# Create the figure and axes objects using the subplot_mosaic function
fig, ax_dict = plt.subplot_mosaic(mosaic, # Specify the layout of subplots using the mosaic parameter
figsize=(constants.FIG_WIDTH, 5), # Set the size of the figure in inches
dpi=300, # Set the resolution of the figure in dots per inch
constrained_layout=True, # Enable constrained layout for automatic adjustment
gridspec_kw={'height_ratios': [1, 1, 1.5],
# Set the relative heights of the rows
'width_ratios': [1.5, 1.5,
1]}) # Set the relative widths of the columns
Now (Still inside the function) we actually populate the panels of the figure, using the functions we created for each panel earlier. Note that we use our make_blank_panel function from plotting_helpers
(‘ph’) to make panel C blank, and we use our label_panels_mosaic
function to label the panels.
self.lineplot_unicorns_vs_rainbows(ax_dict['A'])
self.barplot_ice_cream_flavors(ax_dict['B'])
ph.make_blank_panel(ax_dict['C']) #Panel C is blank, we can fill in it by hand later
self.scatterplot_horn_length_vs_magic_power(ax_dict['D'])
self.annotated_heatmap_wombat_happiness(ax_dict['E'])
self.histogram_laughs(ax_dict['F'])
self.lineplot_widgets(ax_dict['G'])
ph.label_panels_mosaic(fig, ax_dict, size = 14)
return fig, ax_dict
Step 7: Make and save the figure
Now we just need to instantiate the figure object, make the figure, and save it. That’s these 3 lines of code.
fig1 = fig1_demo() #create the figure object
fig, axes = fig1.plot_dashboard() #create the figure
ph.save_figure(fig, 1) #and save it
And let’s take a look at the figure in our Figs folder.
Pretty nice, huh? Notice the automatic panel labeling. The result should be the actual size of the figure as it should appear in the paper. To prove this to yourself, take a sheet of A4 paper and see if it covers your figure (open the figure in an external editor and make sure the zoom is set to 100%). If you have nothing else to add to the figure, congrats, you’re done, you have a paper-ready figure! See step 9 for how to add it to a word document.
However, if you want to hand-draw a panel or otherwise modify the figure, see the next step.
Step 8 (Optional): Add a panel by hand in Inkscape
Earlier, we left panel C blank, so we can draw a diagram by hand in Inkscape. Open a new file in Inkscape and drag the .svg file you created with the code onto the document. You should get a popup like this.
I usually choose the default option checked in screenshot. This imports the svg file as a link, which means you can’t directly edit it (for that, choose the top option). But the advantage of a link is that if you want to update the code for the figure (you will, I promise you), you can re-run the code and it will automatically the image in the Inkscape file.
Now the figure is on the page, but the page is too big.
To fix this, go to File → Document Properties. A dialog box should pop up, click the button that says ‘resize to content’:
Much better:
Make sure that the Z (Zoom) in the bottom right hand corner is at 100%, which gives you the real figure size on the screen (again, test with a sheet of A4).
Now I’ll add a hand-drawn schematic to panel C using Inkscape’s shape tools and my finely-honed graphic design abilities.
Truly a work of art. Because we linked the python figure instead of embedding it, we’ll need to export the figure (we’ll use .tiff, journals like that for some reason). Go to the File→export and the window on the right should pop up.
Note that the DPI is set to 300. Select the export location/file name/file type (tiff) on the bottom and hit export, might take a few seconds.
Step 9: Add your figure to a paper
Now we need to put our figure in a paper. Open your paper/Create an MSWORD document (sorry, LaTex fans). The easiest way to do this is to make a 2x1 table; the figure goes in the top row, the caption goes in the bottom row.
Click on the top row of the table and hit Insert→Picture→Prom this device. Navigate to the figure file (either the .svg from step 7 or the .tiff from step 8). Now just add your caption in the bottom row and publish your amazing science!
(Again, the full code for this project can be found here.)