First functional training loop

This commit is contained in:
2024-05-13 18:13:05 +01:00
committed by Cian-H
parent 674175b260
commit 708f9952b3
5 changed files with 59 additions and 15 deletions

1
.gitignore vendored
View File

@@ -162,3 +162,4 @@ cython_debug/
#.idea/ #.idea/
datasets/ datasets/
lightning_logs/

View File

@@ -1,5 +1,4 @@
def main(): from .ffnn import main
print("Working")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,6 +1,6 @@
from pathlib import Path from pathlib import Path
from torchvision.datasets import Caltech256 from torchvision.datasets import Caltech256
from torchvision.transforms import ToTensor from torchvision.transforms import Compose, Lambda, ToTensor
from torch.utils.data import random_split from torch.utils.data import random_split
from torch.utils.data import BatchSampler from torch.utils.data import BatchSampler
@@ -15,7 +15,9 @@ def get_dataset(
drop_last: bool = False, drop_last: bool = False,
**kwargs, **kwargs,
): ):
_kwargs = {"transform": ToTensor()} _kwargs = {
"transform": ToTensor(),
}
_kwargs.update(kwargs) _kwargs.update(kwargs)
ds = dataset(PROJECT_ROOT / "datasets/", download=True, **_kwargs) ds = dataset(PROJECT_ROOT / "datasets/", download=True, **_kwargs)
train, test, val = ( train, test, val = (

View File

@@ -1,9 +1,33 @@
from .dataloader import get_dataset from torch import nn
model = nn.Sequential(
nn.Flatten(1, -1),
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 10),
nn.Softmax(dim=1),
)
def main(): def main():
pass from torchvision.datasets import QMNIST
import lightning as L
from .dataloader import get_dataset
from .trainer import Trainer
if __name__ == "__main__": train, test, val = get_dataset(dataset=QMNIST)
main() training_model = Trainer(model)
trainer = L.Trainer(max_epochs=10)
trainer.fit(model=training_model, train_dataloaders=train, val_dataloaders=val)

View File

@@ -1,19 +1,37 @@
import torch
from torch import nn, optim from torch import nn, optim
import lightning as L import lightning as L
class Trainer(L.LighningModule): def collate_batch(batch):
def __init__(self, model): x, y = zip(*batch)
x = [i[0] for i in x]
y = [torch.tensor(i) for i in y]
x = torch.stack(x).to("cuda")
y = torch.tensor(y).to("cuda")
return x, y
class Trainer(L.LightningModule):
def __init__(self, model, loss_func=nn.functional.cross_entropy):
super().__init__() super().__init__()
self.model = model self.model = model
self.loss_func = loss_func
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
x, y = batch x, y = collate_batch(batch)
y_pred = self.model(x) y_pred = self.model(x)
loss = nn.functional.mse_loss(y_pred, y) loss = self.loss_func(y_pred, y)
self.log("train_loss", loss) self.log("train_loss", loss)
return loss return loss
def configure_optimizers(self): def validation_step(self, batch, batch_idx):
optimizer = optim.Adam(self.parameters(), lr=1e-3) x, y = collate_batch(batch)
return optimizer y_pred = self.model(x)
loss = self.loss_func(y_pred, y)
self.log("val_loss", loss)
return loss
def configure_optimizers(self, optimizer=optim.Adam, *args, **kwargs):
_optimizer = optimizer(self.parameters(), *args, **kwargs)
return _optimizer