diff --git a/symbolic_nn_tests/__main__.py b/symbolic_nn_tests/__main__.py index f35f926..fcf65ad 100644 --- a/symbolic_nn_tests/__main__.py +++ b/symbolic_nn_tests/__main__.py @@ -1,25 +1,36 @@ LEARNING_RATE = 10e-5 -def qmnist_test(loss_func, version): +def qmnist_test(loss_func, version, tensorboard=True, wandb=True): from .experiment_1.model import main as test_model - from lightning.pytorch.loggers import TensorBoardLogger - from lightning.pytorch.loggers import WandbLogger - import wandb - tb_logger = TensorBoardLogger( - save_dir=".", - name="logs/comparison", - version=version, - ) - wandb_logger = WandbLogger( - project="Symbolic_NN_Tests", - name=version, - dir="wandb", - ) - logger = [tb_logger, wandb_logger] + logger = [] + + if tensorboard: + from lightning.pytorch.loggers import TensorBoardLogger + + tb_logger = TensorBoardLogger( + save_dir=".", + name="logs/comparison", + version=version, + ) + logger.append(tb_logger) + + if wandb: + import wandb as _wandb + from lightning.pytorch.loggers import WandbLogger + + wandb_logger = WandbLogger( + project="Symbolic_NN_Tests", + name=version, + dir="wandb", + ) + logger.append(wandb_logger) + test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE) - wandb.finish() + + if wandb: + _wandb.finish() def qmnist_experiment():