From 85c6ba42a135ebdbdf957ee1395911e9c1bc2910 Mon Sep 17 00:00:00 2001 From: Cian Hughes Date: Mon, 17 Jun 2024 17:31:00 +0100 Subject: [PATCH] Tweaaks after re-running expt1 --- symbolic_nn_tests/dataloader.py | 8 +++++++- symbolic_nn_tests/experiment1/model.py | 14 +++++++++----- symbolic_nn_tests/experiment1/semantic_loss.py | 4 +++- symbolic_nn_tests/experiment2/model.py | 6 +++++- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/symbolic_nn_tests/dataloader.py b/symbolic_nn_tests/dataloader.py index 55ab1c4..22c16ec 100644 --- a/symbolic_nn_tests/dataloader.py +++ b/symbolic_nn_tests/dataloader.py @@ -14,5 +14,11 @@ def create_dataset( **kwargs, ): ds = dataset(DATASET_DIR, download=True, transform=ToTensor()) - train, val, test = (DataLoader(i, **kwargs) for i in random_split(ds, split)) + shuffle = kwargs.pop("shuffle", False) + shuffle_train = kwargs.pop("shuffle_train", False) + to_shuffle = (shuffle or shuffle_train, shuffle, shuffle) + train, val, test = ( + DataLoader(i, shuffle=s, **kwargs) + for i, s in zip(random_split(ds, split), to_shuffle) + ) return train, val, test diff --git a/symbolic_nn_tests/experiment1/model.py b/symbolic_nn_tests/experiment1/model.py index 5679e4b..45e9080 100644 --- a/symbolic_nn_tests/experiment1/model.py +++ b/symbolic_nn_tests/experiment1/model.py @@ -14,8 +14,8 @@ def collate(batch): x, y = zip(*batch) x = [i[0] for i in x] y = [torch.tensor(i) for i in y] - x = torch.stack(x).to("cuda") - y = torch.tensor(y).to("cuda") + x = torch.stack(x) + y = torch.tensor(y) return x, y @@ -27,14 +27,18 @@ def get_singleton_dataset(): from symbolic_nn_tests.dataloader import create_dataset return create_dataset( - dataset=QMNIST, collate_fn=collate, batch_size=128, shuffle=True + dataset=QMNIST, + collate_fn=collate, + batch_size=128, + shuffle_train=True, + num_workers=11, ) def oh_vs_cat_cross_entropy(y_bin, y_cat): return nn.functional.cross_entropy( y_bin, - nn.functional.one_hot(y_cat), + nn.functional.one_hot(y_cat, num_classes=10).float(), ) @@ -56,7 +60,7 @@ def main( train, val, test = get_singleton_dataset() lmodel = TrainingWrapper( - model, train_loss=train_loss, val_loss=val_loss, test_loss=val_loss + model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss ) lmodel.configure_optimizers(**kwargs) trainer = L.Trainer(max_epochs=20, logger=logger) diff --git a/symbolic_nn_tests/experiment1/semantic_loss.py b/symbolic_nn_tests/experiment1/semantic_loss.py index 5c4a633..636b2aa 100644 --- a/symbolic_nn_tests/experiment1/semantic_loss.py +++ b/symbolic_nn_tests/experiment1/semantic_loss.py @@ -11,7 +11,9 @@ def create_semantic_cross_entropy(semantic_matrix): return ce_loss * semantic_penalty def oh_vs_cat_semantic_cross_entropy(input_oh, target_cat): - return semantic_cross_entropy(input_oh, torch.nn.functional.one_hot(target_cat)) + return semantic_cross_entropy( + input_oh, torch.nn.functional.one_hot(target_cat, num_classes=10).float() + ) return oh_vs_cat_semantic_cross_entropy diff --git a/symbolic_nn_tests/experiment2/model.py b/symbolic_nn_tests/experiment2/model.py index 841bf75..1277f1b 100644 --- a/symbolic_nn_tests/experiment2/model.py +++ b/symbolic_nn_tests/experiment2/model.py @@ -49,7 +49,11 @@ def get_singleton_dataset(): from symbolic_nn_tests.experiment2.dataset import collate, pubchem return create_dataset( - dataset=pubchem, collate_fn=collate, batch_size=256, shuffle=True + dataset=pubchem, + collate_fn=collate, + batch_size=256, + shuffle_train=True, + num_workers=11, )