Minimized model and prepared for testing new loss functions

This commit is contained in:
2024-05-14 13:04:52 +01:00
parent 708f9952b3
commit 66b23a8e76
6 changed files with 51 additions and 47 deletions

1
.gitignore vendored
View File

@@ -163,3 +163,4 @@ cython_debug/
datasets/ datasets/
lightning_logs/ lightning_logs/
logs/

View File

@@ -1,5 +1,8 @@
from .ffnn import main from .model import main
if __name__ == "__main__": if __name__ == "__main__":
main() from lightning.pytorch.loggers import TensorBoardLogger
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
main(logger)

View File

@@ -1,6 +1,6 @@
from pathlib import Path from pathlib import Path
from torchvision.datasets import Caltech256 from torchvision.datasets import Caltech256
from torchvision.transforms import Compose, Lambda, ToTensor from torchvision.transforms import ToTensor
from torch.utils.data import random_split from torch.utils.data import random_split
from torch.utils.data import BatchSampler from torch.utils.data import BatchSampler
@@ -20,7 +20,7 @@ def get_dataset(
} }
_kwargs.update(kwargs) _kwargs.update(kwargs)
ds = dataset(PROJECT_ROOT / "datasets/", download=True, **_kwargs) ds = dataset(PROJECT_ROOT / "datasets/", download=True, **_kwargs)
train, test, val = ( train, val, test = (
BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split) BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split)
) )
return train, test, val return train, val, test

View File

@@ -1,33 +0,0 @@
from torch import nn
model = nn.Sequential(
nn.Flatten(1, -1),
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 10),
nn.Softmax(dim=1),
)
def main():
from torchvision.datasets import QMNIST
import lightning as L
from .dataloader import get_dataset
from .trainer import Trainer
train, test, val = get_dataset(dataset=QMNIST)
training_model = Trainer(model)
trainer = L.Trainer(max_epochs=10)
trainer.fit(model=training_model, train_dataloaders=train, val_dataloaders=val)

View File

@@ -0,0 +1,30 @@
from torch import nn
model = nn.Sequential(
nn.Flatten(1, -1),
nn.Linear(784, 10),
nn.Softmax(dim=1),
)
def main(loss_func=nn.functional.cross_entropy, logger=None):
from torchvision.datasets import QMNIST
import lightning as L
from .dataloader import get_dataset
from .train import TrainingWrapper
if logger is None:
from lightning.pytorch.loggers import TensorBoardLogger
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
train, val, test = get_dataset(dataset=QMNIST)
lmodel = TrainingWrapper(model)
trainer = L.Trainer(max_epochs=5, logger=logger)
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
if __name__ == "__main__":
main()

View File

@@ -12,25 +12,28 @@ def collate_batch(batch):
return x, y return x, y
class Trainer(L.LightningModule): class TrainingWrapper(L.LightningModule):
def __init__(self, model, loss_func=nn.functional.cross_entropy): def __init__(self, model, loss_func=nn.functional.cross_entropy):
super().__init__() super().__init__()
self.model = model self.model = model
self.loss_func = loss_func self.loss_func = loss_func
def training_step(self, batch, batch_idx): def _forward_step(self, batch, batch_idx, label=""):
x, y = collate_batch(batch) x, y = collate_batch(batch)
y_pred = self.model(x) y_pred = self.model(x)
loss = self.loss_func(y_pred, y) batch_size = x.shape[0]
self.log("train_loss", loss) loss = self.loss_func(y_pred, nn.functional.one_hot(y).type(torch.float64))
acc = torch.sum(y_pred.argmax(dim=1) == y) / batch_size
self.log(f"{label}{'_' if label else ''}loss", loss, batch_size=batch_size)
self.log(f"{label}{'_' if label else ''}acc", acc, batch_size=batch_size)
return loss, acc
def training_step(self, batch, batch_idx):
loss, _ = self._forward_step(batch, batch_idx, label="train")
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
x, y = collate_batch(batch) self._forward_step(batch, batch_idx, label="val")
y_pred = self.model(x)
loss = self.loss_func(y_pred, y)
self.log("val_loss", loss)
return loss
def configure_optimizers(self, optimizer=optim.Adam, *args, **kwargs): def configure_optimizers(self, optimizer=optim.Adam, *args, **kwargs):
_optimizer = optimizer(self.parameters(), *args, **kwargs) _optimizer = optimizer(self.parameters(), *args, **kwargs)