mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
Aded confusion matrices to experiments
This commit is contained in:
9
src/symbolic_nn_tests/callbacks/tensorboard.py
Normal file
9
src/symbolic_nn_tests/callbacks/tensorboard.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from lightning.pytorch.callbacks import Callback
|
||||||
|
|
||||||
|
|
||||||
|
class MyPrintingCallback(Callback):
|
||||||
|
def on_train_start(self, trainer, pl_module):
|
||||||
|
print("Training is starting")
|
||||||
|
|
||||||
|
def on_train_end(self, trainer, pl_module):
|
||||||
|
print("Training is ending")
|
||||||
35
src/symbolic_nn_tests/callbacks/wandb.py
Normal file
35
src/symbolic_nn_tests/callbacks/wandb.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from lightning.pytorch.callbacks import Callback
|
||||||
|
from lightning.pytorch.loggers.wandb import WandbLogger
|
||||||
|
import torch
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
|
def get_wandb_logger(loggers: list) -> WandbLogger:
|
||||||
|
for logger in loggers:
|
||||||
|
if isinstance(logger, WandbLogger):
|
||||||
|
break
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrixCallback(Callback):
|
||||||
|
def __init__(self, class_names=None, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.class_names = class_names
|
||||||
|
|
||||||
|
def on_validation_epoch_end(self, trainer, pl_module):
|
||||||
|
if trainer.state.stage != "sanity_check":
|
||||||
|
y_pred = torch.concat(pl_module.epoch_step_preds)
|
||||||
|
y_pred = torch.argmax(y_pred, axis=1)
|
||||||
|
y = torch.concat(tuple(map(lambda xy: xy[1], trainer.val_dataloaders)))
|
||||||
|
logger = get_wandb_logger(trainer.loggers)
|
||||||
|
logger.experiment.log(
|
||||||
|
{
|
||||||
|
f"confusion_matrix_epoch_{trainer.current_epoch}": wandb.plot.confusion_matrix(
|
||||||
|
probs=None,
|
||||||
|
y_true=y.numpy(),
|
||||||
|
preds=y_pred.numpy(),
|
||||||
|
class_names=self.class_names,
|
||||||
|
title=f"confusion_matrix_epoch_{trainer.current_epoch}",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
@@ -5,6 +5,7 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
|
|||||||
from .model import main as test_model
|
from .model import main as test_model
|
||||||
|
|
||||||
logger = []
|
logger = []
|
||||||
|
callbacks = []
|
||||||
|
|
||||||
if tensorboard:
|
if tensorboard:
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
@@ -19,6 +20,7 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
|
|||||||
if wandb:
|
if wandb:
|
||||||
import wandb as _wandb
|
import wandb as _wandb
|
||||||
from lightning.pytorch.loggers import WandbLogger
|
from lightning.pytorch.loggers import WandbLogger
|
||||||
|
from symbolic_nn_tests.callbacks.wandb import ConfusionMatrixCallback
|
||||||
|
|
||||||
wandb_logger = WandbLogger(
|
wandb_logger = WandbLogger(
|
||||||
project="Symbolic_NN_Tests",
|
project="Symbolic_NN_Tests",
|
||||||
@@ -26,9 +28,11 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
|
|||||||
dir="wandb",
|
dir="wandb",
|
||||||
)
|
)
|
||||||
logger.append(wandb_logger)
|
logger.append(wandb_logger)
|
||||||
|
callbacks.append(ConfusionMatrixCallback(class_names=list(map(int, range(10)))))
|
||||||
|
|
||||||
test_model(
|
test_model(
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
trainer_callbacks=callbacks,
|
||||||
train_loss=train_loss,
|
train_loss=train_loss,
|
||||||
val_loss=val_loss,
|
val_loss=val_loss,
|
||||||
test_loss=test_loss,
|
test_loss=test_loss,
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ def main(
|
|||||||
val_loss=oh_vs_cat_cross_entropy,
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
test_loss=oh_vs_cat_cross_entropy,
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
logger=None,
|
logger=None,
|
||||||
|
trainer_callbacks=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
import lightning as L
|
import lightning as L
|
||||||
@@ -63,7 +64,7 @@ def main(
|
|||||||
model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss
|
model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss
|
||||||
)
|
)
|
||||||
lmodel.configure_optimizers(**kwargs)
|
lmodel.configure_optimizers(**kwargs)
|
||||||
trainer = L.Trainer(max_epochs=20, logger=logger)
|
trainer = L.Trainer(max_epochs=20, logger=logger, callbacks=trainer_callbacks)
|
||||||
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||||
trainer.test(dataloaders=test)
|
trainer.test(dataloaders=test)
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ def test(
|
|||||||
from .model import main as test_model
|
from .model import main as test_model
|
||||||
|
|
||||||
logger = []
|
logger = []
|
||||||
|
callbacks = []
|
||||||
|
|
||||||
if tensorboard:
|
if tensorboard:
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
@@ -27,6 +28,7 @@ def test(
|
|||||||
if wandb:
|
if wandb:
|
||||||
import wandb as _wandb
|
import wandb as _wandb
|
||||||
from lightning.pytorch.loggers import WandbLogger
|
from lightning.pytorch.loggers import WandbLogger
|
||||||
|
from symbolic_nn_tests.callbacks.wandb import ConfusionMatrixCallback
|
||||||
|
|
||||||
if isinstance(wandb, WandbLogger):
|
if isinstance(wandb, WandbLogger):
|
||||||
wandb_logger = wandb
|
wandb_logger = wandb
|
||||||
@@ -38,9 +40,11 @@ def test(
|
|||||||
log_model="all",
|
log_model="all",
|
||||||
)
|
)
|
||||||
logger.append(wandb_logger)
|
logger.append(wandb_logger)
|
||||||
|
callbacks.append(ConfusionMatrixCallback(class_names=list(map(int, range(10)))))
|
||||||
|
|
||||||
test_model(
|
test_model(
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
trainer_callbacks=callbacks,
|
||||||
train_loss=train_loss,
|
train_loss=train_loss,
|
||||||
val_loss=val_loss,
|
val_loss=val_loss,
|
||||||
test_loss=test_loss,
|
test_loss=test_loss,
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ def main(
|
|||||||
val_loss=unpacking_smooth_l1_loss,
|
val_loss=unpacking_smooth_l1_loss,
|
||||||
test_loss=unpacking_smooth_l1_loss,
|
test_loss=unpacking_smooth_l1_loss,
|
||||||
logger=None,
|
logger=None,
|
||||||
|
trainer_callbacks=None,
|
||||||
semantic_trainer=False,
|
semantic_trainer=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -90,7 +91,7 @@ def main(
|
|||||||
test_loss=test_loss,
|
test_loss=test_loss,
|
||||||
)
|
)
|
||||||
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
|
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
|
||||||
trainer = L.Trainer(max_epochs=5, logger=logger)
|
trainer = L.Trainer(max_epochs=5, logger=logger, callbacks=trainer_callbacks)
|
||||||
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||||
trainer.test(dataloaders=test)
|
trainer.test(dataloaders=test)
|
||||||
|
|
||||||
|
|||||||
@@ -17,11 +17,17 @@ class TrainingWrapper(L.LightningModule):
|
|||||||
self.val_loss = val_loss
|
self.val_loss = val_loss
|
||||||
self.test_loss = val_loss
|
self.test_loss = val_loss
|
||||||
self.accuracy = accuracy
|
self.accuracy = accuracy
|
||||||
|
self.epoch_step_preds = []
|
||||||
|
|
||||||
def _forward_step(self, batch, batch_idx, loss_func, label=""):
|
def _forward_step(self, batch, batch_idx, loss_func, label=""):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x)
|
||||||
loss = loss_func(y_pred, y)
|
loss = loss_func(y_pred, y)
|
||||||
|
# Add tracking of y_pred for each step in RAM (for more advanced plots)
|
||||||
|
if batch_idx == 0:
|
||||||
|
self.epoch_step_preds = []
|
||||||
|
self.epoch_step_preds.append(y_pred.cpu())
|
||||||
|
# Add enhanced logging for more granularity
|
||||||
self.log(f"{label}{'_' if label else ''}loss", loss)
|
self.log(f"{label}{'_' if label else ''}loss", loss)
|
||||||
if self.accuracy is not None:
|
if self.accuracy is not None:
|
||||||
acc = self.accuracy(y_pred, y)
|
acc = self.accuracy(y_pred, y)
|
||||||
|
|||||||
Reference in New Issue
Block a user