mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
Added accuracy to expt1 metrics
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
LEARNING_RATE = 10e-5
|
||||
|
||||
|
||||
def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True):
|
||||
def test(
|
||||
train_loss, val_loss, test_loss, accuracy, version, tensorboard=True, wandb=True
|
||||
):
|
||||
from .model import main as test_model
|
||||
|
||||
logger = []
|
||||
@@ -36,6 +38,7 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
|
||||
train_loss=train_loss,
|
||||
val_loss=val_loss,
|
||||
test_loss=test_loss,
|
||||
accuracy=accuracy,
|
||||
lr=LEARNING_RATE,
|
||||
)
|
||||
|
||||
@@ -44,6 +47,8 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
|
||||
|
||||
|
||||
def run(tensorboard: bool = True, wandb: bool = True):
|
||||
from torchmetrics import Accuracy
|
||||
|
||||
from . import semantic_loss
|
||||
from .model import oh_vs_cat_cross_entropy
|
||||
|
||||
@@ -51,6 +56,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
|
||||
train_loss=oh_vs_cat_cross_entropy,
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
accuracy=Accuracy(task="multiclass", num_classes=10),
|
||||
version="cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
@@ -59,6 +65,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
|
||||
train_loss=semantic_loss.similarity_cross_entropy,
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
accuracy=Accuracy(task="multiclass", num_classes=10),
|
||||
version="similarity_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
@@ -67,6 +74,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
|
||||
train_loss=semantic_loss.hasline_cross_entropy,
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
accuracy=Accuracy(task="multiclass", num_classes=10),
|
||||
version="hasline_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
@@ -75,6 +83,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
|
||||
train_loss=semantic_loss.hasloop_cross_entropy,
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
accuracy=Accuracy(task="multiclass", num_classes=10),
|
||||
version="hasloop_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
@@ -83,6 +92,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
|
||||
train_loss=semantic_loss.multisemantic_cross_entropy,
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
accuracy=Accuracy(task="multiclass", num_classes=10),
|
||||
version="multisemantic_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
@@ -91,6 +101,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
|
||||
train_loss=semantic_loss.garbage_cross_entropy,
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
accuracy=Accuracy(task="multiclass", num_classes=10),
|
||||
version="garbage_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from functools import lru_cache
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchmetrics import Accuracy
|
||||
|
||||
|
||||
model = nn.Sequential(
|
||||
@@ -46,6 +47,7 @@ def main(
|
||||
train_loss=oh_vs_cat_cross_entropy,
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
accuracy=Accuracy(task="multiclass", num_classes=10),
|
||||
logger=None,
|
||||
trainer_callbacks=None,
|
||||
**kwargs,
|
||||
@@ -61,7 +63,11 @@ def main(
|
||||
|
||||
train, val, test = get_singleton_dataset()
|
||||
lmodel = TrainingWrapper(
|
||||
model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss
|
||||
model,
|
||||
train_loss=train_loss,
|
||||
val_loss=val_loss,
|
||||
test_loss=test_loss,
|
||||
accuracy=accuracy,
|
||||
)
|
||||
lmodel.configure_optimizers(**kwargs)
|
||||
trainer = L.Trainer(max_epochs=20, logger=logger, callbacks=trainer_callbacks)
|
||||
|
||||
Reference in New Issue
Block a user