diff --git a/src/symbolic_nn_tests/experiment1/__init__.py b/src/symbolic_nn_tests/experiment1/__init__.py index dccaf1a..45c7a2b 100644 --- a/src/symbolic_nn_tests/experiment1/__init__.py +++ b/src/symbolic_nn_tests/experiment1/__init__.py @@ -1,7 +1,9 @@ LEARNING_RATE = 10e-5 -def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True): +def test( + train_loss, val_loss, test_loss, accuracy, version, tensorboard=True, wandb=True +): from .model import main as test_model logger = [] @@ -36,6 +38,7 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True) train_loss=train_loss, val_loss=val_loss, test_loss=test_loss, + accuracy=accuracy, lr=LEARNING_RATE, ) @@ -44,6 +47,8 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True) def run(tensorboard: bool = True, wandb: bool = True): + from torchmetrics import Accuracy + from . import semantic_loss from .model import oh_vs_cat_cross_entropy @@ -51,6 +56,7 @@ def run(tensorboard: bool = True, wandb: bool = True): train_loss=oh_vs_cat_cross_entropy, val_loss=oh_vs_cat_cross_entropy, test_loss=oh_vs_cat_cross_entropy, + accuracy=Accuracy(task="multiclass", num_classes=10), version="cross_entropy", tensorboard=tensorboard, wandb=wandb, @@ -59,6 +65,7 @@ def run(tensorboard: bool = True, wandb: bool = True): train_loss=semantic_loss.similarity_cross_entropy, val_loss=oh_vs_cat_cross_entropy, test_loss=oh_vs_cat_cross_entropy, + accuracy=Accuracy(task="multiclass", num_classes=10), version="similarity_cross_entropy", tensorboard=tensorboard, wandb=wandb, @@ -67,6 +74,7 @@ def run(tensorboard: bool = True, wandb: bool = True): train_loss=semantic_loss.hasline_cross_entropy, val_loss=oh_vs_cat_cross_entropy, test_loss=oh_vs_cat_cross_entropy, + accuracy=Accuracy(task="multiclass", num_classes=10), version="hasline_cross_entropy", tensorboard=tensorboard, wandb=wandb, @@ -75,6 +83,7 @@ def run(tensorboard: bool = True, wandb: bool = True): train_loss=semantic_loss.hasloop_cross_entropy, val_loss=oh_vs_cat_cross_entropy, test_loss=oh_vs_cat_cross_entropy, + accuracy=Accuracy(task="multiclass", num_classes=10), version="hasloop_cross_entropy", tensorboard=tensorboard, wandb=wandb, @@ -83,6 +92,7 @@ def run(tensorboard: bool = True, wandb: bool = True): train_loss=semantic_loss.multisemantic_cross_entropy, val_loss=oh_vs_cat_cross_entropy, test_loss=oh_vs_cat_cross_entropy, + accuracy=Accuracy(task="multiclass", num_classes=10), version="multisemantic_cross_entropy", tensorboard=tensorboard, wandb=wandb, @@ -91,6 +101,7 @@ def run(tensorboard: bool = True, wandb: bool = True): train_loss=semantic_loss.garbage_cross_entropy, val_loss=oh_vs_cat_cross_entropy, test_loss=oh_vs_cat_cross_entropy, + accuracy=Accuracy(task="multiclass", num_classes=10), version="garbage_cross_entropy", tensorboard=tensorboard, wandb=wandb, diff --git a/src/symbolic_nn_tests/experiment1/model.py b/src/symbolic_nn_tests/experiment1/model.py index de9aea2..3d41c7b 100644 --- a/src/symbolic_nn_tests/experiment1/model.py +++ b/src/symbolic_nn_tests/experiment1/model.py @@ -1,6 +1,7 @@ from functools import lru_cache import torch from torch import nn +from torchmetrics import Accuracy model = nn.Sequential( @@ -46,6 +47,7 @@ def main( train_loss=oh_vs_cat_cross_entropy, val_loss=oh_vs_cat_cross_entropy, test_loss=oh_vs_cat_cross_entropy, + accuracy=Accuracy(task="multiclass", num_classes=10), logger=None, trainer_callbacks=None, **kwargs, @@ -61,7 +63,11 @@ def main( train, val, test = get_singleton_dataset() lmodel = TrainingWrapper( - model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss + model, + train_loss=train_loss, + val_loss=val_loss, + test_loss=test_loss, + accuracy=accuracy, ) lmodel.configure_optimizers(**kwargs) trainer = L.Trainer(max_epochs=20, logger=logger, callbacks=trainer_callbacks)