Added wandb logging

This commit is contained in:
2024-05-15 13:37:29 +01:00
parent 6600a79f71
commit 15a57cc229
2 changed files with 10 additions and 7 deletions

1
.gitignore vendored
View File

@@ -164,3 +164,4 @@ cython_debug/
datasets/ datasets/
lightning_logs/ lightning_logs/
logs/ logs/
wandb/

View File

@@ -4,14 +4,16 @@ LEARNING_RATE = 10e-5
def run_test(loss_func, version): def run_test(loss_func, version):
from .model import main as test_model from .model import main as test_model
from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.loggers import WandbLogger
logger = TensorBoardLogger( tb_logger = TensorBoardLogger(
save_dir=".", save_dir=".",
name="logs/comparison", name="logs/comparison",
version=version, version=version,
) )
test_model(lr=LEARNING_RATE) wandb_logger = WandbLogger(project="MNIST")
# test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE) logger = [tb_logger, wandb_logger]
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
def main(): def main():
@@ -19,10 +21,10 @@ def main():
from torch import nn from torch import nn
run_test(nn.functional.cross_entropy, "cross_entropy") run_test(nn.functional.cross_entropy, "cross_entropy")
# run_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy") run_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy")
# run_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy") run_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy")
# run_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy") run_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy")
# run_test(semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy") run_test(semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy")
if __name__ == "__main__": if __name__ == "__main__":