cnn
CNN models.
Package Contents
Classes
An AutoEncoder model, based upon incrementally downsampling CNNs to a flat code, and then upsampling CNNs. |
- class cnn.AutoEncoder(number_channels: int = 3, input_size: int = 32, code_size: int = 16, number_filters: int = 4)[source]
Bases:
pytorch_lightning.LightningModule
An AutoEncoder model, based upon incrementally downsampling CNNs to a flat code, and then upsampling CNNs.
- forward(self, activation: Union[torch.Tensor, List[torch.Tensor]]) torch.Tensor
Overrides
pl.LightningModule
.
- forward_encode_decode(self, input: Union[torch.Tensor, List[torch.Tensor]]) torch.Tensor
Performs both the encode and decode step on an input (batched).
- Parameters
input – the input tensor.
- Returns
the tensor after encoding and decoding.
- predict_step(self, batch, batch_idx: int, dataloader_idx: int = None)
Step function called during
predict()
. By default, it callsforward()
. Override to add any processing logic.The
predict_step()
is used to scale inference on multi-devices.To prevent an OOM error, it is possible to use
BasePredictionWriter
callback to write the predictions to disk or database after each batch or on epoch end.The
BasePredictionWriter
should be used while using a spawn based accelerator. This happens forTrainer(strategy="ddp_spawn")
or training on 8 TPU cores withTrainer(tpu_cores=8)
as predictions won’t be returned.Example
class MyModel(LightningModule): def predicts_step(self, batch, batch_idx, dataloader_idx): return self(batch) dm = ... model = MyModel() trainer = Trainer(gpus=2) predictions = trainer.predict(model, dm)
- Args:
batch: Current batch batch_idx: Index of current batch dataloader_idx: Index of the current dataloader
- Return:
Predicted output
- configure_optimizers(self)
Overrides
pl.LightningModule
. This docstring replaces the parent docstring which is errored.