cnn.autoencoder
An CNN-based autoencoder model.
Module Contents
Classes
An AutoEncoder model, based upon incrementally downsampling CNNs to a flat code, and then upsampling CNNs. |
Attributes
The width and neight that images are downsampled to before flattening it into a code. |
- cnn.autoencoder.BASE_IMAGE_SIZE :int = 16[source]
The width and neight that images are downsampled to before flattening it into a code.
- class cnn.autoencoder.AutoEncoder(number_channels: int = 3, input_size: int = 32, code_size: int = 16, number_filters: int = 4)[source]
Bases:
pytorch_lightning.LightningModule
An AutoEncoder model, based upon incrementally downsampling CNNs to a flat code, and then upsampling CNNs.
- forward(self, activation: Union[torch.Tensor, List[torch.Tensor]]) torch.Tensor [source]
Overrides
pl.LightningModule
.
- forward_encode_decode(self, input: Union[torch.Tensor, List[torch.Tensor]]) torch.Tensor [source]
Performs both the encode and decode step on an input (batched).
- Parameters
input – the input tensor.
- Returns
the tensor after encoding and decoding.
- training_step(self, batch: List[torch.Tensor], batch_idx: int)[source]
Overrides
pl.LightningModule
.
- validation_step(self, batch: List[torch.Tensor], batch_idx: int)[source]
Overrides
pl.LightningModule
.
- predict_step(self, batch, batch_idx: int, dataloader_idx: int = None)[source]
Step function called during
predict()
. By default, it callsforward()
. Override to add any processing logic.The
predict_step()
is used to scale inference on multi-devices.To prevent an OOM error, it is possible to use
BasePredictionWriter
callback to write the predictions to disk or database after each batch or on epoch end.The
BasePredictionWriter
should be used while using a spawn based accelerator. This happens forTrainer(strategy="ddp_spawn")
or training on 8 TPU cores withTrainer(tpu_cores=8)
as predictions won’t be returned.Example
class MyModel(LightningModule): def predicts_step(self, batch, batch_idx, dataloader_idx): return self(batch) dm = ... model = MyModel() trainer = Trainer(gpus=2) predictions = trainer.predict(model, dm)
- Args:
batch: Current batch batch_idx: Index of current batch dataloader_idx: Index of the current dataloader
- Return:
Predicted output