diff --git a/.gitignore b/.gitignore index 550a371..995bb6b 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,4 @@ cython_debug/ datasets/ lightning_logs/ logs/ +wandb/ diff --git a/symbolic_nn_tests/__main__.py b/symbolic_nn_tests/__main__.py index ef2412a..f308fc5 100644 --- a/symbolic_nn_tests/__main__.py +++ b/symbolic_nn_tests/__main__.py @@ -4,14 +4,16 @@ LEARNING_RATE = 10e-5 def run_test(loss_func, version): from .model import main as test_model from lightning.pytorch.loggers import TensorBoardLogger + from lightning.pytorch.loggers import WandbLogger - logger = TensorBoardLogger( + tb_logger = TensorBoardLogger( save_dir=".", name="logs/comparison", version=version, ) - test_model(lr=LEARNING_RATE) - # test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE) + wandb_logger = WandbLogger(project="MNIST") + logger = [tb_logger, wandb_logger] + test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE) def main(): @@ -19,10 +21,10 @@ def main(): from torch import nn run_test(nn.functional.cross_entropy, "cross_entropy") - # run_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy") - # run_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy") - # run_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy") - # run_test(semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy") + run_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy") + run_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy") + run_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy") + run_test(semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy") if __name__ == "__main__":