Refactored slightly to allow other experiments

This commit is contained in:
2024-05-22 10:35:14 +01:00
parent ae14a1d7c0
commit c3dba79a0b
3 changed files with 26 additions and 18 deletions

View File

@@ -1,8 +1,8 @@
LEARNING_RATE = 10e-5 LEARNING_RATE = 10e-5
def run_test(loss_func, version): def qmnist_test(loss_func, version):
from .model import main as test_model from .qmnist.model import main as test_model
from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.loggers import WandbLogger
import wandb import wandb
@@ -12,26 +12,34 @@ def run_test(loss_func, version):
name="logs/comparison", name="logs/comparison",
version=version, version=version,
) )
wandb_logger = WandbLogger( # wandb_logger = WandbLogger(
project="Symbolic_NN_Tests", # project="Symbolic_NN_Tests",
name=version, # name=version,
dir="wandb", # dir="wandb",
) # )
logger = [tb_logger, wandb_logger] logger = [
tb_logger,
] # wandb_logger]
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE) test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
wandb.finish() wandb.finish()
def main(): def qmnist_experiment():
from . import semantic_loss from .qmnist import semantic_loss
from torch import nn from torch import nn
run_test(nn.functional.cross_entropy, "cross_entropy") qmnist_test(nn.functional.cross_entropy, "cross_entropy")
run_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy") qmnist_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy")
run_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy") qmnist_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy")
run_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy") qmnist_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy")
run_test(semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy") qmnist_test(
run_test(semantic_loss.garbage_cross_entropy, "garbage_cross_entropy") semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy"
)
qmnist_test(semantic_loss.garbage_cross_entropy, "garbage_cross_entropy")
def main():
qmnist_experiment()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -14,7 +14,7 @@ model = nn.Sequential(
def get_singleton_dataset(): def get_singleton_dataset():
from torchvision.datasets import QMNIST from torchvision.datasets import QMNIST
from .dataloader import get_dataset from symbolic_nn_tests.dataloader import get_dataset
return get_dataset(dataset=QMNIST) return get_dataset(dataset=QMNIST)
@@ -22,7 +22,7 @@ def get_singleton_dataset():
def main(loss_func=nn.functional.cross_entropy, logger=None, **kwargs): def main(loss_func=nn.functional.cross_entropy, logger=None, **kwargs):
import lightning as L import lightning as L
from .train import TrainingWrapper from symbolic_nn_tests.train import TrainingWrapper
if logger is None: if logger is None:
from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import TensorBoardLogger