Refactored to lay groundwork for TUI

This commit is contained in:
2024-05-23 09:50:01 +01:00
parent 02b9e43f4d
commit 9536beea04

View File

@@ -1,54 +1,25 @@
LEARNING_RATE = 10e-5
from typing import Optional, Iterable
import experiment1
def qmnist_test(loss_func, version, tensorboard=True, wandb=True):
from .experiment_1.model import main as test_model
logger = []
if tensorboard:
from lightning.pytorch.loggers import TensorBoardLogger
tb_logger = TensorBoardLogger(
save_dir=".",
name="logs/comparison",
version=version,
)
logger.append(tb_logger)
if wandb:
import wandb as _wandb
from lightning.pytorch.loggers import WandbLogger
wandb_logger = WandbLogger(
project="Symbolic_NN_Tests",
name=version,
dir="wandb",
)
logger.append(wandb_logger)
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
if wandb:
_wandb.finish()
EXPERIMENTS = (experiment1,)
def qmnist_experiment():
from .experiment_1 import semantic_loss
from torch import nn
def main(
experiments: Optional[int | Iterable[int]] = None,
tensorboard: bool = True,
wandb: bool = True,
):
if experiments is None:
experiments = range(1, len(EXPERIMENTS) + 1)
elif not isinstance(experiments, Iterable):
experiments = (experiments,)
qmnist_test(nn.functional.cross_entropy, "cross_entropy")
qmnist_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy")
qmnist_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy")
qmnist_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy")
qmnist_test(
semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy"
)
qmnist_test(semantic_loss.garbage_cross_entropy, "garbage_cross_entropy")
experiment_indeces = (i - 1 for i in experiments)
experiment_funcs = [EXPERIMENTS[i].run for i in experiment_indeces]
def main():
qmnist_experiment()
for experiment in experiment_funcs:
experiment(tensorboard=tensorboard, wandb=wandb)
if __name__ == "__main__":