Made loggers togglable from top level

This commit is contained in:
2024-05-23 09:21:48 +01:00
parent 1be0fa1020
commit 02b9e43f4d

View File

@@ -1,25 +1,36 @@
LEARNING_RATE = 10e-5 LEARNING_RATE = 10e-5
def qmnist_test(loss_func, version): def qmnist_test(loss_func, version, tensorboard=True, wandb=True):
from .experiment_1.model import main as test_model from .experiment_1.model import main as test_model
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.loggers import WandbLogger
import wandb
tb_logger = TensorBoardLogger( logger = []
save_dir=".",
name="logs/comparison", if tensorboard:
version=version, from lightning.pytorch.loggers import TensorBoardLogger
)
wandb_logger = WandbLogger( tb_logger = TensorBoardLogger(
project="Symbolic_NN_Tests", save_dir=".",
name=version, name="logs/comparison",
dir="wandb", version=version,
) )
logger = [tb_logger, wandb_logger] logger.append(tb_logger)
if wandb:
import wandb as _wandb
from lightning.pytorch.loggers import WandbLogger
wandb_logger = WandbLogger(
project="Symbolic_NN_Tests",
name=version,
dir="wandb",
)
logger.append(wandb_logger)
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE) test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
wandb.finish()
if wandb:
_wandb.finish()
def qmnist_experiment(): def qmnist_experiment():