mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
First functional training loop
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -162,3 +162,4 @@ cython_debug/
|
|||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
datasets/
|
datasets/
|
||||||
|
lightning_logs/
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
def main():
|
from .ffnn import main
|
||||||
print("Working")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torchvision.datasets import Caltech256
|
from torchvision.datasets import Caltech256
|
||||||
from torchvision.transforms import ToTensor
|
from torchvision.transforms import Compose, Lambda, ToTensor
|
||||||
from torch.utils.data import random_split
|
from torch.utils.data import random_split
|
||||||
from torch.utils.data import BatchSampler
|
from torch.utils.data import BatchSampler
|
||||||
|
|
||||||
@@ -15,7 +15,9 @@ def get_dataset(
|
|||||||
drop_last: bool = False,
|
drop_last: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
_kwargs = {"transform": ToTensor()}
|
_kwargs = {
|
||||||
|
"transform": ToTensor(),
|
||||||
|
}
|
||||||
_kwargs.update(kwargs)
|
_kwargs.update(kwargs)
|
||||||
ds = dataset(PROJECT_ROOT / "datasets/", download=True, **_kwargs)
|
ds = dataset(PROJECT_ROOT / "datasets/", download=True, **_kwargs)
|
||||||
train, test, val = (
|
train, test, val = (
|
||||||
|
|||||||
@@ -1,9 +1,33 @@
|
|||||||
from .dataloader import get_dataset
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
model = nn.Sequential(
|
||||||
|
nn.Flatten(1, -1),
|
||||||
|
nn.Linear(784, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(128, 64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(64, 32),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(32, 10),
|
||||||
|
nn.Softmax(dim=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
pass
|
from torchvision.datasets import QMNIST
|
||||||
|
import lightning as L
|
||||||
|
|
||||||
|
from .dataloader import get_dataset
|
||||||
|
from .trainer import Trainer
|
||||||
|
|
||||||
if __name__ == "__main__":
|
train, test, val = get_dataset(dataset=QMNIST)
|
||||||
main()
|
training_model = Trainer(model)
|
||||||
|
trainer = L.Trainer(max_epochs=10)
|
||||||
|
trainer.fit(model=training_model, train_dataloaders=train, val_dataloaders=val)
|
||||||
|
|||||||
@@ -1,19 +1,37 @@
|
|||||||
|
import torch
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
import lightning as L
|
import lightning as L
|
||||||
|
|
||||||
|
|
||||||
class Trainer(L.LighningModule):
|
def collate_batch(batch):
|
||||||
def __init__(self, model):
|
x, y = zip(*batch)
|
||||||
|
x = [i[0] for i in x]
|
||||||
|
y = [torch.tensor(i) for i in y]
|
||||||
|
x = torch.stack(x).to("cuda")
|
||||||
|
y = torch.tensor(y).to("cuda")
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer(L.LightningModule):
|
||||||
|
def __init__(self, model, loss_func=nn.functional.cross_entropy):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.loss_func = loss_func
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = collate_batch(batch)
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x)
|
||||||
loss = nn.functional.mse_loss(y_pred, y)
|
loss = self.loss_func(y_pred, y)
|
||||||
self.log("train_loss", loss)
|
self.log("train_loss", loss)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def validation_step(self, batch, batch_idx):
|
||||||
optimizer = optim.Adam(self.parameters(), lr=1e-3)
|
x, y = collate_batch(batch)
|
||||||
return optimizer
|
y_pred = self.model(x)
|
||||||
|
loss = self.loss_func(y_pred, y)
|
||||||
|
self.log("val_loss", loss)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def configure_optimizers(self, optimizer=optim.Adam, *args, **kwargs):
|
||||||
|
_optimizer = optimizer(self.parameters(), *args, **kwargs)
|
||||||
|
return _optimizer
|
||||||
|
|||||||
Reference in New Issue
Block a user