mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-23 06:32:05 +00:00
Added wandb logging
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -164,3 +164,4 @@ cython_debug/
|
|||||||
datasets/
|
datasets/
|
||||||
lightning_logs/
|
lightning_logs/
|
||||||
logs/
|
logs/
|
||||||
|
wandb/
|
||||||
|
|||||||
@@ -4,14 +4,16 @@ LEARNING_RATE = 10e-5
|
|||||||
def run_test(loss_func, version):
|
def run_test(loss_func, version):
|
||||||
from .model import main as test_model
|
from .model import main as test_model
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
from lightning.pytorch.loggers import WandbLogger
|
||||||
|
|
||||||
logger = TensorBoardLogger(
|
tb_logger = TensorBoardLogger(
|
||||||
save_dir=".",
|
save_dir=".",
|
||||||
name="logs/comparison",
|
name="logs/comparison",
|
||||||
version=version,
|
version=version,
|
||||||
)
|
)
|
||||||
test_model(lr=LEARNING_RATE)
|
wandb_logger = WandbLogger(project="MNIST")
|
||||||
# test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
|
logger = [tb_logger, wandb_logger]
|
||||||
|
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -19,10 +21,10 @@ def main():
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
run_test(nn.functional.cross_entropy, "cross_entropy")
|
run_test(nn.functional.cross_entropy, "cross_entropy")
|
||||||
# run_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy")
|
run_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy")
|
||||||
# run_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy")
|
run_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy")
|
||||||
# run_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy")
|
run_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy")
|
||||||
# run_test(semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy")
|
run_test(semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user