diff --git a/.gitignore b/.gitignore index 9972241..527b9a3 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,4 @@ cython_debug/ #.idea/ datasets/ +lightning_logs/ diff --git a/symbolic_nn_tests/__main__.py b/symbolic_nn_tests/__main__.py index 5ed3351..251ba28 100644 --- a/symbolic_nn_tests/__main__.py +++ b/symbolic_nn_tests/__main__.py @@ -1,5 +1,4 @@ -def main(): - print("Working") +from .ffnn import main if __name__ == "__main__": diff --git a/symbolic_nn_tests/dataloader.py b/symbolic_nn_tests/dataloader.py index 3abbc36..103bc3b 100644 --- a/symbolic_nn_tests/dataloader.py +++ b/symbolic_nn_tests/dataloader.py @@ -1,6 +1,6 @@ from pathlib import Path from torchvision.datasets import Caltech256 -from torchvision.transforms import ToTensor +from torchvision.transforms import Compose, Lambda, ToTensor from torch.utils.data import random_split from torch.utils.data import BatchSampler @@ -15,7 +15,9 @@ def get_dataset( drop_last: bool = False, **kwargs, ): - _kwargs = {"transform": ToTensor()} + _kwargs = { + "transform": ToTensor(), + } _kwargs.update(kwargs) ds = dataset(PROJECT_ROOT / "datasets/", download=True, **_kwargs) train, test, val = ( diff --git a/symbolic_nn_tests/ffnn.py b/symbolic_nn_tests/ffnn.py index b4da3d7..436c0b8 100644 --- a/symbolic_nn_tests/ffnn.py +++ b/symbolic_nn_tests/ffnn.py @@ -1,9 +1,33 @@ -from .dataloader import get_dataset +from torch import nn + + +model = nn.Sequential( + nn.Flatten(1, -1), + nn.Linear(784, 128), + nn.ReLU(), + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 64), + nn.ReLU(), + nn.Linear(64, 32), + nn.ReLU(), + nn.Linear(32, 10), + nn.Softmax(dim=1), +) def main(): - pass + from torchvision.datasets import QMNIST + import lightning as L + from .dataloader import get_dataset + from .trainer import Trainer -if __name__ == "__main__": - main() + train, test, val = get_dataset(dataset=QMNIST) + training_model = Trainer(model) + trainer = L.Trainer(max_epochs=10) + trainer.fit(model=training_model, train_dataloaders=train, val_dataloaders=val) diff --git a/symbolic_nn_tests/trainer.py b/symbolic_nn_tests/trainer.py index 3bc781b..bc191fc 100644 --- a/symbolic_nn_tests/trainer.py +++ b/symbolic_nn_tests/trainer.py @@ -1,19 +1,37 @@ +import torch from torch import nn, optim import lightning as L -class Trainer(L.LighningModule): - def __init__(self, model): +def collate_batch(batch): + x, y = zip(*batch) + x = [i[0] for i in x] + y = [torch.tensor(i) for i in y] + x = torch.stack(x).to("cuda") + y = torch.tensor(y).to("cuda") + return x, y + + +class Trainer(L.LightningModule): + def __init__(self, model, loss_func=nn.functional.cross_entropy): super().__init__() self.model = model + self.loss_func = loss_func def training_step(self, batch, batch_idx): - x, y = batch + x, y = collate_batch(batch) y_pred = self.model(x) - loss = nn.functional.mse_loss(y_pred, y) + loss = self.loss_func(y_pred, y) self.log("train_loss", loss) return loss - def configure_optimizers(self): - optimizer = optim.Adam(self.parameters(), lr=1e-3) - return optimizer + def validation_step(self, batch, batch_idx): + x, y = collate_batch(batch) + y_pred = self.model(x) + loss = self.loss_func(y_pred, y) + self.log("val_loss", loss) + return loss + + def configure_optimizers(self, optimizer=optim.Adam, *args, **kwargs): + _optimizer = optimizer(self.parameters(), *args, **kwargs) + return _optimizer