Source code for visualize.reconstruction

"""Plots reconstruction of images against their original forms."""
import torch
import torch.nn as nn
from typing import Iterable
import numpy as np


from .comparison import plot_images_two_rows


[docs]def plot_reconstruction_on_first_batch( loader: torch.utils.data.DataLoader, model_reconstructing: nn.Module ) -> None: """Plots the first batch of reconstructed images against the original images. :param loader: a loader from which the first batch of images is loaded. :param model_reconstructing: a model which reconstructs an image (produces an image of identical structure to the input image, when passed an input-image). """ first_batch = _retrieve_first_input_batch_from_loader(loader) with torch.no_grad(): reconstruction = model_reconstructing(first_batch) plot_images_two_rows( _images_from_tensor(first_batch), _images_from_tensor(reconstruction), len(first_batch),
) def _retrieve_first_input_batch_from_loader( loader: torch.utils.data.DataLoader, ) -> torch.Tensor: """Retrieves the input-images from the first batch in a loader.""" for batch_features in loader: return batch_features[0] raise ValueError def _images_from_tensor(images: torch.Tensor) -> Iterable[np.ndarray]: """Converts a tensor to an iterable of images, converting each image to an appropriately sized numpy array.""" for index in range(images.size()[0]): # Convert from PyTorch RGB format (3, y, x) to Numpy expected RGB format (y, x, 3) yield images[index].permute(1, 2, 0).numpy()