cnn
CNN models.
Package Contents
Classes
An AutoEncoder model, based upon incrementally downsampling CNNs to a flat code, and then upsampling CNNs. |
- class cnn.AutoEncoder(number_channels: int = 3, input_size: int = 32, code_size: int = 16, number_filters: int = 4)[source]
Bases:
pytorch_lightning.LightningModuleAn AutoEncoder model, based upon incrementally downsampling CNNs to a flat code, and then upsampling CNNs.
- forward(self, activation: Union[torch.Tensor, List[torch.Tensor]]) torch.Tensor
Overrides
pl.LightningModule.
- forward_encode_decode(self, input: Union[torch.Tensor, List[torch.Tensor]]) torch.Tensor
Performs both the encode and decode step on an input (batched).
- Parameters
input – the input tensor.
- Returns
the tensor after encoding and decoding.
- predict_step(self, batch, batch_idx: int, dataloader_idx: int = None)
Step function called during
predict(). By default, it callsforward(). Override to add any processing logic.The
predict_step()is used to scale inference on multi-devices.To prevent an OOM error, it is possible to use
BasePredictionWritercallback to write the predictions to disk or database after each batch or on epoch end.The
BasePredictionWritershould be used while using a spawn based accelerator. This happens forTrainer(strategy="ddp_spawn")or training on 8 TPU cores withTrainer(tpu_cores=8)as predictions won’t be returned.Example
class MyModel(LightningModule): def predicts_step(self, batch, batch_idx, dataloader_idx): return self(batch) dm = ... model = MyModel() trainer = Trainer(gpus=2) predictions = trainer.predict(model, dm)
- Args:
batch: Current batch batch_idx: Index of current batch dataloader_idx: Index of the current dataloader
- Return:
Predicted output
- configure_optimizers(self)
Overrides
pl.LightningModule. This docstring replaces the parent docstring which is errored.