From 66b23a8e7679894ad98e11d0b28b73e545f00833 Mon Sep 17 00:00:00 2001 From: Cian-H Date: Tue, 14 May 2024 13:04:52 +0100 Subject: [PATCH] Minimized model and prepared for testing new loss functions --- .gitignore | 1 + symbolic_nn_tests/__main__.py | 7 +++-- symbolic_nn_tests/dataloader.py | 6 ++-- symbolic_nn_tests/ffnn.py | 33 ---------------------- symbolic_nn_tests/model.py | 30 ++++++++++++++++++++ symbolic_nn_tests/{trainer.py => train.py} | 21 ++++++++------ 6 files changed, 51 insertions(+), 47 deletions(-) delete mode 100644 symbolic_nn_tests/ffnn.py create mode 100644 symbolic_nn_tests/model.py rename symbolic_nn_tests/{trainer.py => train.py} (56%) diff --git a/.gitignore b/.gitignore index 527b9a3..550a371 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,4 @@ cython_debug/ datasets/ lightning_logs/ +logs/ diff --git a/symbolic_nn_tests/__main__.py b/symbolic_nn_tests/__main__.py index 251ba28..48b7ea0 100644 --- a/symbolic_nn_tests/__main__.py +++ b/symbolic_nn_tests/__main__.py @@ -1,5 +1,8 @@ -from .ffnn import main +from .model import main if __name__ == "__main__": - main() + from lightning.pytorch.loggers import TensorBoardLogger + + logger = TensorBoardLogger(save_dir=".", name="logs/ffnn") + main(logger) diff --git a/symbolic_nn_tests/dataloader.py b/symbolic_nn_tests/dataloader.py index 103bc3b..5ad0142 100644 --- a/symbolic_nn_tests/dataloader.py +++ b/symbolic_nn_tests/dataloader.py @@ -1,6 +1,6 @@ from pathlib import Path from torchvision.datasets import Caltech256 -from torchvision.transforms import Compose, Lambda, ToTensor +from torchvision.transforms import ToTensor from torch.utils.data import random_split from torch.utils.data import BatchSampler @@ -20,7 +20,7 @@ def get_dataset( } _kwargs.update(kwargs) ds = dataset(PROJECT_ROOT / "datasets/", download=True, **_kwargs) - train, test, val = ( + train, val, test = ( BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split) ) - return train, test, val + return train, val, test diff --git a/symbolic_nn_tests/ffnn.py b/symbolic_nn_tests/ffnn.py deleted file mode 100644 index 436c0b8..0000000 --- a/symbolic_nn_tests/ffnn.py +++ /dev/null @@ -1,33 +0,0 @@ -from torch import nn - - -model = nn.Sequential( - nn.Flatten(1, -1), - nn.Linear(784, 128), - nn.ReLU(), - nn.Linear(128, 128), - nn.ReLU(), - nn.Linear(128, 128), - nn.ReLU(), - nn.Linear(128, 128), - nn.ReLU(), - nn.Linear(128, 64), - nn.ReLU(), - nn.Linear(64, 32), - nn.ReLU(), - nn.Linear(32, 10), - nn.Softmax(dim=1), -) - - -def main(): - from torchvision.datasets import QMNIST - import lightning as L - - from .dataloader import get_dataset - from .trainer import Trainer - - train, test, val = get_dataset(dataset=QMNIST) - training_model = Trainer(model) - trainer = L.Trainer(max_epochs=10) - trainer.fit(model=training_model, train_dataloaders=train, val_dataloaders=val) diff --git a/symbolic_nn_tests/model.py b/symbolic_nn_tests/model.py new file mode 100644 index 0000000..2c28591 --- /dev/null +++ b/symbolic_nn_tests/model.py @@ -0,0 +1,30 @@ +from torch import nn + + +model = nn.Sequential( + nn.Flatten(1, -1), + nn.Linear(784, 10), + nn.Softmax(dim=1), +) + + +def main(loss_func=nn.functional.cross_entropy, logger=None): + from torchvision.datasets import QMNIST + import lightning as L + + from .dataloader import get_dataset + from .train import TrainingWrapper + + if logger is None: + from lightning.pytorch.loggers import TensorBoardLogger + + logger = TensorBoardLogger(save_dir=".", name="logs/ffnn") + + train, val, test = get_dataset(dataset=QMNIST) + lmodel = TrainingWrapper(model) + trainer = L.Trainer(max_epochs=5, logger=logger) + trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val) + + +if __name__ == "__main__": + main() diff --git a/symbolic_nn_tests/trainer.py b/symbolic_nn_tests/train.py similarity index 56% rename from symbolic_nn_tests/trainer.py rename to symbolic_nn_tests/train.py index bc191fc..d4f77d9 100644 --- a/symbolic_nn_tests/trainer.py +++ b/symbolic_nn_tests/train.py @@ -12,25 +12,28 @@ def collate_batch(batch): return x, y -class Trainer(L.LightningModule): +class TrainingWrapper(L.LightningModule): def __init__(self, model, loss_func=nn.functional.cross_entropy): super().__init__() self.model = model self.loss_func = loss_func - def training_step(self, batch, batch_idx): + def _forward_step(self, batch, batch_idx, label=""): x, y = collate_batch(batch) y_pred = self.model(x) - loss = self.loss_func(y_pred, y) - self.log("train_loss", loss) + batch_size = x.shape[0] + loss = self.loss_func(y_pred, nn.functional.one_hot(y).type(torch.float64)) + acc = torch.sum(y_pred.argmax(dim=1) == y) / batch_size + self.log(f"{label}{'_' if label else ''}loss", loss, batch_size=batch_size) + self.log(f"{label}{'_' if label else ''}acc", acc, batch_size=batch_size) + return loss, acc + + def training_step(self, batch, batch_idx): + loss, _ = self._forward_step(batch, batch_idx, label="train") return loss def validation_step(self, batch, batch_idx): - x, y = collate_batch(batch) - y_pred = self.model(x) - loss = self.loss_func(y_pred, y) - self.log("val_loss", loss) - return loss + self._forward_step(batch, batch_idx, label="val") def configure_optimizers(self, optimizer=optim.Adam, *args, **kwargs): _optimizer = optimizer(self.parameters(), *args, **kwargs)