from __future__ import division
import matplotlib
import matplotlib.colors
import matplotlib.markers
import numpy as np
import six
import warnings
try:
_named_colors = matplotlib.colors.ColorConverter.colors.copy()
for colorname, hexcode in matplotlib.colors.cnames.items():
_named_colors[colorname] = matplotlib.colors.hex2color(hexcode)
except: # pragma: no cover
warnings.warn("Could not get matplotlib colors, named colors will not be available")
_named_colors = {}
class InvalidPlotError(Exception):
pass
[docs]class PlotChecker(object):
"""A generic object to test plots.
Parameters
----------
axis : ``matplotlib.axes.Axes`` object
A set of matplotlib axes (e.g. obtained through ``plt.gca()``)
"""
_named_colors = _named_colors
def __init__(self, axis):
"""Initialize the PlotChecker object."""
self.axis = axis
@classmethod
def _color2rgb(cls, color):
"""Converts the given color to a 3-tuple RGB color.
Parameters
----------
color :
Either a matplotlib color name (e.g. ``'r'`` or ``'red'``), a
hexcode (e.g. ``"#FF0000"``), a 3-tuple RGB color, or a 4-tuple RGBA
color.
Returns
-------
rgb : 3-tuple RGB color
"""
if isinstance(color, six.string_types):
if color in cls._named_colors:
return tuple(cls._named_colors[color])
else:
return tuple(matplotlib.colors.hex2color(color))
elif hasattr(color, '__iter__') and len(color) == 3:
return tuple(float(x) for x in color)
elif hasattr(color, '__iter__') and len(color) == 4:
return tuple(float(x) for x in color[:3])
else:
raise ValueError("Invalid color: {}".format(color))
@classmethod
def _color2alpha(cls, color):
"""Converts the given color to an alpha value. For all cases except
RGBA colors, this value will be 1.0.
Parameters
----------
color :
Either a matplotlib color name (e.g. ``'r'`` or ``'red'``), a
hexcode (e.g. ``"#FF0000"``), a 3-tuple RGB color, or a 4-tuple RGBA
color.
Returns
-------
alpha : float
"""
if isinstance(color, six.string_types):
return 1.0
elif hasattr(color, '__iter__') and len(color) == 3:
return 1.0
elif hasattr(color, '__iter__') and len(color) == 4:
return float(color[3])
else:
raise ValueError("Invalid color: {}".format(color))
@classmethod
def _parse_marker(cls, marker):
"""Converts the given marker to a consistent marker type. In practice,
this is basically just making sure all null markers (``''``, ``'None'``,
``None``) get converted to empty strings.
Parameters
----------
marker : string
The marker type
Returns
-------
marker : string
"""
if marker is None or marker == 'None':
return ''
return marker
@classmethod
def _tile_or_trim(cls, x, y):
"""Tiles or trims the first dimension of ``y`` so that ``x.shape[0]`` ==
``y.shape[0]``.
Parameters
----------
x : array-like
A numpy array with any number of dimensions.
y : array-like
A numpy array with any number of dimensions.
"""
xn = x.shape[0]
yn = y.shape[0]
if xn > yn:
numrep = int(np.ceil(xn / yn))
y = np.tile(y, (numrep,) + (1,) * (y.ndim - 1))
yn = y.shape[0]
if xn < yn:
y = y[:xn]
return y
@property
def title(self):
"""The title of the matplotlib plot, stripped of whitespace."""
return self.axis.get_title().strip()
[docs] def assert_title_equal(self, title):
"""Asserts that the given title is the same as the plotted
:attr:`~plotchecker.PlotChecker.title`.
Parameters
----------
title : string
The expected title
"""
title = title.strip()
if self.title != title:
raise AssertionError(
"title is incorrect: '{}'' (expected '{}')".format(
self.title, title))
[docs] def assert_title_exists(self):
"""Asserts that the plotted :attr:`~plotchecker.PlotChecker.title` is
non-empty.
"""
if self.title == '':
raise AssertionError("no title")
@property
def xlabel(self):
"""The xlabel of the matplotlib plot, stripped of whitespace."""
return self.axis.get_xlabel().strip()
[docs] def assert_xlabel_equal(self, xlabel):
"""Asserts that the given xlabel is the same as the plotted
:attr:`~plotchecker.PlotChecker.xlabel`.
Parameters
----------
xlabel : string
The expected xlabel
"""
xlabel = xlabel.strip()
if self.xlabel != xlabel:
raise AssertionError(
"xlabel is incorrect: '{}'' (expected '{}')".format(
self.xlabel, xlabel))
[docs] def assert_xlabel_exists(self):
"""Asserts that the plotted :attr:`~plotchecker.PlotChecker.xlabel` is
non-empty.
"""
if self.xlabel == '':
raise AssertionError("no xlabel")
@property
def ylabel(self):
"""The ylabel of the matplotlib plot, stripped of whitespace."""
return self.axis.get_ylabel().strip()
[docs] def assert_ylabel_equal(self, ylabel):
"""Asserts that the given ylabel is the same as the plotted
:attr:`~plotchecker.PlotChecker.ylabel`.
Parameters
----------
ylabel : string
The expected ylabel
"""
ylabel = ylabel.strip()
if self.ylabel != ylabel:
raise AssertionError(
"ylabel is incorrect: '{}'' (expected '{}')".format(
self.ylabel, ylabel))
[docs] def assert_ylabel_exists(self):
"""Asserts that the plotted :attr:`~plotchecker.PlotChecker.ylabel` is
non-empty.
"""
if self.ylabel == '':
raise AssertionError("no ylabel")
@property
def xlim(self):
"""The x-axis limits of the matplotlib plot."""
return self.axis.get_xlim()
[docs] def assert_xlim_equal(self, xlim):
"""Asserts that the given xlim is the same as the plot's
:attr:`~plotchecker.PlotChecker.xlim`.
Parameters
----------
xlim : 2-tuple
The expected xlim
"""
if self.xlim != xlim:
raise AssertionError(
"xlim is incorrect: {} (expected {})".format(
self.xlim, xlim))
@property
def ylim(self):
"""The y-axis limits of the matplotlib plot."""
return self.axis.get_ylim()
[docs] def assert_ylim_equal(self, ylim):
"""Asserts that the given ylim is the same as the plot's
:attr:`~plotchecker.PlotChecker.ylim`.
Parameters
----------
ylim : 2-tuple
The expected ylim
"""
if self.ylim != ylim:
raise AssertionError(
"ylim is incorrect: {} (expected {})".format(
self.ylim, ylim))
@property
def xticks(self):
"""The tick locations along the plot's x-axis."""
return self.axis.get_xticks()
[docs] def assert_xticks_equal(self, xticks):
"""Asserts that the given xticks are the same as the plot's
:attr:`~plotchecker.PlotChecker.xticks`.
Parameters
----------
xticks : list
The expected tick locations on the x-axis
"""
np.testing.assert_equal(self.xticks, xticks)
@property
def yticks(self):
"""The tick locations along the plot's y-axis."""
return self.axis.get_yticks()
[docs] def assert_yticks_equal(self, yticks):
"""Asserts that the given yticks are the same as the plot's
:attr:`~plotchecker.PlotChecker.yticks`.
Parameters
----------
yticks : list
The expected tick locations on the y-axis
"""
np.testing.assert_equal(self.yticks, yticks)
@property
def xticklabels(self):
"""The tick labels along the plot's x-axis, stripped of whitespace."""
return [x.get_text().strip() for x in self.axis.get_xticklabels()]
[docs] def assert_xticklabels_equal(self, xticklabels):
"""Asserts that the given xticklabels are the same as the plot's
:attr:`~plotchecker.PlotChecker.xticklabels`.
Parameters
----------
xticklabels : list
The expected tick labels on the x-axis
"""
xticklabels = [x.strip() for x in xticklabels]
np.testing.assert_equal(self.xticklabels, xticklabels)
@property
def yticklabels(self):
"""The tick labels along the plot's y-axis, stripped of whitespace."""
return [x.get_text().strip() for x in self.axis.get_yticklabels()]
[docs] def assert_yticklabels_equal(self, yticklabels):
"""Asserts that the given yticklabels are the same as the plot's
:attr:`~plotchecker.PlotChecker.yticklabels`.
Parameters
----------
yticklabels : list
The expected tick labels on the y-axis
"""
yticklabels = [y.strip() for y in yticklabels]
np.testing.assert_equal(self.yticklabels, yticklabels)
@property
def _texts(self):
"""All ``matplotlib.text.Text`` objects in the plot, excluding titles."""
texts = []
for x in self.axis.get_children():
if not isinstance(x, matplotlib.text.Text):
continue
if x == self.axis.title:
continue
if x == getattr(self.axis, '_left_title', None):
continue
if x == getattr(self.axis, '_right_title', None):
continue
texts.append(x)
return texts
@property
def textlabels(self):
"""The labels of all ``matplotlib.text.Text`` objects in the plot, excluding titles."""
return [x.get_text().strip() for x in self._texts]
[docs] def assert_textlabels_equal(self, textlabels):
"""Asserts that the given textlabels are the same as the plot's
:attr:`~plotchecker.PlotChecker.textlabels`.
Parameters
----------
textlabels : list
The expected text labels on the plot
"""
textlabels = [x.strip() for x in textlabels]
np.testing.assert_equal(self.textlabels, textlabels)
@property
def textpoints(self):
"""The locations of all ``matplotlib.text.Text`` objects in the plot, excluding titles."""
return np.vstack([x.get_position() for x in self._texts])
[docs] def assert_textpoints_equal(self, textpoints):
"""Asserts that the given locations of the text objects are the same as
the plot's :attr:`~plotchecker.PlotChecker.textpoints`.
Parameters
----------
textpoints : array-like, N-by-2
The expected text locations on the plot, where the first column
corresponds to the x-values, and the second column corresponds to
the y-values.
"""
np.testing.assert_equal(self.textpoints, textpoints)
[docs] def assert_textpoints_allclose(self, textpoints, **kwargs):
"""Asserts that the given locations of the text objects are almost the
same as the plot's :attr:`~plotchecker.PlotChecker.textpoints`.
Parameters
----------
textpoints : array-like, N-by-2
The expected text locations on the plot, where the first column
corresponds to the x-values, and the second column corresponds to
the y-values.
kwargs :
Additional keyword arguments to pass to
``numpy.testing.assert_allclose``
"""
np.testing.assert_allclose(self.textpoints, textpoints, **kwargs)