From 8a87e864fb71bb1733c5acc394a5c2ab819590c8 Mon Sep 17 00:00:00 2001 From: Cian Hughes Date: Wed, 15 May 2024 18:13:28 +0100 Subject: [PATCH] Fixed wandb logging mistake --- symbolic_nn_tests/__main__.py | 8 +++++++- symbolic_nn_tests/model.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/symbolic_nn_tests/__main__.py b/symbolic_nn_tests/__main__.py index 1a88eb1..ac6d79b 100644 --- a/symbolic_nn_tests/__main__.py +++ b/symbolic_nn_tests/__main__.py @@ -5,15 +5,21 @@ def run_test(loss_func, version): from .model import main as test_model from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import WandbLogger + import wandb tb_logger = TensorBoardLogger( save_dir=".", name="logs/comparison", version=version, ) - wandb_logger = WandbLogger(project="Semantic_Loss_Tests") + wandb_logger = WandbLogger( + project="Symbolic_NN_Tests", + name=version, + dir="wandb", + ) logger = [tb_logger, wandb_logger] test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE) + wandb.finish() def main(): diff --git a/symbolic_nn_tests/model.py b/symbolic_nn_tests/model.py index 285c33d..4849bc8 100644 --- a/symbolic_nn_tests/model.py +++ b/symbolic_nn_tests/model.py @@ -32,7 +32,7 @@ def main(loss_func=nn.functional.cross_entropy, logger=None, **kwargs): train, val, test = get_singleton_dataset() lmodel = TrainingWrapper(model, loss_func=loss_func) lmodel.configure_optimizers(**kwargs) - trainer = L.Trainer(max_epochs=20, logger=logger) + trainer = L.Trainer(max_epochs=1, logger=logger) trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val) trainer.test(dataloaders=test)