mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-23 06:32:05 +00:00
12 lines
396 B
Python
12 lines
396 B
Python
from torchvision.datasets import Caltech256
|
|
from torch.utils.data import random_split
|
|
from torch.utils.data import BatchSampler
|
|
|
|
|
|
def get_dataset(split: (float, float, float) = (0.7, 0.1, 0.2), *args, **kwargs):
|
|
ds = Caltech256("../datasets/", download=True)
|
|
train, test, val = (
|
|
BatchSampler(i, *args, **kwargs) for i in random_split(ds, split)
|
|
)
|
|
return train, test, val
|