Added image transformations to dataset fetching

This commit is contained in:
2024-05-13 15:38:50 +01:00
parent 223b064564
commit 02d70964e1
2 changed files with 17 additions and 3 deletions

View File

@@ -1,5 +1,6 @@
from pathlib import Path
from torchvision.datasets import Caltech256
from torchvision.transforms import ToTensor
from torch.utils.data import random_split
from torch.utils.data import BatchSampler
@@ -7,9 +8,13 @@ from torch.utils.data import BatchSampler
PROJECT_ROOT = Path(__file__).parent.parent
def get_dataset(split: (float, float, float) = (0.7, 0.1, 0.2), *args, **kwargs):
ds = Caltech256(PROJECT_ROOT / "datasets/", download=True)
def get_dataset(
split: (float, float, float) = (0.7, 0.1, 0.2),
batch_size: int = 128,
drop_last: bool = False,
):
ds = Caltech256(PROJECT_ROOT / "datasets/", download=True, transform=ToTensor())
train, test, val = (
BatchSampler(i, *args, **kwargs) for i in random_split(ds, split)
BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split)
)
return train, test, val

View File

@@ -0,0 +1,9 @@
from .dataloader import get_dataset
def main():
pass
if __name__ == "__main__":
main()