Learning Matplotlib - Lessons from a ggplot user

Python
Matplotlib
Author

Benjamin Wee

Published

April 6, 2020

You never forget your first

I love R. I wouldn’t be where I am in my career without the tools of the Tidyverse and kindness of the Rstats and Rladies communities. It provided me a platform to entering data science, exposing me to the world of software engineering and development along the way. Choosing to use R over Stata for my Economics Honours thesis was probably the best choice I made (hides behind shield).

I picked up new programming languages over the past two years working as a data scientist. The choice of programming languages at work are determined by the design of the tech stack and the analytics teams. We mostly use PySpark at work for exploratory data analysis (EDA) and “feature engineering”. So far, I’ve managed. I picked up standard python libraries like Pandas and NumPy, machine learning libaries (MLlib) in Pyspark, and wrangled many tables in SQL.

But the one thing I couldn’t put myself through is learning how to plot in Python. I couldn’t get over the ease and beauty of building visualisations using ggplot2. I would do all my work in SQL and pyspark, but I would always finish off my EDA using ggplot.

Getting over the curve

I’ve been finally pushing myself to get acquainted with Python plotting libraries. I figured that it wouldn’t hurt being bilingual in the two main data science languages. Plus all the developers at work are all python useres (pythonistas?), so it would be jointly beneficial to communicate using the programming lingua franca.

I found it difficult navigating the vast array of packages Python has for plotting. Matplotlib, seaborn, altair, bokeh, etc. But I realised that most of the higher level modules are built on top of matplotlib. So while you may plot using a higher level package like seaborn, a lot of the tweaking is done using matplotlib functions.

Learning matplotlib is difficult. The first thing which took me a while to wrap my head around was how I needed the data to be structured to build the plots I wanted. Matplotlib requires data in a wide format, as opposed to a long (tidy) format when reading in components of the data. Seaborn was much easier to grasp as it requires long/tidy data structures when plotting. This feels more intuitive to me as a ggplot user.

Replicating KH’s COVID Small Multiples in Matplotlib

I came across Kieran Healy’s blog post a couple of weeks ago where he plotted cumulative reported cases of the COVID-19 virus for the top 50 countries. I thought it would be a good exercise for me to try to replicate this in Matplotlib.

I’ve always admired this style of visualisation It preserves the variation of all the countries while highlighting the key country you want to focus on for comparison and trend analysis. I use similar plots at work all the time to assist in presentations - a clean and compact way of presenting rich datasets.

I saved myself the effort of recreating the dataset from scratch by using Kieran’s R code found in his COVID Github Repository. The COVID data comes from the European Centre for Disease Prevention and Control (ECDC). Note that I have not checked the validity or accuracy of the data source or scraping process. This is just an exercise in creating a similar figure in Python.

Processing the data

Most of the data preprocessing has been completed in the R script. However, there is still a bit of work requierd to get the necessary datasets and objects prepared for plotting. The exact data and scripts I used here

First, load in the data, libraries and clean up some of the country labels.

import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as tick

# Setting some Matplotlib parameters
mpl.rcParams['figure.dpi']= 300
# plt.style.use('seaborn-talk')

# Replace the some of the country names with shorter labels
data = pd.read_csv('cov_case_curve2.csv')\
         .replace({'cname' : 
                   {'United States' : 'USA',
                   'Iran, Islamic Republic of' : 'Iran',
                   'Korea, Republic of' : 'South Korea', 
                   'United Kingdom' : 'UK'}})\
        .drop('Unnamed: 0', axis=1)

# Show first 5 rows of the main dataset
data.sort_values("date", ascending = False)\
    .head()\
    .style
  date cname iso3 cases deaths cu_cases cu_deaths days_elapsed end_label
2205 2020-04-06 Viet Nam VNM 1 0 241 0 14 Viet Nam
2146 2020-04-06 Jordan JOR 22 0 345 5 13 Jordan
2118 2020-04-06 Denmark DNK 292 18 4369 179 27 Denmark
2119 2020-04-06 Dominican Republic DOM 167 5 1745 82 15 Dominican Republic
2120 2020-04-06 Ecuador ECU 181 8 3646 180 19 Ecuador

Next, we subset our data so it contains the top 50 cumulative cases as of April 6th (df_50) and pivot them into a ‘wide’ format. The index of the dataframes will be the days_elapsed variable (x-axis) and the columns will contain cumulative cases for each country (y-axis, presented on log10 scale).

For ggplot and seaborn, these country columns would be their own variable to determine the facets of the plot. This makes creating the grid of a facet plot much simpler (one line of code).

## Countries - top 50 cumulative cases as of 2020-03-27
## We will use this to create a dataset with only the top 50 countries
## that we will highlight in all the subplots
top_50 = data.loc[data.groupby(["cname"])["cu_cases"].idxmax()]\
             .sort_values('cu_cases', ascending = False)\
             .head(50)\
             .loc[:, ['iso3', 'cname', 'cu_cases']]

## Filter countries in top 50
_df = data.loc[data['iso3'].isin(top_50['iso3'])]\

## Restructure data into wide format
## Top 50
df_50 = _df.pivot(index = 'days_elapsed', 
                  values = 'cu_cases',
                  columns = 'cname')\
          .reset_index(inplace=False)

# Days elapsed as index (x axis)
df_50.index = df_50['days_elapsed']

# Drop unwanted columns
#df = df.drop('days_elapsed', axis = 1)
df_50 = df_50.drop('days_elapsed', axis = 1)

# Display one of the dataframes
df_50.head().style
cname Algeria Argentina Australia Austria Belgium Brazil Canada Chile China Colombia Czech Republic Denmark Dominican Republic Ecuador Finland France Germany Iceland India Indonesia Iran Ireland Israel Italy Japan Luxembourg Malaysia Mexico Netherlands Norway Pakistan Panama Peru Philippines Poland Portugal Qatar Romania Russian Federation Saudi Arabia Serbia South Africa South Korea Spain Sweden Switzerland Thailand Turkey USA United Arab Emirates
days_elapsed                                                                                                    
0 102.000000 128.000000 100.000000 102.000000 109.000000 121.000000 103.000000 156.000000 216.000000 102.000000 116.000000 113.000000 112.000000 111.000000 155.000000 100.000000 111.000000 117.000000 125.000000 117.000000 139.000000 129.000000 116.000000 132.000000 105.000000 140.000000 117.000000 118.000000 128.000000 113.000000 187.000000 109.000000 117.000000 111.000000 104.000000 112.000000 262.000000 113.000000 114.000000 118.000000 126.000000 116.000000 155.000000 114.000000 137.000000 209.000000 114.000000 191.000000 103.000000 113.000000
1 189.000000 158.000000 112.000000 131.000000 169.000000 200.000000 138.000000 201.000000 235.000000 128.000000 150.000000 264.000000 202.000000 168.000000 nan 130.000000 129.000000 117.000000 137.000000 134.000000 245.000000 169.000000 178.000000 229.000000 132.000000 210.000000 129.000000 164.000000 188.000000 147.000000 187.000000 137.000000 145.000000 140.000000 125.000000 169.000000 nan 139.000000 147.000000 133.000000 135.000000 150.000000 345.000000 151.000000 161.000000 264.000000 177.000000 359.000000 125.000000 113.000000
2 231.000000 225.000000 126.000000 182.000000 200.000000 234.000000 176.000000 238.000000 386.000000 158.000000 214.000000 516.000000 245.000000 199.000000 210.000000 178.000000 157.000000 138.000000 165.000000 172.000000 388.000000 223.000000 250.000000 322.000000 144.000000 345.000000 149.000000 203.000000 265.000000 169.000000 302.000000 200.000000 234.000000 142.000000 177.000000 245.000000 320.000000 184.000000 199.000000 133.000000 149.000000 205.000000 601.000000 200.000000 203.000000 332.000000 177.000000 670.000000 159.000000 140.000000
3 264.000000 266.000000 156.000000 246.000000 239.000000 291.000000 244.000000 342.000000 526.000000 210.000000 298.000000 676.000000 312.000000 426.000000 267.000000 212.000000 196.000000 178.000000 191.000000 172.000000 593.000000 292.000000 260.000000 400.000000 144.000000 484.000000 158.000000 251.000000 321.000000 192.000000 478.000000 245.000000 263.000000 187.000000 238.000000 331.000000 337.000000 217.000000 253.000000 171.000000 188.000000 240.000000 762.000000 261.000000 248.000000 374.000000 177.000000 947.000000 233.000000 140.000000
4 305.000000 301.000000 197.000000 361.000000 267.000000 428.000000 304.000000 434.000000 623.000000 235.000000 344.000000 804.000000 392.000000 532.000000 272.000000 285.000000 262.000000 199.000000 231.000000 227.000000 978.000000 366.000000 427.000000 650.000000 164.000000 670.000000 197.000000 316.000000 382.000000 277.000000 495.000000 313.000000 318.000000 202.000000 287.000000 448.000000 401.000000 260.000000 306.000000 238.000000 222.000000 274.000000 892.000000 374.000000 326.000000 490.000000 212.000000 1236.000000 338.000000 153.000000

Plotting single country

I started by plotting a single country before adding the other components of the final chart. Starting simple helped me get the basics right first. The below chart the cumulative growth of cases in China by days since the 100th case.

You can plot this using matplotlib or the plot function of the pandas dataframe. All the customisation uses matplotlib functions. The difference between using matplotlib and pandas.plot is minor in a simple plot. However, it is much easier to start off with pandas.plot then add customisation via matplotlib when it comes to more complicated visualisations (thanks to Chris Moffitt’s blog post on effective plotting in Python for this advice).

## Plot China
fig, ax = plt.subplots(figsize = (8, 13))

# Matplotlib
plt.plot(df_50.index, df_50['China'], color = 'grey')

# Pandas
# df_50['China'].plot(color='grey')

# Plot customisation
plt.title('Cumulative Growth rate of China')
plt.xlabel('Days since 100th confirmed case')
plt.ylabel('Cumulative number of cases (log10 scale)')
ax.set_xticks(np.arange(0, 80, 20), minor=False)
ax.set_yscale('log', base=10)
ax.yaxis.set_major_formatter(tick.ScalarFormatter())
ax.get_yaxis().set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))
plt.plot();

Plotting all countries

Next is to display the values of all coutries in a single figure. Matplotlib requires the code to explicitly identify which columns of the dataframe it is plotting. To do this I created a list of all countries which I loop over to plot each individual line. This will also be used to ‘highlight’ the relevant countries in each subplot later on.

The plot function in pandas is more forgiving. It will plot all columns for you automatically. No for loop needed.

# Create list of countries to loop over
countries = top_50['cname'].drop_duplicates().tolist()

# Print first 5 countries
countries[0:5]
['USA', 'Spain', 'Italy', 'Germany', 'China']

Pandas

fig, axes= plt.subplots(figsize=(5,8))

# Pandas
df_50.plot(color='grey', alpha = 0.6, linewidth=0.5, legend = False, ax = axes)

# Matplotlib customisation
plt.title('Cumulative reported cases for all countries')
plt.xlabel('Days since 100th confirmed case')
plt.ylabel('Cumulative number of cases (log10 scale)')
axes.set_xticks(np.arange(0, 80, 20), minor=False)
axes.set_yscale('log', base=10)
axes.grid(alpha=0.2)
axes.yaxis.set_major_formatter(tick.ScalarFormatter())
axes.get_yaxis().set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))

plt.plot();

Matplotlib

fig, axes = plt.subplots(figsize = (5, 8))

# Matplotlib
for idx, count in enumerate(countries):
    plt.plot(df_50.index, df_50[str(count)], color = 'grey', alpha = 0.6, linewidth=0.5)

# Matplotlib Customisation
plt.title('Cumulative reported cases for all countries')
plt.xlabel('Days since 100th confirmed case')
plt.ylabel('Cumulative number of cases (log10 scale)')
axes.set_xticks(np.arange(0, 80, 20), minor=False)
axes.set_yscale('log', base=10)
axes.yaxis.set_major_formatter(tick.ScalarFormatter())
axes.get_yaxis().set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))

plt.plot();

Plotting all countries in 50 subplots

This is where things get complicated. We need to loop over the countries list for both matplotlib and pandas to create subplots for each of the top 50 countries. This means that matplotlib has another for loop to compute all the cumulative cases as well as doing it across subplots. I spent quite a few hours trying not to do this (I would’ve got this post out a lot earlier).

But it turns out the best way to do this is just plot the 50 countries using pandas.plot and use the for loop to plot across all subplots. I’ve kept my attempted matplotlib version in the tab below – please get in touch if you know a better way of doing this.

Pandas

fig, axes = plt.subplots(10, 5, figsize = (16, 30), sharex = True, sharey = True)

for idx, count in enumerate(countries):
    # Get grey lines for all subplots
    df_50.plot(ax = axes[idx//5][idx%5],
                legend = False, 
                color='grey',  
                alpha = 0.6, 
                linewidth=0.5)

    axes[idx//5][idx%5].set_title(str(count), size = 9)
    axes[idx//5][idx%5].set_xlabel('')
    axes[idx//5][idx%5].set_ylabel('')
    axes[idx//5][idx%5].set_yscale('log', base=10)
    axes[idx//5][idx%5].yaxis.set_major_formatter(tick.ScalarFormatter())
    axes[idx//5][idx%5].get_yaxis().set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))
    axes[idx//5][idx%5].grid(alpha=0.1)
    axes[idx//5][idx%5].set_xticks(np.arange(0, 80, 20), minor=False)
    
fig.suptitle('Cumulative Number of Reported Cases of COVID-19: Top 50 Countries', fontsize=20,
            x=0.12, y=.91, horizontalalignment='left', verticalalignment='top')
fig.text(0.12, 0.895, 'Date of Saturday, April 4, 2020', fontsize=16, ha='left', va='top')
fig.text(0.04, 0.5, 'Cumulative number of cases (log10 scale)', va='center', rotation='vertical', size = 16) 
fig.text(0.5, 0.097, 'Days since 100th confirmed case', ha='center', size = 16)
plt.figure();

Matplotlib

fig, axes = plt.subplots(10, 5, figsize = (16, 30), sharex = True, sharey = True)

for idx, count in enumerate(countries):
    for country in enumerate(countries):
        axes[idx//5][idx%5].plot(df_50.index, df_50[str(country[1])],
                                 color='grey',  
                                 alpha = 0.6, 
                                 linewidth=0.5)
   
        axes[idx//5][idx%5].title.set_text(str(count))
        axes[idx//5][idx%5].set_title(str(count), size = 9)
        axes[idx//5][idx%5].set_xlabel('')
        axes[idx//5][idx%5].set_ylabel('')
        axes[idx//5][idx%5].set_yscale('log', base=10)
        axes[idx//5][idx%5].yaxis.set_major_formatter(tick.ScalarFormatter())
        axes[idx//5][idx%5].get_yaxis().set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))
        axes[idx//5][idx%5].grid(alpha=0.2)
        axes[idx//5][idx%5].set_xticks(np.arange(0, 80, 20), minor=False)

fig.suptitle('Cumulative Number of Reported Cases of COVID-19: Top 50 Countries', fontsize=20,
            x=0.12, y=.91, horizontalalignment='left', verticalalignment='top')
fig.text(0.12, 0.895, 'Date of Saturday, April 4, 2020', fontsize=16, ha='left', va='top')
fig.text(0.04, 0.5, 'Cumulative number of cases (log10 scale)', va='center', rotation='vertical', size = 16) 
fig.text(0.5, 0.097, 'Days since 100th confirmed case', ha='center', size = 16)
        
plt.plot();

Highlighting countries and adding points

Finally all that’s left to do is to highlight the line corresponding to the country in the subplot and add a point for the end of each line. To do plot the points at the end of each red line, I create and loop over a dataframe which contains the top 50 countries with the corresponding latest day elapsed and cumulative case. This is completed using just pandas.plot and customised with matplotlib functions.

# Subset dataframe with top 50 countries and the latest cumulative case value
markers = data.loc[data.groupby(["cname"])["cu_cases"].idxmax()]\
             .sort_values('cu_cases', ascending = False)\
             .head(50)\
             .loc[:, ['days_elapsed', 'cname', 'cu_cases']]\
             .reset_index(drop = True)

markers.head().style
  days_elapsed cname cu_cases
0 34 USA 337635
1 34 Spain 130759
2 42 Italy 128948
3 36 Germany 95391
4 78 China 82642
fig, axes = plt.subplots(10, 5, figsize = (16, 30), sharex = True, sharey = True)

for idx, count in enumerate(countries):
    # Get grey lines for all subplots
    df_50.plot(ax = axes[idx//5][idx%5],
                legend = False, 
                color='grey',  
                alpha = 0.6, 
                linewidth=0.5)
   
    # Highlight relevant countries for each subplot
    df_50[str(count)].plot(ax = axes[idx//5][idx%5],
                            legend = False, 
                            color='red',
                            linewidth=0.9)
    
    # Add markers at the end of each line
    markers.query('cname == "{}"'.format(count))\
           .plot.scatter(ax = axes[idx//5][idx%5],
                         x='days_elapsed', 
                         y='cu_cases', 
                         color = 'red')
    
    axes[idx//5][idx%5].set_title(str(count), size = 9)
    axes[idx//5][idx%5].set_xlabel('')
    axes[idx//5][idx%5].set_ylabel('')
    axes[idx//5][idx%5].set_yscale('log', base=10)
    axes[idx//5][idx%5].yaxis.set_major_formatter(tick.ScalarFormatter())
    axes[idx//5][idx%5].get_yaxis().set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))
    axes[idx//5][idx%5].grid(alpha=0.1)
    axes[idx//5][idx%5].set_xticks(np.arange(0, 80, 20), minor=False)
    
fig.suptitle('Cumulative Number of Reported Cases of COVID-19: Top 50 Countries', fontsize=20,
            x=0.12, y=.91, horizontalalignment='left', verticalalignment='top')
fig.text(0.12, 0.895, 'Date of Saturday, April 4, 2020', fontsize=16, ha='left', va='top')
fig.text(0.04, 0.5, 'Cumulative number of cases (log10 scale)', va='center', rotation='vertical', size = 16) 
fig.text(0.5, 0.097, 'Days since 100th confirmed case', ha='center', size = 16)

plt.figure();

Resources

Resources I found useful as an R user learning Python:

Effectively using Matplotlib - great for outlining a set of principles and steps for plotting

Pandas comparison with R - Method chaining preserves the main functionality of the tidyverse %>% pipe, just need to look up the corresponding functions

Pandas plotting - useful summary of pandas plotting tools