Source code for data.load_images_module

"""A Lightning data-module for loading images recursively from a file-system directory."""
from .load_images import ImageLoader
import pytorch_lightning as pl
from typing import Optional


[docs]class LoadImagesModule(pl.LightningDataModule): """A Lightning data-module for loading images recursively from a file-system directory.""" def __init__( self, image_directory: str, image_size: int, extension: str = "jpg", rgb: bool = True, batch_size: int = 16, num_workers: int = 1, ratio_validation: float = 0.3, ratio_test: float = 0.2, ): """Creates the module with necessary parameterization. :param image_directory: the image directory to load images recursively from. :param image_size: the size (height and width) to resize all images to. :param extension: the extension (without a leading period) that all image files must match. :param rgb: when true, images are always loaded as RGB. when false, they are loaded as grayscale. :param batch_size: how many images should be in a batch. :param num_workers: the number of workers for the data-loader. :param ratio_validation: a number between 0 and 1 determining linearly how many elements belong in the validation set e.g. 0.4 would try and place 40% approximately of elements into the second batch. :param ratio_test: a number between 0 and 1 determining linearly how many elements belong in the validation set e.g. 0.2 would try and place 20% approximately of elements into the third batch. """ super().__init__() self._loader = ImageLoader( image_directory, [image_size, image_size], extension, rgb, batch_size, num_workers, ) self._ratio_validation = ratio_validation self._ratio_test = ratio_test self.predict_data = None self.train_data = None self.validation_data = None self.test_data = None
[docs] def prepare_data(self) -> None: """Overrides :class:`pl.LightningDataModule`.""" pass
[docs] def setup(self, stage: Optional[pl.trainer.states.TrainerFn] = None): """Overrides :class:`pl.LightningDataModule`.""" # Load images normally occurs here, but we've already done it in the constructor. if stage == pl.trainer.states.TrainerFn.PREDICTING: # Load images all as one self.predict_data = self._loader.load_images() else: # Load images split into three sets. ( self.train_data, self.validation_data, self.test_data, ) = self._loader.load_images_split_three( self._ratio_validation, self._ratio_test,
)
[docs] def train_dataloader(self): """Overrides :class:`pl.LightningDataModule`.""" return self.train_data
[docs] def val_dataloader(self): """Overrides :class:`pl.LightningDataModule`.""" return self.validation_data
[docs] def test_dataloader(self): """Overrides :class:`pl.LightningDataModule`.""" return self.test_data
[docs] def predict_dataloader(self): """Overrides :class:`pl.LightningDataModule`.""" return self.predict_data