cnn

CNN models.

Package Contents

Classes

AutoEncoder

An AutoEncoder model, based upon incrementally downsampling CNNs to a flat code, and then upsampling CNNs.

class cnn.AutoEncoder(number_channels: int = 3, input_size: int = 32, code_size: int = 16, number_filters: int = 4)[source]

Bases: pytorch_lightning.LightningModule

An AutoEncoder model, based upon incrementally downsampling CNNs to a flat code, and then upsampling CNNs.

forward(self, activation: Union[torch.Tensor, List[torch.Tensor]]) torch.Tensor

Overrides pl.LightningModule.

forward_encode_decode(self, input: Union[torch.Tensor, List[torch.Tensor]]) torch.Tensor

Performs both the encode and decode step on an input (batched).

Parameters

input – the input tensor.

Returns

the tensor after encoding and decoding.

training_step(self, batch: List[torch.Tensor], batch_idx: int)

Overrides pl.LightningModule.

validation_step(self, batch: List[torch.Tensor], batch_idx: int)

Overrides pl.LightningModule.

test_step(self, batch: List[torch.Tensor], batch_idx: int)

Overrides pl.LightningModule.

predict_step(self, batch, batch_idx: int, dataloader_idx: int = None)

Step function called during predict(). By default, it calls forward(). Override to add any processing logic.

The predict_step() is used to scale inference on multi-devices.

To prevent an OOM error, it is possible to use BasePredictionWriter callback to write the predictions to disk or database after each batch or on epoch end.

The BasePredictionWriter should be used while using a spawn based accelerator. This happens for Trainer(strategy="ddp_spawn") or training on 8 TPU cores with Trainer(tpu_cores=8) as predictions won’t be returned.

Example

class MyModel(LightningModule):

    def predicts_step(self, batch, batch_idx, dataloader_idx):
        return self(batch)

dm = ...
model = MyModel()
trainer = Trainer(gpus=2)
predictions = trainer.predict(model, dm)
Args:

batch: Current batch batch_idx: Index of current batch dataloader_idx: Index of the current dataloader

Return:

Predicted output

configure_optimizers(self)

Overrides pl.LightningModule. This docstring replaces the parent docstring which is errored.