mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-23 06:32:05 +00:00
Added image transformations to dataset fetching
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torchvision.datasets import Caltech256
|
from torchvision.datasets import Caltech256
|
||||||
|
from torchvision.transforms import ToTensor
|
||||||
from torch.utils.data import random_split
|
from torch.utils.data import random_split
|
||||||
from torch.utils.data import BatchSampler
|
from torch.utils.data import BatchSampler
|
||||||
|
|
||||||
@@ -7,9 +8,13 @@ from torch.utils.data import BatchSampler
|
|||||||
PROJECT_ROOT = Path(__file__).parent.parent
|
PROJECT_ROOT = Path(__file__).parent.parent
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(split: (float, float, float) = (0.7, 0.1, 0.2), *args, **kwargs):
|
def get_dataset(
|
||||||
ds = Caltech256(PROJECT_ROOT / "datasets/", download=True)
|
split: (float, float, float) = (0.7, 0.1, 0.2),
|
||||||
|
batch_size: int = 128,
|
||||||
|
drop_last: bool = False,
|
||||||
|
):
|
||||||
|
ds = Caltech256(PROJECT_ROOT / "datasets/", download=True, transform=ToTensor())
|
||||||
train, test, val = (
|
train, test, val = (
|
||||||
BatchSampler(i, *args, **kwargs) for i in random_split(ds, split)
|
BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split)
|
||||||
)
|
)
|
||||||
return train, test, val
|
return train, test, val
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
from .dataloader import get_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user