Added accuracy to expt1 metrics

This commit is contained in:
2024-09-11 17:13:51 +01:00
parent 1454a701e0
commit 9a32c8fd04
2 changed files with 19 additions and 2 deletions

View File

@@ -1,7 +1,9 @@
LEARNING_RATE = 10e-5
def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True):
def test(
train_loss, val_loss, test_loss, accuracy, version, tensorboard=True, wandb=True
):
from .model import main as test_model
logger = []
@@ -36,6 +38,7 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
train_loss=train_loss,
val_loss=val_loss,
test_loss=test_loss,
accuracy=accuracy,
lr=LEARNING_RATE,
)
@@ -44,6 +47,8 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
def run(tensorboard: bool = True, wandb: bool = True):
from torchmetrics import Accuracy
from . import semantic_loss
from .model import oh_vs_cat_cross_entropy
@@ -51,6 +56,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
train_loss=oh_vs_cat_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
accuracy=Accuracy(task="multiclass", num_classes=10),
version="cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
@@ -59,6 +65,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
train_loss=semantic_loss.similarity_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
accuracy=Accuracy(task="multiclass", num_classes=10),
version="similarity_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
@@ -67,6 +74,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
train_loss=semantic_loss.hasline_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
accuracy=Accuracy(task="multiclass", num_classes=10),
version="hasline_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
@@ -75,6 +83,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
train_loss=semantic_loss.hasloop_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
accuracy=Accuracy(task="multiclass", num_classes=10),
version="hasloop_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
@@ -83,6 +92,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
train_loss=semantic_loss.multisemantic_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
accuracy=Accuracy(task="multiclass", num_classes=10),
version="multisemantic_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
@@ -91,6 +101,7 @@ def run(tensorboard: bool = True, wandb: bool = True):
train_loss=semantic_loss.garbage_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
accuracy=Accuracy(task="multiclass", num_classes=10),
version="garbage_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,

View File

@@ -1,6 +1,7 @@
from functools import lru_cache
import torch
from torch import nn
from torchmetrics import Accuracy
model = nn.Sequential(
@@ -46,6 +47,7 @@ def main(
train_loss=oh_vs_cat_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
accuracy=Accuracy(task="multiclass", num_classes=10),
logger=None,
trainer_callbacks=None,
**kwargs,
@@ -61,7 +63,11 @@ def main(
train, val, test = get_singleton_dataset()
lmodel = TrainingWrapper(
model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss
model,
train_loss=train_loss,
val_loss=val_loss,
test_loss=test_loss,
accuracy=accuracy,
)
lmodel.configure_optimizers(**kwargs)
trainer = L.Trainer(max_epochs=20, logger=logger, callbacks=trainer_callbacks)