Source code for train_autoencoder

"""Trains (or validates etc.) an AutoEncoder model against images using the PyTorch Lightning Command Line Interface.

Please see :class:`cnn.AutoEncoder` for details of the auto-encoder architecture.

Input Arguments
===============

Please see `PyTorch Lightning CLI <https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html>`_ for
details on how to use the command line interface.

The help page prints all options::

    python -m anchor_python_training.train_autoencoder -h

As a suggestion, first create a configuration file by saving the contents to e.g. `config.yaml`::

    python -m anchor_python_training.train_autoencoder fit --print_config

and then use these configuration file to perform the training::

    python -m anchor_python_training.train_autoencoder fit --config config.yaml

The outputted model is saved incrementally via checkpoints to a directory `lightning_logs` in the working directory
(unless otherwise configured).

The `predict` subcommand is not currently recommended for productive use.
"""
__author__ = "Owen Feehan"
__copyright__ = "Copyright 2021, Owen Feehan"


import torch
import torchinfo
import os

from anchor_python_training import visualize, data, cnn
from pytorch_lightning.utilities import cli


class _MyLightningCLI(cli.LightningCLI):
    """Customized CLI implementation."""

    def add_arguments_to_parser(self, parser):
        parser.link_arguments(
            "data.rgb", "model.number_channels", compute_fn=_number_channels
        )
        parser.link_arguments("model.input_size", "data.image_size")
        parser.add_argument(
            "--write.onnx.encoder", default=True, required=False, type=bool
        )
        parser.add_argument(
            "--show.reconstruction", default=True, required=False, type=bool
        )

    def before_fit(self) -> None:
        """Overrides :class:`cli.LightningCLI`. Print model architecture details to the console."""

        config_fit = self.config["fit"]
        input_size = config_fit["model"]["input_size"]
        batch_size = config_fit["data"]["batch_size"]
        rgb = config_fit["data"]["rgb"]
        print(
            torchinfo.summary(
                self.model, (batch_size, _number_channels(rgb), input_size, input_size)
            )
        )
        print(self.model)

    def after_fit(self) -> None:
        """Overrides :class:`cli.LightningCLI`. Plot the first batch: reconstructed against original."""

        # Write the encoder model as ONNX.
        if self.config["fit"]["write"]["onnx"]["encoder"]:
            _write_onnx(
                self.model,
                self.trainer.log_dir,
                "encoder",
                _extract_first_batch_from_loader(self.datamodule.train_data),
            )

        # Show examples of the first batch, input versus reconstructed.
        if self.config["fit"]["show"]["reconstruction"]:
            visualize.plot_reconstruction_on_first_batch(
                self.datamodule.validation_data, self.model.forward_encode_decode
            )


[docs]def main(): _MyLightningCLI(cnn.AutoEncoder, data.LoadImagesModule)
def _write_onnx( model: torch.Tensor, directory: str, filename_without_extension: str, example_input: torch.Tensor, ) -> None: """Write a model to the file-system as ONNX. :param model: the model to write as ONNX. :param directory: the directory to write the model to. :param filename_without_extension: the name of the file (without the .onnx extension) to write the model to. :param data_loader: a data-loader, the first batch of whom is used to get an example input, as needed to write ONNX. """ path = os.path.join(directory, filename_without_extension + ".onnx") torch.onnx.export(model, example_input, path, export_params=True) def _extract_first_batch_from_loader( loader: torch.utils.data.DataLoader, ) -> torch.Tensor: """Extracts the first batch from a data-loader.""" first_batch, _ = next(iter(loader)) return first_batch def _number_channels(rgb: bool) -> int: """Determines the number of channels corresponding to a RGB flag.""" if rgb: return 3 else: return 1 if __name__ == "__main__": main()