"""Trains (or validates etc.) an AutoEncoder model against images using the PyTorch Lightning Command Line Interface.
Please see :class:`cnn.AutoEncoder` for details of the auto-encoder architecture.
Input Arguments
===============
Please see `PyTorch Lightning CLI <https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html>`_ for
details on how to use the command line interface.
The help page prints all options::
    python -m anchor_python_training.train_autoencoder -h
As a suggestion, first create a configuration file by saving the contents to e.g. `config.yaml`::
    python -m anchor_python_training.train_autoencoder fit --print_config
and then use these configuration file to perform the training::
    python -m anchor_python_training.train_autoencoder fit --config config.yaml
The outputted model is saved incrementally via checkpoints to a directory `lightning_logs` in the working directory
(unless otherwise configured).
The `predict` subcommand is not currently recommended for productive use.
"""
__author__ = "Owen Feehan"
__copyright__ = "Copyright 2021, Owen Feehan"
import torch
import torchinfo
import os
from anchor_python_training import visualize, data, cnn
from pytorch_lightning.utilities import cli
class _MyLightningCLI(cli.LightningCLI):
    """Customized CLI implementation."""
    def add_arguments_to_parser(self, parser):
        parser.link_arguments(
            "data.rgb", "model.number_channels", compute_fn=_number_channels
        )
        parser.link_arguments("model.input_size", "data.image_size")
        parser.add_argument(
            "--write.onnx.encoder", default=True, required=False, type=bool
        )
        parser.add_argument(
            "--show.reconstruction", default=True, required=False, type=bool
        )
    def before_fit(self) -> None:
        """Overrides :class:`cli.LightningCLI`. Print model architecture details to the console."""
        config_fit = self.config["fit"]
        input_size = config_fit["model"]["input_size"]
        batch_size = config_fit["data"]["batch_size"]
        rgb = config_fit["data"]["rgb"]
        print(
            torchinfo.summary(
                self.model, (batch_size, _number_channels(rgb), input_size, input_size)
            )
        )
        print(self.model)
    def after_fit(self) -> None:
        """Overrides :class:`cli.LightningCLI`. Plot the first batch: reconstructed against original."""
        # Write the encoder model as ONNX.
        if self.config["fit"]["write"]["onnx"]["encoder"]:
            _write_onnx(
                self.model,
                self.trainer.log_dir,
                "encoder",
                _extract_first_batch_from_loader(self.datamodule.train_data),
            )
        # Show examples of the first batch, input versus reconstructed.
        if self.config["fit"]["show"]["reconstruction"]:
            visualize.plot_reconstruction_on_first_batch(
                self.datamodule.validation_data, self.model.forward_encode_decode
            )
[docs]def main():
    _MyLightningCLI(cnn.AutoEncoder, data.LoadImagesModule) 
def _write_onnx(
    model: torch.Tensor,
    directory: str,
    filename_without_extension: str,
    example_input: torch.Tensor,
) -> None:
    """Write a model to the file-system as ONNX.
    :param model: the model to write as ONNX.
    :param directory: the directory to write the model to.
    :param filename_without_extension: the name of the file (without the .onnx extension) to write the model to.
    :param data_loader: a data-loader, the first batch of whom is used to get an example input, as needed to write ONNX.
    """
    path = os.path.join(directory, filename_without_extension + ".onnx")
    torch.onnx.export(model, example_input, path, export_params=True)
def _extract_first_batch_from_loader(
    loader: torch.utils.data.DataLoader,
) -> torch.Tensor:
    """Extracts the first batch from a data-loader."""
    first_batch, _ = next(iter(loader))
    return first_batch
def _number_channels(rgb: bool) -> int:
    """Determines the number of channels corresponding to a RGB flag."""
    if rgb:
        return 3
    else:
        return 1
if __name__ == "__main__":
    main()