From 1454a701e050e49577f089038c5520c4ecd4ccdc Mon Sep 17 00:00:00 2001 From: Cian Hughes Date: Wed, 11 Sep 2024 16:52:04 +0100 Subject: [PATCH] Aded confusion matrices to experiments --- .../callbacks/tensorboard.py | 9 +++++ src/symbolic_nn_tests/callbacks/wandb.py | 35 +++++++++++++++++++ src/symbolic_nn_tests/experiment1/__init__.py | 4 +++ src/symbolic_nn_tests/experiment1/model.py | 3 +- src/symbolic_nn_tests/experiment2/__init__.py | 4 +++ src/symbolic_nn_tests/experiment2/model.py | 3 +- src/symbolic_nn_tests/train.py | 6 ++++ 7 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 src/symbolic_nn_tests/callbacks/tensorboard.py create mode 100644 src/symbolic_nn_tests/callbacks/wandb.py diff --git a/src/symbolic_nn_tests/callbacks/tensorboard.py b/src/symbolic_nn_tests/callbacks/tensorboard.py new file mode 100644 index 0000000..758b6d3 --- /dev/null +++ b/src/symbolic_nn_tests/callbacks/tensorboard.py @@ -0,0 +1,9 @@ +from lightning.pytorch.callbacks import Callback + + +class MyPrintingCallback(Callback): + def on_train_start(self, trainer, pl_module): + print("Training is starting") + + def on_train_end(self, trainer, pl_module): + print("Training is ending") diff --git a/src/symbolic_nn_tests/callbacks/wandb.py b/src/symbolic_nn_tests/callbacks/wandb.py new file mode 100644 index 0000000..0022e40 --- /dev/null +++ b/src/symbolic_nn_tests/callbacks/wandb.py @@ -0,0 +1,35 @@ +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.loggers.wandb import WandbLogger +import torch +import wandb + + +def get_wandb_logger(loggers: list) -> WandbLogger: + for logger in loggers: + if isinstance(logger, WandbLogger): + break + return logger + + +class ConfusionMatrixCallback(Callback): + def __init__(self, class_names=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.class_names = class_names + + def on_validation_epoch_end(self, trainer, pl_module): + if trainer.state.stage != "sanity_check": + y_pred = torch.concat(pl_module.epoch_step_preds) + y_pred = torch.argmax(y_pred, axis=1) + y = torch.concat(tuple(map(lambda xy: xy[1], trainer.val_dataloaders))) + logger = get_wandb_logger(trainer.loggers) + logger.experiment.log( + { + f"confusion_matrix_epoch_{trainer.current_epoch}": wandb.plot.confusion_matrix( + probs=None, + y_true=y.numpy(), + preds=y_pred.numpy(), + class_names=self.class_names, + title=f"confusion_matrix_epoch_{trainer.current_epoch}", + ) + } + ) diff --git a/src/symbolic_nn_tests/experiment1/__init__.py b/src/symbolic_nn_tests/experiment1/__init__.py index 4c9d141..dccaf1a 100644 --- a/src/symbolic_nn_tests/experiment1/__init__.py +++ b/src/symbolic_nn_tests/experiment1/__init__.py @@ -5,6 +5,7 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True) from .model import main as test_model logger = [] + callbacks = [] if tensorboard: from lightning.pytorch.loggers import TensorBoardLogger @@ -19,6 +20,7 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True) if wandb: import wandb as _wandb from lightning.pytorch.loggers import WandbLogger + from symbolic_nn_tests.callbacks.wandb import ConfusionMatrixCallback wandb_logger = WandbLogger( project="Symbolic_NN_Tests", @@ -26,9 +28,11 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True) dir="wandb", ) logger.append(wandb_logger) + callbacks.append(ConfusionMatrixCallback(class_names=list(map(int, range(10))))) test_model( logger=logger, + trainer_callbacks=callbacks, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss, diff --git a/src/symbolic_nn_tests/experiment1/model.py b/src/symbolic_nn_tests/experiment1/model.py index 45e9080..de9aea2 100644 --- a/src/symbolic_nn_tests/experiment1/model.py +++ b/src/symbolic_nn_tests/experiment1/model.py @@ -47,6 +47,7 @@ def main( val_loss=oh_vs_cat_cross_entropy, test_loss=oh_vs_cat_cross_entropy, logger=None, + trainer_callbacks=None, **kwargs, ): import lightning as L @@ -63,7 +64,7 @@ def main( model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss ) lmodel.configure_optimizers(**kwargs) - trainer = L.Trainer(max_epochs=20, logger=logger) + trainer = L.Trainer(max_epochs=20, logger=logger, callbacks=trainer_callbacks) trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val) trainer.test(dataloaders=test) diff --git a/src/symbolic_nn_tests/experiment2/__init__.py b/src/symbolic_nn_tests/experiment2/__init__.py index 9bf85df..7cd4280 100644 --- a/src/symbolic_nn_tests/experiment2/__init__.py +++ b/src/symbolic_nn_tests/experiment2/__init__.py @@ -13,6 +13,7 @@ def test( from .model import main as test_model logger = [] + callbacks = [] if tensorboard: from lightning.pytorch.loggers import TensorBoardLogger @@ -27,6 +28,7 @@ def test( if wandb: import wandb as _wandb from lightning.pytorch.loggers import WandbLogger + from symbolic_nn_tests.callbacks.wandb import ConfusionMatrixCallback if isinstance(wandb, WandbLogger): wandb_logger = wandb @@ -38,9 +40,11 @@ def test( log_model="all", ) logger.append(wandb_logger) + callbacks.append(ConfusionMatrixCallback(class_names=list(map(int, range(10))))) test_model( logger=logger, + trainer_callbacks=callbacks, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss, diff --git a/src/symbolic_nn_tests/experiment2/model.py b/src/symbolic_nn_tests/experiment2/model.py index 1277f1b..cbc9c6d 100644 --- a/src/symbolic_nn_tests/experiment2/model.py +++ b/src/symbolic_nn_tests/experiment2/model.py @@ -67,6 +67,7 @@ def main( val_loss=unpacking_smooth_l1_loss, test_loss=unpacking_smooth_l1_loss, logger=None, + trainer_callbacks=None, semantic_trainer=False, **kwargs, ): @@ -90,7 +91,7 @@ def main( test_loss=test_loss, ) lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs) - trainer = L.Trainer(max_epochs=5, logger=logger) + trainer = L.Trainer(max_epochs=5, logger=logger, callbacks=trainer_callbacks) trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val) trainer.test(dataloaders=test) diff --git a/src/symbolic_nn_tests/train.py b/src/symbolic_nn_tests/train.py index b1dc604..04f1c7e 100644 --- a/src/symbolic_nn_tests/train.py +++ b/src/symbolic_nn_tests/train.py @@ -17,11 +17,17 @@ class TrainingWrapper(L.LightningModule): self.val_loss = val_loss self.test_loss = val_loss self.accuracy = accuracy + self.epoch_step_preds = [] def _forward_step(self, batch, batch_idx, loss_func, label=""): x, y = batch y_pred = self.model(x) loss = loss_func(y_pred, y) + # Add tracking of y_pred for each step in RAM (for more advanced plots) + if batch_idx == 0: + self.epoch_step_preds = [] + self.epoch_step_preds.append(y_pred.cpu()) + # Add enhanced logging for more granularity self.log(f"{label}{'_' if label else ''}loss", loss) if self.accuracy is not None: acc = self.accuracy(y_pred, y)