How can I write unit tests against code that uses matplotlib?
You can also use unittest.mock to mock matplotlib.pyplot
and check that appropriate calls with appropriate arguments are made to it. Let's say you have a plot_data(data)
function inside module.py
(say it lives in package/src/
) that you want to test and which looks like this:
import matplotlib.pyplot as plt
def plot_data(x, y, title):
plt.figure()
plt.title(title)
plt.plot(x, y)
plt.show()
In order to test this function in your test_module.py
file you need to:
import numpy as np
from unittest import mock
import package.src.module as my_module # Specify path to your module.py
@mock.patch("%s.my_module.plt" % __name__)
def test_module(mock_plt):
x = np.arange(0, 5, 0.1)
y = np.sin(x)
my_module.plot_data(x, y, "my title")
# Assert plt.title has been called with expected arg
mock_plt.title.assert_called_once_with("my title")
# Assert plt.figure got called
assert mock_plt.figure.called
This checks if a title
method is called with an argument my title
and that the figure
method is invoked inside plot_data
on the plt
object.
More detailed explanation:
The @mock.patch("module.plt")
decorator "patches" the plt
module imported inside module.py
and injects it as a mock
object (mock_plt
) to the test_module
as a parameter. This mock object (passed as mock_plt
) can be now used inside our test to record everything that plot_data
(function we're testing) does to plt
- that's because all the calls made to plt
by plot_data
are now going to be made on our mock object instead.
Also, apart from assert_called_once_with you might want to use other, similar methods such as assert_not_called, assert_called_once etc.
In my experience, image comparison tests end up bring more trouble than they are worth. This is especially the case if you want to run continuous integration across multiple systems (like TravisCI) that may have slightly different fonts or available drawing backends. It can be a lot of work to keep the tests passing even when the functions work perfectly correctly. Furthermore, testing this way requires keeping images in your git repository, which can quickly lead to repository bloat if you're changing the code often.
A better approach in my opinion is to (1) assume matplotlib is going to actually draw the figure correctly, and (2) run numerical tests against the data returned by the plotting functions. (You can also always find this data inside the Axes
object if you know where to look.)
For example, say you want to test a simple function like this:
import numpy as np
import matplotlib.pyplot as plt
def plot_square(x, y):
y_squared = np.square(y)
return plt.plot(x, y_squared)
Your unit test might then look like
def test_plot_square1():
x, y = [0, 1, 2], [0, 1, 2]
line, = plot_square(x, y)
x_plot, y_plot = line.get_xydata().T
np.testing.assert_array_equal(y_plot, np.square(y))
Or, equivalently,
def test_plot_square2():
f, ax = plt.subplots()
x, y = [0, 1, 2], [0, 1, 2]
plot_square(x, y)
x_plot, y_plot = ax.lines[0].get_xydata().T
np.testing.assert_array_equal(y_plot, np.square(y))