"""An CNN-based `autoencoder <https://en.wikipedia.org/wiki/Autoencoder>`_ model."""
import torch.nn as nn
import math
from typing import Iterable, Callable, List, Union
import pytorch_lightning as pl
import torch.optim
from torch.nn import functional as F
[docs]BASE_IMAGE_SIZE: int = 16
"""The width and neight that images are downsampled to before flattening it into a code."""
_KERNEL_SIZE: int = 3
"""The size of each convolutional step."""
_PADDING: int = 1
"""The padding, both input and output for filters."""
_STRIDE: int = 2
"""The stride in each convolutional filter."""
[docs]class AutoEncoder(pl.LightningModule):
"""An AutoEncoder model, based upon incrementally downsampling CNNs to a flat code, and then upsampling CNNs."""
def __init__(
self,
number_channels: int = 3,
input_size: int = 32,
code_size: int = 16,
number_filters: int = 4,
):
"""Creates an AutoEncoder Model, with an architecture as described in the class-level docstring.
The architecture is summarized as follows:
- The input is incrementally downsampled by two until it becomes :code:`[BASE_IMAGE_SIZE, BASE_IMAGE_SIZE]`.
- It is flattened, and there's a linear layer mapping it to a coding vector of size `code_size`.
- It is then unflattened, and there's a linear layer mapping it back to
:code:`[BASE_IMAGE_SIZE, BASE_IMAGE_SIZE]`.
- Is then upsampled incrementally by two until it becomes the original input-size again.
:param number_channels: the number of channels in the input image e.g. 1 for grayscale, 3 for RGB.
:param input_size: the width and height of images fed into the autoencoder. It should be a power of 2 and
:code:`>=` :const:`BASE_IMAGE_SIZE`.
:param code_size: how many variables exist in the 'bottleneck' latent space part of the autoencoder. Each
image will be encoded into a vector of this size.
:param number_filters: the number of filters to use in each convolutional operation.
"""
super(AutoEncoder, self).__init__()
self._number_channels = number_channels
self._number_filters = number_filters
rounds_downsampling = _calculate_rounds_downsampling(input_size)
# The size before being linear mapped into the code vector.
penultimate_size = self._number_filters * BASE_IMAGE_SIZE * BASE_IMAGE_SIZE
self.encoder = nn.Sequential(
*_repeat_with_relu(rounds_downsampling, self._downsampling_conv_filter),
nn.Flatten(),
nn.Linear(penultimate_size, code_size),
nn.ReLU(True),
)
self.decoder = nn.Sequential(
nn.Linear(code_size, penultimate_size),
nn.ReLU(True),
nn.Unflatten(1, (self._number_filters, BASE_IMAGE_SIZE, BASE_IMAGE_SIZE)),
*_repeat_with_relu(rounds_downsampling, self._upsampling_conv_filter),
)
[docs] def forward(
self, activation: Union[torch.Tensor, List[torch.Tensor]]
) -> torch.Tensor:
"""Overrides :mod:`pl.LightningModule`."""
# Perform only the encode step when predicting.
# When predicting, activation is both the input tensor and another.
# This ensures only the input tensor is used
# See https://github.com/PyTorchLightning/pytorch-lightning/issues/1137
if isinstance(activation, List):
activation = activation[0]
return self.encoder(activation)
[docs] def forward_encode_decode(
self, input: Union[torch.Tensor, List[torch.Tensor]]
) -> torch.Tensor:
"""Performs both the encode and decode step on an input (batched).
:param input: the input tensor.
:returns: the tensor after encoding and decoding.
"""
temp = self.forward(input)
return self.decoder(temp)
[docs] def training_step(self, batch: List[torch.Tensor], batch_idx: int):
"""Overrides :class:`pl.LightningModule`."""
return self._common_step(batch, batch_idx, "train")
[docs] def validation_step(self, batch: List[torch.Tensor], batch_idx: int):
"""Overrides :class:`pl.LightningModule`."""
self._common_step(batch, batch_idx, "val")
[docs] def test_step(self, batch: List[torch.Tensor], batch_idx: int):
"""Overrides :class:`pl.LightningModule`."""
self._common_step(batch, batch_idx, "test")
[docs] def predict_step(self, batch, batch_idx: int, dataloader_idx: int = None):
return self(batch)
def _common_step(self, batch: List[torch.Tensor], batch_idx: int, stage: str):
x, _ = batch
x_hat = self.encoder(x)
x_hat = self.decoder(x_hat)
loss = F.mse_loss(x, x_hat)
self.log(f"loss/{stage}", loss, on_step=True)
return loss
def _downsampling_conv_filter(self, first: bool, last: bool) -> nn.Module:
"""A single downsampling convolutional filter.
This is an approximately opposite counterpart to :meth:`_upsampling_conv_filter`.
The number of output channels is always :code:`self._number_filters`.
The number of input channels is :code:`self._number_filters`, apart from the first layer which is
:code:`self._number_channels`.
"""
number_filters_in = self._number_channels if first else self._number_filters
return nn.Conv2d(
number_filters_in,
self._number_filters,
_KERNEL_SIZE,
stride=_STRIDE,
padding=_PADDING,
)
def _upsampling_conv_filter(self, first: bool, last: bool) -> nn.Module:
"""A single upsampling filter.
This is an approximately opposite counterpart to :meth:`_downsampling_conv_filter`.
The number of input channels is always :code:`self._number_filters`.
The number of output channels is :code:`self._number_filters`, apart from the final layer which is
:code:`self._number_channels`.
"""
number_filters_out = self._number_channels if last else self._number_filters
return nn.ConvTranspose2d(
self._number_filters,
number_filters_out,
_KERNEL_SIZE,
stride=_STRIDE,
padding=_PADDING,
output_padding=_PADDING,
)
def _repeat_with_relu(
times: int, create_filter: Callable[[int, int], nn.Module]
) -> Iterable[nn.Module]:
"""Repeats the following: a filter, followed by a ReLU for a certain number of times.
:param times: the number of times to repeat the filter and ReLU.
:param create_filter: creates the filter to use, passing boolean flags indicating if it is the first and last
iteration, respectively.
:returns: the successive filter and ReLU modules, repeated {@code times} number of times.
"""
for i in range(times):
yield create_filter(i == 0, i == (times - 1))
yield nn.ReLU(True)
def _calculate_rounds_downsampling(input_size: int) -> int:
"""Calculates how many rounds downsampling is needed so that the input becomes :const:`BASE_IMAGE_SIZE`.
A check occurs that :code:`input_size` fulfills the necessary criteria.
"""
rounds = math.log(input_size / BASE_IMAGE_SIZE, 2)
if rounds.is_integer():
return int(rounds)
else:
raise ValueError(
f"input_size must be a power of 2 (>= {BASE_IMAGE_SIZE}) but is {input_size}."
)