Often data science professionals need to present several plots mixed under one image to see the difference between past and current data changes. Such discrimination provides a better insight into a particular concept or circumstance. That is where Matplotlib comes up with the concept of multiple subplots. This tutorial teaches you how to create Multiple Subplots using Matplotlib.

What Is Matplotlib Multiple Subplots?

Matplotlib Subplot in the Matplotlib library is a way where data analysts can render multiple sub-plots under one plot. It helps in differentiating and comparing various views of data together side by side. Subplots are one of the most essential but fundamental concepts that need to be understood when plotting multiple graphs or figures in a single plot. In Matplotlib, we have the subplot() and subplots() functions in the pyplot module. The syntax of this built-in method looks something like this:

Syntax
matplotlib.pyplot.subplot(nrows, ncols, index)

and

matplotlib.pyplot.subplots(nrows, ncols, sharex = False, sharey = False, squeeze = True,  subplot_kw = None, gridspec_kw = None, **kwargs)

where,

  • nrows and ncols determine the number of rows and columns in the subplot.
  • sharex and sharey determine the sharing properties associated with the x-axis and y-axis.
  • squeeze is an optional parameter with a default value of "True". It normally contains Boolean values.
  • subplot_kw has the "dict" with keywords provided to the add subplot method, which is used to create each subplot.
  • gridspec_kw  helps in creating the grid where multiple plots will be located.

There are two major ways of plotting Subplot through Matplotlib:

  1. The Stacked Plots
  2. The Grid Plots

Stacked Plots

In a stacked plot, multiple plots are generated one after the other, like a "stack data structure".

Example:

import matplotlib.pyplot as plt
import numpy as np

x = np.arange(0, 20, 0.02)

plt.subplot(3, 1, 1)
plt.plot(x, np.sin(x))

plt.subplot(3, 1, 2)
plt.plot(x, x)

plt.subplot(3, 1, 3)
plt.plot(x, np.cos(x))

# Print the chart
plt.show()
Program Output:

Matplotlib Stacked plots example

Grid Plots

In Grid Plots, multiple plots get generated like grids or tiles one after the other in box-shaped structures. For example, If a data analyst wants to create four multiplot, they have to make a 2×2 grid, which means there will be n_rows = 2 and n_cols = 2. Here each image will get identified by the index again in the form of a row-major order. If a data analyst wants to create a multiplot with three plots, they should have to make three consecutive grids vertically stand one after the other from left to right.

Example:

import matplotlib.pyplot as plt
import numpy as np

g = np.arange(0, 20, 0.02)

plt.subplot(2, 2, 1)
plt.plot(g, np.sin(g))

plt.subplot(2, 2, 2)
plt.plot(g, np.cos(g))

plt.subplot(2, 2, 3)
plt.plot(g, 5-g)

plt.subplot(2, 2, 4)
plt.plot(g, g)

# Print the chart
plt.show()
Program Output:

Matplotlib Grid plots example

Example of a Single Plot Using the subplots() Function

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0, 2.5*np.pi, 500)
y = np.sin(g**2) + np.cos(g)

fig, ax = plt.subplots()

ax.plot(x, y)
ax.set_title('Matplotlib plot example')

# Print the chart
plt.show()
Program Output:

Matplotlib Example of a single plot using the subplots() function

Example of Multiple Plots Using the subplots() Function

import numpy as np
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 4))

def f(g):
        return np.sin(g) - g * np.cos(g)

def fp(g):
        return g * np.sin(g)

G = np.arange(-5, 25.0, 0.025)
fig, ax = plt.subplots(2, sharex = 'col', sharey = 'row')

ax[0].plot(G, f(G), 'yo', G, f(G), 'r')
ax[0].set(title=' Multiplots using subplots() ')

ax[1].plot(G, fp(G), 'co', G, fp(G), 'b')
ax[1].set(xlabel='Values at X', ylabel='Values at Y', title='Derivative Function through f')

# Print the chart
plt.show()
Program Output:

Matplotlib Example of multiple plots using the subplots() function

Adding a Grid to a Specific or Both Subplots

Data science professionals can also create grids to get better measurements and insights into the data and better isolate the situation. These internal grids are virtual lines that one can create using Matplotlib's grid() method.

Example:

import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(6, 4))

def f(g):
        return np.sin(g) - g * np.cos(g)

def fp(g):
        return g * np.sin(g)

G = np.arange(-5, 25.0, 0.025)
fig, ax = plt.subplots(2, sharex = 'col', sharey = 'row')

ax[0].plot(G, f(G), 'yo', G, f(G), 'r')
ax[0].set(title=' Multiplots using subplots() ')

ax[1].plot(G, fp(G), 'co', G, fp(G), 'b')
ax[1].set(xlabel='Values at X', ylabel='Values at Y', title='Derivative Function through f')

ax[0].grid()
ax[1].grid()

# Print the chart
plt.show()
Program Output:

Matplotlib Adding a grid to a specific or both subplots