mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-23 06:32:05 +00:00
Tweaaks after re-running expt1
This commit is contained in:
@@ -14,5 +14,11 @@ def create_dataset(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
ds = dataset(DATASET_DIR, download=True, transform=ToTensor())
|
ds = dataset(DATASET_DIR, download=True, transform=ToTensor())
|
||||||
train, val, test = (DataLoader(i, **kwargs) for i in random_split(ds, split))
|
shuffle = kwargs.pop("shuffle", False)
|
||||||
|
shuffle_train = kwargs.pop("shuffle_train", False)
|
||||||
|
to_shuffle = (shuffle or shuffle_train, shuffle, shuffle)
|
||||||
|
train, val, test = (
|
||||||
|
DataLoader(i, shuffle=s, **kwargs)
|
||||||
|
for i, s in zip(random_split(ds, split), to_shuffle)
|
||||||
|
)
|
||||||
return train, val, test
|
return train, val, test
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ def collate(batch):
|
|||||||
x, y = zip(*batch)
|
x, y = zip(*batch)
|
||||||
x = [i[0] for i in x]
|
x = [i[0] for i in x]
|
||||||
y = [torch.tensor(i) for i in y]
|
y = [torch.tensor(i) for i in y]
|
||||||
x = torch.stack(x).to("cuda")
|
x = torch.stack(x)
|
||||||
y = torch.tensor(y).to("cuda")
|
y = torch.tensor(y)
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
@@ -27,14 +27,18 @@ def get_singleton_dataset():
|
|||||||
from symbolic_nn_tests.dataloader import create_dataset
|
from symbolic_nn_tests.dataloader import create_dataset
|
||||||
|
|
||||||
return create_dataset(
|
return create_dataset(
|
||||||
dataset=QMNIST, collate_fn=collate, batch_size=128, shuffle=True
|
dataset=QMNIST,
|
||||||
|
collate_fn=collate,
|
||||||
|
batch_size=128,
|
||||||
|
shuffle_train=True,
|
||||||
|
num_workers=11,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def oh_vs_cat_cross_entropy(y_bin, y_cat):
|
def oh_vs_cat_cross_entropy(y_bin, y_cat):
|
||||||
return nn.functional.cross_entropy(
|
return nn.functional.cross_entropy(
|
||||||
y_bin,
|
y_bin,
|
||||||
nn.functional.one_hot(y_cat),
|
nn.functional.one_hot(y_cat, num_classes=10).float(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -56,7 +60,7 @@ def main(
|
|||||||
|
|
||||||
train, val, test = get_singleton_dataset()
|
train, val, test = get_singleton_dataset()
|
||||||
lmodel = TrainingWrapper(
|
lmodel = TrainingWrapper(
|
||||||
model, train_loss=train_loss, val_loss=val_loss, test_loss=val_loss
|
model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss
|
||||||
)
|
)
|
||||||
lmodel.configure_optimizers(**kwargs)
|
lmodel.configure_optimizers(**kwargs)
|
||||||
trainer = L.Trainer(max_epochs=20, logger=logger)
|
trainer = L.Trainer(max_epochs=20, logger=logger)
|
||||||
|
|||||||
@@ -11,7 +11,9 @@ def create_semantic_cross_entropy(semantic_matrix):
|
|||||||
return ce_loss * semantic_penalty
|
return ce_loss * semantic_penalty
|
||||||
|
|
||||||
def oh_vs_cat_semantic_cross_entropy(input_oh, target_cat):
|
def oh_vs_cat_semantic_cross_entropy(input_oh, target_cat):
|
||||||
return semantic_cross_entropy(input_oh, torch.nn.functional.one_hot(target_cat))
|
return semantic_cross_entropy(
|
||||||
|
input_oh, torch.nn.functional.one_hot(target_cat, num_classes=10).float()
|
||||||
|
)
|
||||||
|
|
||||||
return oh_vs_cat_semantic_cross_entropy
|
return oh_vs_cat_semantic_cross_entropy
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,11 @@ def get_singleton_dataset():
|
|||||||
from symbolic_nn_tests.experiment2.dataset import collate, pubchem
|
from symbolic_nn_tests.experiment2.dataset import collate, pubchem
|
||||||
|
|
||||||
return create_dataset(
|
return create_dataset(
|
||||||
dataset=pubchem, collate_fn=collate, batch_size=256, shuffle=True
|
dataset=pubchem,
|
||||||
|
collate_fn=collate,
|
||||||
|
batch_size=256,
|
||||||
|
shuffle_train=True,
|
||||||
|
num_workers=11,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user