mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
Minimized model and prepared for testing new loss functions
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -163,3 +163,4 @@ cython_debug/
|
|||||||
|
|
||||||
datasets/
|
datasets/
|
||||||
lightning_logs/
|
lightning_logs/
|
||||||
|
logs/
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
from .ffnn import main
|
from .model import main
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
|
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
||||||
|
main(logger)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torchvision.datasets import Caltech256
|
from torchvision.datasets import Caltech256
|
||||||
from torchvision.transforms import Compose, Lambda, ToTensor
|
from torchvision.transforms import ToTensor
|
||||||
from torch.utils.data import random_split
|
from torch.utils.data import random_split
|
||||||
from torch.utils.data import BatchSampler
|
from torch.utils.data import BatchSampler
|
||||||
|
|
||||||
@@ -20,7 +20,7 @@ def get_dataset(
|
|||||||
}
|
}
|
||||||
_kwargs.update(kwargs)
|
_kwargs.update(kwargs)
|
||||||
ds = dataset(PROJECT_ROOT / "datasets/", download=True, **_kwargs)
|
ds = dataset(PROJECT_ROOT / "datasets/", download=True, **_kwargs)
|
||||||
train, test, val = (
|
train, val, test = (
|
||||||
BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split)
|
BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split)
|
||||||
)
|
)
|
||||||
return train, test, val
|
return train, val, test
|
||||||
|
|||||||
@@ -1,33 +0,0 @@
|
|||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
model = nn.Sequential(
|
|
||||||
nn.Flatten(1, -1),
|
|
||||||
nn.Linear(784, 128),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(128, 64),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(64, 32),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(32, 10),
|
|
||||||
nn.Softmax(dim=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
from torchvision.datasets import QMNIST
|
|
||||||
import lightning as L
|
|
||||||
|
|
||||||
from .dataloader import get_dataset
|
|
||||||
from .trainer import Trainer
|
|
||||||
|
|
||||||
train, test, val = get_dataset(dataset=QMNIST)
|
|
||||||
training_model = Trainer(model)
|
|
||||||
trainer = L.Trainer(max_epochs=10)
|
|
||||||
trainer.fit(model=training_model, train_dataloaders=train, val_dataloaders=val)
|
|
||||||
30
symbolic_nn_tests/model.py
Normal file
30
symbolic_nn_tests/model.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
model = nn.Sequential(
|
||||||
|
nn.Flatten(1, -1),
|
||||||
|
nn.Linear(784, 10),
|
||||||
|
nn.Softmax(dim=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(loss_func=nn.functional.cross_entropy, logger=None):
|
||||||
|
from torchvision.datasets import QMNIST
|
||||||
|
import lightning as L
|
||||||
|
|
||||||
|
from .dataloader import get_dataset
|
||||||
|
from .train import TrainingWrapper
|
||||||
|
|
||||||
|
if logger is None:
|
||||||
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
|
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
||||||
|
|
||||||
|
train, val, test = get_dataset(dataset=QMNIST)
|
||||||
|
lmodel = TrainingWrapper(model)
|
||||||
|
trainer = L.Trainer(max_epochs=5, logger=logger)
|
||||||
|
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -12,25 +12,28 @@ def collate_batch(batch):
|
|||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
class Trainer(L.LightningModule):
|
class TrainingWrapper(L.LightningModule):
|
||||||
def __init__(self, model, loss_func=nn.functional.cross_entropy):
|
def __init__(self, model, loss_func=nn.functional.cross_entropy):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.loss_func = loss_func
|
self.loss_func = loss_func
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def _forward_step(self, batch, batch_idx, label=""):
|
||||||
x, y = collate_batch(batch)
|
x, y = collate_batch(batch)
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x)
|
||||||
loss = self.loss_func(y_pred, y)
|
batch_size = x.shape[0]
|
||||||
self.log("train_loss", loss)
|
loss = self.loss_func(y_pred, nn.functional.one_hot(y).type(torch.float64))
|
||||||
|
acc = torch.sum(y_pred.argmax(dim=1) == y) / batch_size
|
||||||
|
self.log(f"{label}{'_' if label else ''}loss", loss, batch_size=batch_size)
|
||||||
|
self.log(f"{label}{'_' if label else ''}acc", acc, batch_size=batch_size)
|
||||||
|
return loss, acc
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
loss, _ = self._forward_step(batch, batch_idx, label="train")
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
x, y = collate_batch(batch)
|
self._forward_step(batch, batch_idx, label="val")
|
||||||
y_pred = self.model(x)
|
|
||||||
loss = self.loss_func(y_pred, y)
|
|
||||||
self.log("val_loss", loss)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def configure_optimizers(self, optimizer=optim.Adam, *args, **kwargs):
|
def configure_optimizers(self, optimizer=optim.Adam, *args, **kwargs):
|
||||||
_optimizer = optimizer(self.parameters(), *args, **kwargs)
|
_optimizer = optimizer(self.parameters(), *args, **kwargs)
|
||||||
Reference in New Issue
Block a user