From 63bea4d35577cc5c85c80034abf4b9ca87baac95 Mon Sep 17 00:00:00 2001 From: Cian-H Date: Mon, 13 May 2024 15:55:54 +0100 Subject: [PATCH] Updated dataloader to default to tensor transform --- symbolic_nn_tests/dataloader.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/symbolic_nn_tests/dataloader.py b/symbolic_nn_tests/dataloader.py index 4ec6dcf..1f2ce0f 100644 --- a/symbolic_nn_tests/dataloader.py +++ b/symbolic_nn_tests/dataloader.py @@ -12,8 +12,11 @@ def get_dataset( split: (float, float, float) = (0.7, 0.1, 0.2), batch_size: int = 128, drop_last: bool = False, + **kwargs, ): - ds = Caltech256(PROJECT_ROOT / "datasets/", download=True, transform=ToTensor()) + _kwargs = {"transform": ToTensor()} + _kwargs.update(kwargs) + ds = Caltech256(PROJECT_ROOT / "datasets/", download=True, **_kwargs) train, test, val = ( BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split) )