mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 14:11:59 +00:00
Fixed similarity matrix normalization
This commit is contained in:
6
poetry.lock
generated
6
poetry.lock
generated
@@ -1637,13 +1637,13 @@ test = ["coverage", "pytest", "pytest-cov"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mako"
|
name = "mako"
|
||||||
version = "1.3.3"
|
version = "1.3.5"
|
||||||
description = "A super-fast templating language that borrows the best ideas from the existing templating languages."
|
description = "A super-fast templating language that borrows the best ideas from the existing templating languages."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "Mako-1.3.3-py3-none-any.whl", hash = "sha256:5324b88089a8978bf76d1629774fcc2f1c07b82acdf00f4c5dd8ceadfffc4b40"},
|
{file = "Mako-1.3.5-py3-none-any.whl", hash = "sha256:260f1dbc3a519453a9c856dedfe4beb4e50bd5a26d96386cb6c80856556bb91a"},
|
||||||
{file = "Mako-1.3.3.tar.gz", hash = "sha256:e16c01d9ab9c11f7290eef1cfefc093fb5a45ee4a3da09e2fec2e4d1bae54e73"},
|
{file = "Mako-1.3.5.tar.gz", hash = "sha256:48dbc20568c1d276a2698b36d968fa76161bf127194907ea6fc594fa81f943bc"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
|||||||
@@ -1,8 +1,23 @@
|
|||||||
from .model import main
|
def main():
|
||||||
|
from .model import main as test_model
|
||||||
|
from . import semantic_loss
|
||||||
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
logger = TensorBoardLogger(
|
||||||
|
save_dir=".",
|
||||||
|
name="logs/comparison",
|
||||||
|
version="cross_entropy",
|
||||||
|
)
|
||||||
|
test_model(logger=logger, loss_func=nn.functional.cross_entropy)
|
||||||
|
|
||||||
|
logger = TensorBoardLogger(
|
||||||
|
save_dir=".",
|
||||||
|
name="logs/comparison",
|
||||||
|
version="similarity_weighted_cross_entropy",
|
||||||
|
)
|
||||||
|
test_model(logger=logger, loss_func=semantic_loss.similarity_weighted_cross_entropy)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
main()
|
||||||
|
|
||||||
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
|
||||||
main(logger)
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from functools import lru_cache
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
@@ -8,11 +9,19 @@ model = nn.Sequential(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(loss_func=nn.functional.cross_entropy, logger=None):
|
# This is just a quick, lazy way to ensure all models are trained on the same dataset
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_singleton_dataset():
|
||||||
from torchvision.datasets import QMNIST
|
from torchvision.datasets import QMNIST
|
||||||
import lightning as L
|
|
||||||
|
|
||||||
from .dataloader import get_dataset
|
from .dataloader import get_dataset
|
||||||
|
|
||||||
|
return get_dataset(dataset=QMNIST)
|
||||||
|
|
||||||
|
|
||||||
|
def main(loss_func=nn.functional.cross_entropy, logger=None):
|
||||||
|
import lightning as L
|
||||||
|
|
||||||
from .train import TrainingWrapper
|
from .train import TrainingWrapper
|
||||||
|
|
||||||
if logger is None:
|
if logger is None:
|
||||||
@@ -20,8 +29,8 @@ def main(loss_func=nn.functional.cross_entropy, logger=None):
|
|||||||
|
|
||||||
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
||||||
|
|
||||||
train, val, test = get_dataset(dataset=QMNIST)
|
train, val, test = get_singleton_dataset()
|
||||||
lmodel = TrainingWrapper(model)
|
lmodel = TrainingWrapper(model, loss_func=loss_func)
|
||||||
trainer = L.Trainer(max_epochs=5, logger=logger)
|
trainer = L.Trainer(max_epochs=5, logger=logger)
|
||||||
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||||
|
|
||||||
|
|||||||
33
symbolic_nn_tests/semantic_loss.py
Normal file
33
symbolic_nn_tests/semantic_loss.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: This similarity matrix defines loss scaling factors for misclassification
|
||||||
|
# of numbers from our QMNIST dataset. Visually similar numbers (e.g: 3/8) are
|
||||||
|
# penalised less harshly than visually distinct numbers as this mistake is "less
|
||||||
|
# mistaken" given our understanding of the visual characteristics of numerals.
|
||||||
|
# By using this scaling matric we can inject human knowledge into the model via
|
||||||
|
# the loss function, making this an example of a "semantic loss function"
|
||||||
|
SIMILARITY_MATRIX = torch.tensor(
|
||||||
|
[
|
||||||
|
[2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
||||||
|
[1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.5, 1.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 2.0, 1.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
|
||||||
|
]
|
||||||
|
).to("cuda")
|
||||||
|
SIMILARITY_MATRIX /= SIMILARITY_MATRIX.sum() # Normalized to sum of 1
|
||||||
|
|
||||||
|
|
||||||
|
def similarity_weighted_cross_entropy(input, target):
|
||||||
|
ce_loss = torch.nn.functional.cross_entropy(input, target)
|
||||||
|
|
||||||
|
penalty_tensor = SIMILARITY_MATRIX[target.argmax(dim=1)]
|
||||||
|
similarity = (target - input).abs()
|
||||||
|
similarity_penalty = (similarity * penalty_tensor).sum()
|
||||||
|
return ce_loss * similarity_penalty
|
||||||
@@ -22,7 +22,8 @@ class TrainingWrapper(L.LightningModule):
|
|||||||
x, y = collate_batch(batch)
|
x, y = collate_batch(batch)
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x)
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
loss = self.loss_func(y_pred, nn.functional.one_hot(y).type(torch.float64))
|
one_hot_y = nn.functional.one_hot(y).type(torch.float64)
|
||||||
|
loss = self.loss_func(y_pred, one_hot_y)
|
||||||
acc = torch.sum(y_pred.argmax(dim=1) == y) / batch_size
|
acc = torch.sum(y_pred.argmax(dim=1) == y) / batch_size
|
||||||
self.log(f"{label}{'_' if label else ''}loss", loss, batch_size=batch_size)
|
self.log(f"{label}{'_' if label else ''}loss", loss, batch_size=batch_size)
|
||||||
self.log(f"{label}{'_' if label else ''}acc", acc, batch_size=batch_size)
|
self.log(f"{label}{'_' if label else ''}acc", acc, batch_size=batch_size)
|
||||||
|
|||||||
Reference in New Issue
Block a user