Using Seaborn to Plot Distributions#

Sources and inspiration:

If running this from Google Colab, uncomment the cell below and run it. Otherwise, just skip it.

# !pip install watermark
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

Brief Data Exploration with pandas#

We can work with one of the seaborn training datasets Penguins

penguins = sns.load_dataset("penguins")

pandas package can help us get some overview of the data.

bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
count 342.000000 342.000000 342.000000 342.000000
mean 43.921930 17.151170 200.915205 4201.754386
std 5.459584 1.974793 14.061714 801.954536
min 32.100000 13.100000 172.000000 2700.000000
25% 39.225000 15.600000 190.000000 3550.000000
50% 44.450000 17.300000 197.000000 4050.000000
75% 48.500000 18.700000 213.000000 4750.000000
max 59.600000 21.500000 231.000000 6300.000000

Since describe() function works only with numers, we will need to look at the few first values, and search for unique strings in some of the columns.

species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex
0 Adelie Torgersen 39.1 18.7 181.0 3750.0 Male
1 Adelie Torgersen 39.5 17.4 186.0 3800.0 Female
2 Adelie Torgersen 40.3 18.0 195.0 3250.0 Female
3 Adelie Torgersen NaN NaN NaN NaN NaN
4 Adelie Torgersen 36.7 19.3 193.0 3450.0 Female
array(['Adelie', 'Chinstrap', 'Gentoo'], dtype=object)
array(['Torgersen', 'Biscoe', 'Dream'], dtype=object)
array(['Male', 'Female', nan], dtype=object)

Let’s take a look at missing data (NaNs) with the isnull() method.

species               0
island                0
bill_length_mm        2
bill_depth_mm         2
flipper_length_mm     2
body_mass_g           2
sex                  11
dtype: int64

We will want to drop all rows with unknown entries with dropna() function.

penguins_cleaned = penguins.dropna()
species              0
island               0
bill_length_mm       0
bill_depth_mm        0
flipper_length_mm    0
body_mass_g          0
sex                  0
dtype: int64

Of course we can do a fast visualisation with pandas, but it is more useful for exploring the data than sharing the resulting charts.

<Axes: ylabel='count'>

One of my favourite things about seaborn is that part of the documentations is a Example Gallery where you can simply copy-paste the code for charts. But there is a catch! What is missing?

Preparing charts for any occasion#

Setting the Theme and Color Pallet#

We can use the set_theme() function which changes the global defaults for all plots using the matplotlib system. So we can presetup the style or color palettes for the rest of the notebook.

You can explore more in Controlling figure aesthetics or Choosing color palettes.

We can control figure size and some axes parameters by using plt.subplots from matplotlib to ave access to the figure and the axes objects.

# Apply the theme
sns.set_theme(style="ticks", palette="colorblind")
# sns.set_theme(style="white")
# sns.set_theme(style="dark")

# Set up the matplotlib figure
figure, axes = plt.subplots(figsize=(5, 5))

Axes-level plots#

Seaborn has 2 hierarchy level of plots: figure-level and axes-level.

A figure-level function, like relplot, generates the full figure and sets the axes inside it depending on the provided parameters. An axes-level function generates a plot that should be put into an axes object. This allows having different types of plots in the same figure, because we can put each plot in a different axis.


Pivot tables and Heat maps#

As shown before, we can generate heatmaps from a pivotted table and from a correlation matrix.

# penguins_cleaned.groupby(['species', 'island'])['body_mass_g'].aggregate('mean').unstack()
pivot_table = penguins_cleaned.pivot_table("body_mass_g", index=["island", "species"]).unstack()

species Adelie Chinstrap Gentoo
Biscoe 3709.659091 NaN 5092.436975
Dream 3701.363636 3733.088235 NaN
Torgersen 3708.510638 NaN NaN

Below we create an empty canvas with instances of a figure object (fig) and an axes object (ax).

We make a heatmap with the heatmap function and assign it to the ax variable.

# Set up the matplotlib figure
figure, axes = plt.subplots(figsize=(5, 10))

# Draw the heatmap
    linewidths=.5, cbar_kws={"shrink": .5},
    ax = axes,
    #cmap="coolwarm", #Spectral)

axes.set(title="Mean weight of penguins, kg")
[Text(0.5, 0, 'Adelie'), Text(1.5, 0, 'Chinstrap'), Text(2.5, 0, 'Gentoo')]

Some figure/chart size are fixed, and if you force the figsize to be different - the chart size and shape will stay, but it will create empty part of image. Look at the output of this figure.

figure.savefig('heatmap_uneven_PNG.png', dpi=300)

Scatter plot#

The scatter plot is used to display the relationship between variables. Let’s see the scatter plot of culmen lengths and depths by penguin species.

# Set up the matplotlib figure
figure, axes = plt.subplots(figsize=(10, 5))

# Make a scatterplot
    ax = axes
# Give the plot a title
plt.title("Bill Length vs Bill Depth", size=20, color="red") #matplotlib way to define title

# Improve the legend
    loc="lower right",


The histogram plot shows the distribution of the data. You can use the histogram plot to see the distribution of one or more variables. Now let’s see the histogram of the bill length using the histplot function.

# Set up the matplotlib figure
figure, axes = plt.subplots(figsize=(10, 5))

    data = penguins_cleaned,
    x = "bill_length_mm",
    ax = axes
plt.title("Bill Length", size=20, color="red")
Text(0.5, 1.0, 'Bill Length')

We can display subsets histograms with different colors easily with the same function. We just have to give some additional parameters, like assigning the column ‘species’ of the dataframe to the hue parameter.

# Set up the matplotlib figure
figure, axes = plt.subplots(figsize=(10, 5))

    data = penguins_cleaned,
    x = "bill_length_mm",
    binwidth = 1,
    hue = "species",
    kde = True,
    ax = axes
axes.set(title="Bill Length")

    axes, "upper center",
    bbox_to_anchor=(.5, 1),

Bar plot#

A bar plot represents an estimate of the central tendency for a numeric variable with the height of each rectangle. Let’s see the bar plot showing the bill lengths of penguin species.

# Set up the matplotlib figure
figure, axes = plt.subplots(figsize=(10, 5))

    data = penguins_cleaned, 
    x = "species", 
    y = "bill_length_mm", 
    hue = "sex",
    ax = axes
axes.set(title="Bill Length for 3 Penguin Species by Sex")
[Text(0.5, 1.0, 'Bill Length for 3 Penguin Species by Sex')]

Box plot#

The box plot is used to compare the distribution of numerical data between levels of a categorical variable. Let’s see the distribution of flipper length by species.

figure, axes = plt.subplots(figsize=(10, 5))

    data =penguins_cleaned,
    x = "species",
    y = "flipper_length_mm",
    ax = axes
axes.set(title="Bill Length for 3 Penguin Species")
[Text(0.5, 1.0, 'Bill Length for 3 Penguin Species')]

You can use the hue parameter to see a boxplot of flipper lengths of species by sex.

figure, axes = plt.subplots(figsize=(8, 5))

    data = penguins_cleaned,
    x = "species",
    y = "flipper_length_mm",
    ax = axes)
axes.set(title="Bill Length for 3 Penguin Species")
[Text(0.5, 1.0, 'Bill Length for 3 Penguin Species')]

But - does the box plot represent the distribution in a best way? Source