Tweaaks after re-running expt1

This commit is contained in:
2024-06-17 17:31:00 +01:00
parent c6a1b7207a
commit 85c6ba42a1
4 changed files with 24 additions and 8 deletions

View File

@@ -14,5 +14,11 @@ def create_dataset(
**kwargs,
):
ds = dataset(DATASET_DIR, download=True, transform=ToTensor())
train, val, test = (DataLoader(i, **kwargs) for i in random_split(ds, split))
shuffle = kwargs.pop("shuffle", False)
shuffle_train = kwargs.pop("shuffle_train", False)
to_shuffle = (shuffle or shuffle_train, shuffle, shuffle)
train, val, test = (
DataLoader(i, shuffle=s, **kwargs)
for i, s in zip(random_split(ds, split), to_shuffle)
)
return train, val, test

View File

@@ -14,8 +14,8 @@ def collate(batch):
x, y = zip(*batch)
x = [i[0] for i in x]
y = [torch.tensor(i) for i in y]
x = torch.stack(x).to("cuda")
y = torch.tensor(y).to("cuda")
x = torch.stack(x)
y = torch.tensor(y)
return x, y
@@ -27,14 +27,18 @@ def get_singleton_dataset():
from symbolic_nn_tests.dataloader import create_dataset
return create_dataset(
dataset=QMNIST, collate_fn=collate, batch_size=128, shuffle=True
dataset=QMNIST,
collate_fn=collate,
batch_size=128,
shuffle_train=True,
num_workers=11,
)
def oh_vs_cat_cross_entropy(y_bin, y_cat):
return nn.functional.cross_entropy(
y_bin,
nn.functional.one_hot(y_cat),
nn.functional.one_hot(y_cat, num_classes=10).float(),
)
@@ -56,7 +60,7 @@ def main(
train, val, test = get_singleton_dataset()
lmodel = TrainingWrapper(
model, train_loss=train_loss, val_loss=val_loss, test_loss=val_loss
model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss
)
lmodel.configure_optimizers(**kwargs)
trainer = L.Trainer(max_epochs=20, logger=logger)

View File

@@ -11,7 +11,9 @@ def create_semantic_cross_entropy(semantic_matrix):
return ce_loss * semantic_penalty
def oh_vs_cat_semantic_cross_entropy(input_oh, target_cat):
return semantic_cross_entropy(input_oh, torch.nn.functional.one_hot(target_cat))
return semantic_cross_entropy(
input_oh, torch.nn.functional.one_hot(target_cat, num_classes=10).float()
)
return oh_vs_cat_semantic_cross_entropy

View File

@@ -49,7 +49,11 @@ def get_singleton_dataset():
from symbolic_nn_tests.experiment2.dataset import collate, pubchem
return create_dataset(
dataset=pubchem, collate_fn=collate, batch_size=256, shuffle=True
dataset=pubchem,
collate_fn=collate,
batch_size=256,
shuffle_train=True,
num_workers=11,
)