Source code for visualize.comparison

"""Plots comparisons of a single image or set of images against nother image or set."""


import torch
import matplotlib.pyplot as plt
import numpy as np
from typing import Iterable


[docs]def plot_images_two_rows( top: Iterable[np.ndarray], bottom: Iterable[np.ndarray], row_size: int = 10, figure_size=(20, 4), ) -> None: """Plot a top-row of images, alongside a bottom-row of images. :param top: the images on the top-row. There must be at least :code:row_size: elements. :param bottom: the images on the bottom-row. There must be at least :code:row_size: elements. :param row_size: the maximum number of images on a row. :param figure_size: the size of figure in inches, as passed to :func:`plt.figure`. """ with torch.no_grad(): plt.figure(figsize=figure_size) iter_top = iter(top) iter_bottom = iter(bottom) try: for index in range(row_size): _plot_image(next(iter_top), index, row_size) _plot_image(next(iter_bottom), index + row_size, row_size) except StopIteration: pass plt.show()
def _plot_image(image: np.ndarray, index: int, row_size: int) -> None: """Displays a row of images in a subplot.""" ax = plt.subplot(2, row_size, index + 1) plt.imshow(image) plt.gray() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False)