Updated dataloader to default to tensor transform

This commit is contained in:
2024-05-13 15:55:54 +01:00
parent 02d70964e1
commit 63bea4d355

View File

@@ -12,8 +12,11 @@ def get_dataset(
split: (float, float, float) = (0.7, 0.1, 0.2),
batch_size: int = 128,
drop_last: bool = False,
**kwargs,
):
ds = Caltech256(PROJECT_ROOT / "datasets/", download=True, transform=ToTensor())
_kwargs = {"transform": ToTensor()}
_kwargs.update(kwargs)
ds = Caltech256(PROJECT_ROOT / "datasets/", download=True, **_kwargs)
train, test, val = (
BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split)
)