From c3dba79a0bd2c275cd67cf240bc57753b2d63c0a Mon Sep 17 00:00:00 2001 From: Cian Hughes Date: Wed, 22 May 2024 10:35:14 +0100 Subject: [PATCH] Refactored slightly to allow other experiments --- symbolic_nn_tests/__main__.py | 40 +++++++++++-------- symbolic_nn_tests/{ => qmnist}/model.py | 4 +- .../{ => qmnist}/semantic_loss.py | 0 3 files changed, 26 insertions(+), 18 deletions(-) rename symbolic_nn_tests/{ => qmnist}/model.py (89%) rename symbolic_nn_tests/{ => qmnist}/semantic_loss.py (100%) diff --git a/symbolic_nn_tests/__main__.py b/symbolic_nn_tests/__main__.py index 3be6a0b..5d28fbe 100644 --- a/symbolic_nn_tests/__main__.py +++ b/symbolic_nn_tests/__main__.py @@ -1,8 +1,8 @@ LEARNING_RATE = 10e-5 -def run_test(loss_func, version): - from .model import main as test_model +def qmnist_test(loss_func, version): + from .qmnist.model import main as test_model from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import WandbLogger import wandb @@ -12,26 +12,34 @@ def run_test(loss_func, version): name="logs/comparison", version=version, ) - wandb_logger = WandbLogger( - project="Symbolic_NN_Tests", - name=version, - dir="wandb", - ) - logger = [tb_logger, wandb_logger] + # wandb_logger = WandbLogger( + # project="Symbolic_NN_Tests", + # name=version, + # dir="wandb", + # ) + logger = [ + tb_logger, + ] # wandb_logger] test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE) wandb.finish() -def main(): - from . import semantic_loss +def qmnist_experiment(): + from .qmnist import semantic_loss from torch import nn - run_test(nn.functional.cross_entropy, "cross_entropy") - run_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy") - run_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy") - run_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy") - run_test(semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy") - run_test(semantic_loss.garbage_cross_entropy, "garbage_cross_entropy") + qmnist_test(nn.functional.cross_entropy, "cross_entropy") + qmnist_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy") + qmnist_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy") + qmnist_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy") + qmnist_test( + semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy" + ) + qmnist_test(semantic_loss.garbage_cross_entropy, "garbage_cross_entropy") + + +def main(): + qmnist_experiment() if __name__ == "__main__": diff --git a/symbolic_nn_tests/model.py b/symbolic_nn_tests/qmnist/model.py similarity index 89% rename from symbolic_nn_tests/model.py rename to symbolic_nn_tests/qmnist/model.py index 285c33d..b164471 100644 --- a/symbolic_nn_tests/model.py +++ b/symbolic_nn_tests/qmnist/model.py @@ -14,7 +14,7 @@ model = nn.Sequential( def get_singleton_dataset(): from torchvision.datasets import QMNIST - from .dataloader import get_dataset + from symbolic_nn_tests.dataloader import get_dataset return get_dataset(dataset=QMNIST) @@ -22,7 +22,7 @@ def get_singleton_dataset(): def main(loss_func=nn.functional.cross_entropy, logger=None, **kwargs): import lightning as L - from .train import TrainingWrapper + from symbolic_nn_tests.train import TrainingWrapper if logger is None: from lightning.pytorch.loggers import TensorBoardLogger diff --git a/symbolic_nn_tests/semantic_loss.py b/symbolic_nn_tests/qmnist/semantic_loss.py similarity index 100% rename from symbolic_nn_tests/semantic_loss.py rename to symbolic_nn_tests/qmnist/semantic_loss.py