mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-23 06:32:05 +00:00
Added accuracy to expt1 metrics
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
LEARNING_RATE = 10e-5
|
LEARNING_RATE = 10e-5
|
||||||
|
|
||||||
|
|
||||||
def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True):
|
def test(
|
||||||
|
train_loss, val_loss, test_loss, accuracy, version, tensorboard=True, wandb=True
|
||||||
|
):
|
||||||
from .model import main as test_model
|
from .model import main as test_model
|
||||||
|
|
||||||
logger = []
|
logger = []
|
||||||
@@ -36,6 +38,7 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
|
|||||||
train_loss=train_loss,
|
train_loss=train_loss,
|
||||||
val_loss=val_loss,
|
val_loss=val_loss,
|
||||||
test_loss=test_loss,
|
test_loss=test_loss,
|
||||||
|
accuracy=accuracy,
|
||||||
lr=LEARNING_RATE,
|
lr=LEARNING_RATE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -44,6 +47,8 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
|
|||||||
|
|
||||||
|
|
||||||
def run(tensorboard: bool = True, wandb: bool = True):
|
def run(tensorboard: bool = True, wandb: bool = True):
|
||||||
|
from torchmetrics import Accuracy
|
||||||
|
|
||||||
from . import semantic_loss
|
from . import semantic_loss
|
||||||
from .model import oh_vs_cat_cross_entropy
|
from .model import oh_vs_cat_cross_entropy
|
||||||
|
|
||||||
@@ -51,6 +56,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
|
|||||||
train_loss=oh_vs_cat_cross_entropy,
|
train_loss=oh_vs_cat_cross_entropy,
|
||||||
val_loss=oh_vs_cat_cross_entropy,
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
test_loss=oh_vs_cat_cross_entropy,
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
accuracy=Accuracy(task="multiclass", num_classes=10),
|
||||||
version="cross_entropy",
|
version="cross_entropy",
|
||||||
tensorboard=tensorboard,
|
tensorboard=tensorboard,
|
||||||
wandb=wandb,
|
wandb=wandb,
|
||||||
@@ -59,6 +65,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
|
|||||||
train_loss=semantic_loss.similarity_cross_entropy,
|
train_loss=semantic_loss.similarity_cross_entropy,
|
||||||
val_loss=oh_vs_cat_cross_entropy,
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
test_loss=oh_vs_cat_cross_entropy,
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
accuracy=Accuracy(task="multiclass", num_classes=10),
|
||||||
version="similarity_cross_entropy",
|
version="similarity_cross_entropy",
|
||||||
tensorboard=tensorboard,
|
tensorboard=tensorboard,
|
||||||
wandb=wandb,
|
wandb=wandb,
|
||||||
@@ -67,6 +74,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
|
|||||||
train_loss=semantic_loss.hasline_cross_entropy,
|
train_loss=semantic_loss.hasline_cross_entropy,
|
||||||
val_loss=oh_vs_cat_cross_entropy,
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
test_loss=oh_vs_cat_cross_entropy,
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
accuracy=Accuracy(task="multiclass", num_classes=10),
|
||||||
version="hasline_cross_entropy",
|
version="hasline_cross_entropy",
|
||||||
tensorboard=tensorboard,
|
tensorboard=tensorboard,
|
||||||
wandb=wandb,
|
wandb=wandb,
|
||||||
@@ -75,6 +83,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
|
|||||||
train_loss=semantic_loss.hasloop_cross_entropy,
|
train_loss=semantic_loss.hasloop_cross_entropy,
|
||||||
val_loss=oh_vs_cat_cross_entropy,
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
test_loss=oh_vs_cat_cross_entropy,
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
accuracy=Accuracy(task="multiclass", num_classes=10),
|
||||||
version="hasloop_cross_entropy",
|
version="hasloop_cross_entropy",
|
||||||
tensorboard=tensorboard,
|
tensorboard=tensorboard,
|
||||||
wandb=wandb,
|
wandb=wandb,
|
||||||
@@ -83,6 +92,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
|
|||||||
train_loss=semantic_loss.multisemantic_cross_entropy,
|
train_loss=semantic_loss.multisemantic_cross_entropy,
|
||||||
val_loss=oh_vs_cat_cross_entropy,
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
test_loss=oh_vs_cat_cross_entropy,
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
accuracy=Accuracy(task="multiclass", num_classes=10),
|
||||||
version="multisemantic_cross_entropy",
|
version="multisemantic_cross_entropy",
|
||||||
tensorboard=tensorboard,
|
tensorboard=tensorboard,
|
||||||
wandb=wandb,
|
wandb=wandb,
|
||||||
@@ -91,6 +101,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
|
|||||||
train_loss=semantic_loss.garbage_cross_entropy,
|
train_loss=semantic_loss.garbage_cross_entropy,
|
||||||
val_loss=oh_vs_cat_cross_entropy,
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
test_loss=oh_vs_cat_cross_entropy,
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
accuracy=Accuracy(task="multiclass", num_classes=10),
|
||||||
version="garbage_cross_entropy",
|
version="garbage_cross_entropy",
|
||||||
tensorboard=tensorboard,
|
tensorboard=tensorboard,
|
||||||
wandb=wandb,
|
wandb=wandb,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torchmetrics import Accuracy
|
||||||
|
|
||||||
|
|
||||||
model = nn.Sequential(
|
model = nn.Sequential(
|
||||||
@@ -46,6 +47,7 @@ def main(
|
|||||||
train_loss=oh_vs_cat_cross_entropy,
|
train_loss=oh_vs_cat_cross_entropy,
|
||||||
val_loss=oh_vs_cat_cross_entropy,
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
test_loss=oh_vs_cat_cross_entropy,
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
accuracy=Accuracy(task="multiclass", num_classes=10),
|
||||||
logger=None,
|
logger=None,
|
||||||
trainer_callbacks=None,
|
trainer_callbacks=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -61,7 +63,11 @@ def main(
|
|||||||
|
|
||||||
train, val, test = get_singleton_dataset()
|
train, val, test = get_singleton_dataset()
|
||||||
lmodel = TrainingWrapper(
|
lmodel = TrainingWrapper(
|
||||||
model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss
|
model,
|
||||||
|
train_loss=train_loss,
|
||||||
|
val_loss=val_loss,
|
||||||
|
test_loss=test_loss,
|
||||||
|
accuracy=accuracy,
|
||||||
)
|
)
|
||||||
lmodel.configure_optimizers(**kwargs)
|
lmodel.configure_optimizers(**kwargs)
|
||||||
trainer = L.Trainer(max_epochs=20, logger=logger, callbacks=trainer_callbacks)
|
trainer = L.Trainer(max_epochs=20, logger=logger, callbacks=trainer_callbacks)
|
||||||
|
|||||||
Reference in New Issue
Block a user