mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 14:11:59 +00:00
Tweaaks after re-running expt1
This commit is contained in:
@@ -14,5 +14,11 @@ def create_dataset(
|
||||
**kwargs,
|
||||
):
|
||||
ds = dataset(DATASET_DIR, download=True, transform=ToTensor())
|
||||
train, val, test = (DataLoader(i, **kwargs) for i in random_split(ds, split))
|
||||
shuffle = kwargs.pop("shuffle", False)
|
||||
shuffle_train = kwargs.pop("shuffle_train", False)
|
||||
to_shuffle = (shuffle or shuffle_train, shuffle, shuffle)
|
||||
train, val, test = (
|
||||
DataLoader(i, shuffle=s, **kwargs)
|
||||
for i, s in zip(random_split(ds, split), to_shuffle)
|
||||
)
|
||||
return train, val, test
|
||||
|
||||
@@ -14,8 +14,8 @@ def collate(batch):
|
||||
x, y = zip(*batch)
|
||||
x = [i[0] for i in x]
|
||||
y = [torch.tensor(i) for i in y]
|
||||
x = torch.stack(x).to("cuda")
|
||||
y = torch.tensor(y).to("cuda")
|
||||
x = torch.stack(x)
|
||||
y = torch.tensor(y)
|
||||
return x, y
|
||||
|
||||
|
||||
@@ -27,14 +27,18 @@ def get_singleton_dataset():
|
||||
from symbolic_nn_tests.dataloader import create_dataset
|
||||
|
||||
return create_dataset(
|
||||
dataset=QMNIST, collate_fn=collate, batch_size=128, shuffle=True
|
||||
dataset=QMNIST,
|
||||
collate_fn=collate,
|
||||
batch_size=128,
|
||||
shuffle_train=True,
|
||||
num_workers=11,
|
||||
)
|
||||
|
||||
|
||||
def oh_vs_cat_cross_entropy(y_bin, y_cat):
|
||||
return nn.functional.cross_entropy(
|
||||
y_bin,
|
||||
nn.functional.one_hot(y_cat),
|
||||
nn.functional.one_hot(y_cat, num_classes=10).float(),
|
||||
)
|
||||
|
||||
|
||||
@@ -56,7 +60,7 @@ def main(
|
||||
|
||||
train, val, test = get_singleton_dataset()
|
||||
lmodel = TrainingWrapper(
|
||||
model, train_loss=train_loss, val_loss=val_loss, test_loss=val_loss
|
||||
model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss
|
||||
)
|
||||
lmodel.configure_optimizers(**kwargs)
|
||||
trainer = L.Trainer(max_epochs=20, logger=logger)
|
||||
|
||||
@@ -11,7 +11,9 @@ def create_semantic_cross_entropy(semantic_matrix):
|
||||
return ce_loss * semantic_penalty
|
||||
|
||||
def oh_vs_cat_semantic_cross_entropy(input_oh, target_cat):
|
||||
return semantic_cross_entropy(input_oh, torch.nn.functional.one_hot(target_cat))
|
||||
return semantic_cross_entropy(
|
||||
input_oh, torch.nn.functional.one_hot(target_cat, num_classes=10).float()
|
||||
)
|
||||
|
||||
return oh_vs_cat_semantic_cross_entropy
|
||||
|
||||
|
||||
@@ -49,7 +49,11 @@ def get_singleton_dataset():
|
||||
from symbolic_nn_tests.experiment2.dataset import collate, pubchem
|
||||
|
||||
return create_dataset(
|
||||
dataset=pubchem, collate_fn=collate, batch_size=256, shuffle=True
|
||||
dataset=pubchem,
|
||||
collate_fn=collate,
|
||||
batch_size=256,
|
||||
shuffle_train=True,
|
||||
num_workers=11,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user