From 02d70964e1f6b1a0a29b0c1a49571d47ad72a58a Mon Sep 17 00:00:00 2001 From: Cian-H Date: Mon, 13 May 2024 15:38:50 +0100 Subject: [PATCH] Added image transformations to dataset fetching --- symbolic_nn_tests/dataloader.py | 11 ++++++++--- symbolic_nn_tests/ffnn.py | 9 +++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/symbolic_nn_tests/dataloader.py b/symbolic_nn_tests/dataloader.py index 0c78f14..4ec6dcf 100644 --- a/symbolic_nn_tests/dataloader.py +++ b/symbolic_nn_tests/dataloader.py @@ -1,5 +1,6 @@ from pathlib import Path from torchvision.datasets import Caltech256 +from torchvision.transforms import ToTensor from torch.utils.data import random_split from torch.utils.data import BatchSampler @@ -7,9 +8,13 @@ from torch.utils.data import BatchSampler PROJECT_ROOT = Path(__file__).parent.parent -def get_dataset(split: (float, float, float) = (0.7, 0.1, 0.2), *args, **kwargs): - ds = Caltech256(PROJECT_ROOT / "datasets/", download=True) +def get_dataset( + split: (float, float, float) = (0.7, 0.1, 0.2), + batch_size: int = 128, + drop_last: bool = False, +): + ds = Caltech256(PROJECT_ROOT / "datasets/", download=True, transform=ToTensor()) train, test, val = ( - BatchSampler(i, *args, **kwargs) for i in random_split(ds, split) + BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split) ) return train, test, val diff --git a/symbolic_nn_tests/ffnn.py b/symbolic_nn_tests/ffnn.py index e69de29..b4da3d7 100644 --- a/symbolic_nn_tests/ffnn.py +++ b/symbolic_nn_tests/ffnn.py @@ -0,0 +1,9 @@ +from .dataloader import get_dataset + + +def main(): + pass + + +if __name__ == "__main__": + main()