Fixed wandb logging mistake

This commit is contained in:
2024-05-15 18:13:28 +01:00
parent 4212c543f8
commit 8a87e864fb
2 changed files with 8 additions and 2 deletions

View File

@@ -5,15 +5,21 @@ def run_test(loss_func, version):
from .model import main as test_model
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.loggers import WandbLogger
import wandb
tb_logger = TensorBoardLogger(
save_dir=".",
name="logs/comparison",
version=version,
)
wandb_logger = WandbLogger(project="Semantic_Loss_Tests")
wandb_logger = WandbLogger(
project="Symbolic_NN_Tests",
name=version,
dir="wandb",
)
logger = [tb_logger, wandb_logger]
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
wandb.finish()
def main():

View File

@@ -32,7 +32,7 @@ def main(loss_func=nn.functional.cross_entropy, logger=None, **kwargs):
train, val, test = get_singleton_dataset()
lmodel = TrainingWrapper(model, loss_func=loss_func)
lmodel.configure_optimizers(**kwargs)
trainer = L.Trainer(max_epochs=20, logger=logger)
trainer = L.Trainer(max_epochs=1, logger=logger)
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
trainer.test(dataloaders=test)