diff --git a/symbolic_nn_tests/dataloader.py b/symbolic_nn_tests/dataloader.py index 4ec6dcf..1f2ce0f 100644 --- a/symbolic_nn_tests/dataloader.py +++ b/symbolic_nn_tests/dataloader.py @@ -12,8 +12,11 @@ def get_dataset( split: (float, float, float) = (0.7, 0.1, 0.2), batch_size: int = 128, drop_last: bool = False, + **kwargs, ): - ds = Caltech256(PROJECT_ROOT / "datasets/", download=True, transform=ToTensor()) + _kwargs = {"transform": ToTensor()} + _kwargs.update(kwargs) + ds = Caltech256(PROJECT_ROOT / "datasets/", download=True, **_kwargs) train, test, val = ( BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split) )