Aded confusion matrices to experiments

This commit is contained in:
2024-09-11 16:52:04 +01:00
parent 546e235b09
commit 1454a701e0
7 changed files with 62 additions and 2 deletions

View File

@@ -0,0 +1,9 @@
from lightning.pytorch.callbacks import Callback
class MyPrintingCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Training is starting")
def on_train_end(self, trainer, pl_module):
print("Training is ending")

View File

@@ -0,0 +1,35 @@
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers.wandb import WandbLogger
import torch
import wandb
def get_wandb_logger(loggers: list) -> WandbLogger:
for logger in loggers:
if isinstance(logger, WandbLogger):
break
return logger
class ConfusionMatrixCallback(Callback):
def __init__(self, class_names=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.class_names = class_names
def on_validation_epoch_end(self, trainer, pl_module):
if trainer.state.stage != "sanity_check":
y_pred = torch.concat(pl_module.epoch_step_preds)
y_pred = torch.argmax(y_pred, axis=1)
y = torch.concat(tuple(map(lambda xy: xy[1], trainer.val_dataloaders)))
logger = get_wandb_logger(trainer.loggers)
logger.experiment.log(
{
f"confusion_matrix_epoch_{trainer.current_epoch}": wandb.plot.confusion_matrix(
probs=None,
y_true=y.numpy(),
preds=y_pred.numpy(),
class_names=self.class_names,
title=f"confusion_matrix_epoch_{trainer.current_epoch}",
)
}
)

View File

@@ -5,6 +5,7 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
from .model import main as test_model
logger = []
callbacks = []
if tensorboard:
from lightning.pytorch.loggers import TensorBoardLogger
@@ -19,6 +20,7 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
if wandb:
import wandb as _wandb
from lightning.pytorch.loggers import WandbLogger
from symbolic_nn_tests.callbacks.wandb import ConfusionMatrixCallback
wandb_logger = WandbLogger(
project="Symbolic_NN_Tests",
@@ -26,9 +28,11 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
dir="wandb",
)
logger.append(wandb_logger)
callbacks.append(ConfusionMatrixCallback(class_names=list(map(int, range(10)))))
test_model(
logger=logger,
trainer_callbacks=callbacks,
train_loss=train_loss,
val_loss=val_loss,
test_loss=test_loss,

View File

@@ -47,6 +47,7 @@ def main(
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
logger=None,
trainer_callbacks=None,
**kwargs,
):
import lightning as L
@@ -63,7 +64,7 @@ def main(
model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss
)
lmodel.configure_optimizers(**kwargs)
trainer = L.Trainer(max_epochs=20, logger=logger)
trainer = L.Trainer(max_epochs=20, logger=logger, callbacks=trainer_callbacks)
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
trainer.test(dataloaders=test)

View File

@@ -13,6 +13,7 @@ def test(
from .model import main as test_model
logger = []
callbacks = []
if tensorboard:
from lightning.pytorch.loggers import TensorBoardLogger
@@ -27,6 +28,7 @@ def test(
if wandb:
import wandb as _wandb
from lightning.pytorch.loggers import WandbLogger
from symbolic_nn_tests.callbacks.wandb import ConfusionMatrixCallback
if isinstance(wandb, WandbLogger):
wandb_logger = wandb
@@ -38,9 +40,11 @@ def test(
log_model="all",
)
logger.append(wandb_logger)
callbacks.append(ConfusionMatrixCallback(class_names=list(map(int, range(10)))))
test_model(
logger=logger,
trainer_callbacks=callbacks,
train_loss=train_loss,
val_loss=val_loss,
test_loss=test_loss,

View File

@@ -67,6 +67,7 @@ def main(
val_loss=unpacking_smooth_l1_loss,
test_loss=unpacking_smooth_l1_loss,
logger=None,
trainer_callbacks=None,
semantic_trainer=False,
**kwargs,
):
@@ -90,7 +91,7 @@ def main(
test_loss=test_loss,
)
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
trainer = L.Trainer(max_epochs=5, logger=logger)
trainer = L.Trainer(max_epochs=5, logger=logger, callbacks=trainer_callbacks)
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
trainer.test(dataloaders=test)

View File

@@ -17,11 +17,17 @@ class TrainingWrapper(L.LightningModule):
self.val_loss = val_loss
self.test_loss = val_loss
self.accuracy = accuracy
self.epoch_step_preds = []
def _forward_step(self, batch, batch_idx, loss_func, label=""):
x, y = batch
y_pred = self.model(x)
loss = loss_func(y_pred, y)
# Add tracking of y_pred for each step in RAM (for more advanced plots)
if batch_idx == 0:
self.epoch_step_preds = []
self.epoch_step_preds.append(y_pred.cpu())
# Add enhanced logging for more granularity
self.log(f"{label}{'_' if label else ''}loss", loss)
if self.accuracy is not None:
acc = self.accuracy(y_pred, y)