Files
symbolic_nn_tests/symbolic_nn_tests/dataloader.py

12 lines
396 B
Python

from torchvision.datasets import Caltech256
from torch.utils.data import random_split
from torch.utils.data import BatchSampler
def get_dataset(split: (float, float, float) = (0.7, 0.1, 0.2), *args, **kwargs):
ds = Caltech256("../datasets/", download=True)
train, test, val = (
BatchSampler(i, *args, **kwargs) for i in random_split(ds, split)
)
return train, test, val