diff --git a/symbolic_nn_tests/__main__.py b/symbolic_nn_tests/__main__.py index fcf65ad..a798059 100644 --- a/symbolic_nn_tests/__main__.py +++ b/symbolic_nn_tests/__main__.py @@ -1,54 +1,25 @@ -LEARNING_RATE = 10e-5 +from typing import Optional, Iterable +import experiment1 -def qmnist_test(loss_func, version, tensorboard=True, wandb=True): - from .experiment_1.model import main as test_model - - logger = [] - - if tensorboard: - from lightning.pytorch.loggers import TensorBoardLogger - - tb_logger = TensorBoardLogger( - save_dir=".", - name="logs/comparison", - version=version, - ) - logger.append(tb_logger) - - if wandb: - import wandb as _wandb - from lightning.pytorch.loggers import WandbLogger - - wandb_logger = WandbLogger( - project="Symbolic_NN_Tests", - name=version, - dir="wandb", - ) - logger.append(wandb_logger) - - test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE) - - if wandb: - _wandb.finish() +EXPERIMENTS = (experiment1,) -def qmnist_experiment(): - from .experiment_1 import semantic_loss - from torch import nn +def main( + experiments: Optional[int | Iterable[int]] = None, + tensorboard: bool = True, + wandb: bool = True, +): + if experiments is None: + experiments = range(1, len(EXPERIMENTS) + 1) + elif not isinstance(experiments, Iterable): + experiments = (experiments,) - qmnist_test(nn.functional.cross_entropy, "cross_entropy") - qmnist_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy") - qmnist_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy") - qmnist_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy") - qmnist_test( - semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy" - ) - qmnist_test(semantic_loss.garbage_cross_entropy, "garbage_cross_entropy") + experiment_indeces = (i - 1 for i in experiments) + experiment_funcs = [EXPERIMENTS[i].run for i in experiment_indeces] - -def main(): - qmnist_experiment() + for experiment in experiment_funcs: + experiment(tensorboard=tensorboard, wandb=wandb) if __name__ == "__main__":