mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
First functional training loop
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -162,3 +162,4 @@ cython_debug/
|
||||
#.idea/
|
||||
|
||||
datasets/
|
||||
lightning_logs/
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
def main():
|
||||
print("Working")
|
||||
from .ffnn import main
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pathlib import Path
|
||||
from torchvision.datasets import Caltech256
|
||||
from torchvision.transforms import ToTensor
|
||||
from torchvision.transforms import Compose, Lambda, ToTensor
|
||||
from torch.utils.data import random_split
|
||||
from torch.utils.data import BatchSampler
|
||||
|
||||
@@ -15,7 +15,9 @@ def get_dataset(
|
||||
drop_last: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
_kwargs = {"transform": ToTensor()}
|
||||
_kwargs = {
|
||||
"transform": ToTensor(),
|
||||
}
|
||||
_kwargs.update(kwargs)
|
||||
ds = dataset(PROJECT_ROOT / "datasets/", download=True, **_kwargs)
|
||||
train, test, val = (
|
||||
|
||||
@@ -1,9 +1,33 @@
|
||||
from .dataloader import get_dataset
|
||||
from torch import nn
|
||||
|
||||
|
||||
model = nn.Sequential(
|
||||
nn.Flatten(1, -1),
|
||||
nn.Linear(784, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, 10),
|
||||
nn.Softmax(dim=1),
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
pass
|
||||
from torchvision.datasets import QMNIST
|
||||
import lightning as L
|
||||
|
||||
from .dataloader import get_dataset
|
||||
from .trainer import Trainer
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
train, test, val = get_dataset(dataset=QMNIST)
|
||||
training_model = Trainer(model)
|
||||
trainer = L.Trainer(max_epochs=10)
|
||||
trainer.fit(model=training_model, train_dataloaders=train, val_dataloaders=val)
|
||||
|
||||
@@ -1,19 +1,37 @@
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
import lightning as L
|
||||
|
||||
|
||||
class Trainer(L.LighningModule):
|
||||
def __init__(self, model):
|
||||
def collate_batch(batch):
|
||||
x, y = zip(*batch)
|
||||
x = [i[0] for i in x]
|
||||
y = [torch.tensor(i) for i in y]
|
||||
x = torch.stack(x).to("cuda")
|
||||
y = torch.tensor(y).to("cuda")
|
||||
return x, y
|
||||
|
||||
|
||||
class Trainer(L.LightningModule):
|
||||
def __init__(self, model, loss_func=nn.functional.cross_entropy):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.loss_func = loss_func
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
x, y = collate_batch(batch)
|
||||
y_pred = self.model(x)
|
||||
loss = nn.functional.mse_loss(y_pred, y)
|
||||
loss = self.loss_func(y_pred, y)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = optim.Adam(self.parameters(), lr=1e-3)
|
||||
return optimizer
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = collate_batch(batch)
|
||||
y_pred = self.model(x)
|
||||
loss = self.loss_func(y_pred, y)
|
||||
self.log("val_loss", loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self, optimizer=optim.Adam, *args, **kwargs):
|
||||
_optimizer = optimizer(self.parameters(), *args, **kwargs)
|
||||
return _optimizer
|
||||
|
||||
Reference in New Issue
Block a user