mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 14:11:59 +00:00
Updated dataloader to default to tensor transform
This commit is contained in:
@@ -12,8 +12,11 @@ def get_dataset(
|
|||||||
split: (float, float, float) = (0.7, 0.1, 0.2),
|
split: (float, float, float) = (0.7, 0.1, 0.2),
|
||||||
batch_size: int = 128,
|
batch_size: int = 128,
|
||||||
drop_last: bool = False,
|
drop_last: bool = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
ds = Caltech256(PROJECT_ROOT / "datasets/", download=True, transform=ToTensor())
|
_kwargs = {"transform": ToTensor()}
|
||||||
|
_kwargs.update(kwargs)
|
||||||
|
ds = Caltech256(PROJECT_ROOT / "datasets/", download=True, **_kwargs)
|
||||||
train, test, val = (
|
train, test, val = (
|
||||||
BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split)
|
BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split)
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user