mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
Aded confusion matrices to experiments
This commit is contained in:
9
src/symbolic_nn_tests/callbacks/tensorboard.py
Normal file
9
src/symbolic_nn_tests/callbacks/tensorboard.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
|
||||
|
||||
class MyPrintingCallback(Callback):
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
print("Training is starting")
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
print("Training is ending")
|
||||
35
src/symbolic_nn_tests/callbacks/wandb.py
Normal file
35
src/symbolic_nn_tests/callbacks/wandb.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from lightning.pytorch.loggers.wandb import WandbLogger
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
|
||||
def get_wandb_logger(loggers: list) -> WandbLogger:
|
||||
for logger in loggers:
|
||||
if isinstance(logger, WandbLogger):
|
||||
break
|
||||
return logger
|
||||
|
||||
|
||||
class ConfusionMatrixCallback(Callback):
|
||||
def __init__(self, class_names=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.class_names = class_names
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
if trainer.state.stage != "sanity_check":
|
||||
y_pred = torch.concat(pl_module.epoch_step_preds)
|
||||
y_pred = torch.argmax(y_pred, axis=1)
|
||||
y = torch.concat(tuple(map(lambda xy: xy[1], trainer.val_dataloaders)))
|
||||
logger = get_wandb_logger(trainer.loggers)
|
||||
logger.experiment.log(
|
||||
{
|
||||
f"confusion_matrix_epoch_{trainer.current_epoch}": wandb.plot.confusion_matrix(
|
||||
probs=None,
|
||||
y_true=y.numpy(),
|
||||
preds=y_pred.numpy(),
|
||||
class_names=self.class_names,
|
||||
title=f"confusion_matrix_epoch_{trainer.current_epoch}",
|
||||
)
|
||||
}
|
||||
)
|
||||
@@ -5,6 +5,7 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
|
||||
from .model import main as test_model
|
||||
|
||||
logger = []
|
||||
callbacks = []
|
||||
|
||||
if tensorboard:
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
@@ -19,6 +20,7 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
|
||||
if wandb:
|
||||
import wandb as _wandb
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
from symbolic_nn_tests.callbacks.wandb import ConfusionMatrixCallback
|
||||
|
||||
wandb_logger = WandbLogger(
|
||||
project="Symbolic_NN_Tests",
|
||||
@@ -26,9 +28,11 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
|
||||
dir="wandb",
|
||||
)
|
||||
logger.append(wandb_logger)
|
||||
callbacks.append(ConfusionMatrixCallback(class_names=list(map(int, range(10)))))
|
||||
|
||||
test_model(
|
||||
logger=logger,
|
||||
trainer_callbacks=callbacks,
|
||||
train_loss=train_loss,
|
||||
val_loss=val_loss,
|
||||
test_loss=test_loss,
|
||||
|
||||
@@ -47,6 +47,7 @@ def main(
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
logger=None,
|
||||
trainer_callbacks=None,
|
||||
**kwargs,
|
||||
):
|
||||
import lightning as L
|
||||
@@ -63,7 +64,7 @@ def main(
|
||||
model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss
|
||||
)
|
||||
lmodel.configure_optimizers(**kwargs)
|
||||
trainer = L.Trainer(max_epochs=20, logger=logger)
|
||||
trainer = L.Trainer(max_epochs=20, logger=logger, callbacks=trainer_callbacks)
|
||||
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||
trainer.test(dataloaders=test)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ def test(
|
||||
from .model import main as test_model
|
||||
|
||||
logger = []
|
||||
callbacks = []
|
||||
|
||||
if tensorboard:
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
@@ -27,6 +28,7 @@ def test(
|
||||
if wandb:
|
||||
import wandb as _wandb
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
from symbolic_nn_tests.callbacks.wandb import ConfusionMatrixCallback
|
||||
|
||||
if isinstance(wandb, WandbLogger):
|
||||
wandb_logger = wandb
|
||||
@@ -38,9 +40,11 @@ def test(
|
||||
log_model="all",
|
||||
)
|
||||
logger.append(wandb_logger)
|
||||
callbacks.append(ConfusionMatrixCallback(class_names=list(map(int, range(10)))))
|
||||
|
||||
test_model(
|
||||
logger=logger,
|
||||
trainer_callbacks=callbacks,
|
||||
train_loss=train_loss,
|
||||
val_loss=val_loss,
|
||||
test_loss=test_loss,
|
||||
|
||||
@@ -67,6 +67,7 @@ def main(
|
||||
val_loss=unpacking_smooth_l1_loss,
|
||||
test_loss=unpacking_smooth_l1_loss,
|
||||
logger=None,
|
||||
trainer_callbacks=None,
|
||||
semantic_trainer=False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -90,7 +91,7 @@ def main(
|
||||
test_loss=test_loss,
|
||||
)
|
||||
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
|
||||
trainer = L.Trainer(max_epochs=5, logger=logger)
|
||||
trainer = L.Trainer(max_epochs=5, logger=logger, callbacks=trainer_callbacks)
|
||||
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||
trainer.test(dataloaders=test)
|
||||
|
||||
|
||||
@@ -17,11 +17,17 @@ class TrainingWrapper(L.LightningModule):
|
||||
self.val_loss = val_loss
|
||||
self.test_loss = val_loss
|
||||
self.accuracy = accuracy
|
||||
self.epoch_step_preds = []
|
||||
|
||||
def _forward_step(self, batch, batch_idx, loss_func, label=""):
|
||||
x, y = batch
|
||||
y_pred = self.model(x)
|
||||
loss = loss_func(y_pred, y)
|
||||
# Add tracking of y_pred for each step in RAM (for more advanced plots)
|
||||
if batch_idx == 0:
|
||||
self.epoch_step_preds = []
|
||||
self.epoch_step_preds.append(y_pred.cpu())
|
||||
# Add enhanced logging for more granularity
|
||||
self.log(f"{label}{'_' if label else ''}loss", loss)
|
||||
if self.accuracy is not None:
|
||||
acc = self.accuracy(y_pred, y)
|
||||
|
||||
Reference in New Issue
Block a user