Fixed similarity matrix normalization

This commit is contained in:
2024-05-14 13:09:17 +01:00
parent 66b23a8e76
commit 389c47ef28
5 changed files with 71 additions and 13 deletions

6
poetry.lock generated
View File

@@ -1637,13 +1637,13 @@ test = ["coverage", "pytest", "pytest-cov"]
[[package]]
name = "mako"
version = "1.3.3"
version = "1.3.5"
description = "A super-fast templating language that borrows the best ideas from the existing templating languages."
optional = false
python-versions = ">=3.8"
files = [
{file = "Mako-1.3.3-py3-none-any.whl", hash = "sha256:5324b88089a8978bf76d1629774fcc2f1c07b82acdf00f4c5dd8ceadfffc4b40"},
{file = "Mako-1.3.3.tar.gz", hash = "sha256:e16c01d9ab9c11f7290eef1cfefc093fb5a45ee4a3da09e2fec2e4d1bae54e73"},
{file = "Mako-1.3.5-py3-none-any.whl", hash = "sha256:260f1dbc3a519453a9c856dedfe4beb4e50bd5a26d96386cb6c80856556bb91a"},
{file = "Mako-1.3.5.tar.gz", hash = "sha256:48dbc20568c1d276a2698b36d968fa76161bf127194907ea6fc594fa81f943bc"},
]
[package.dependencies]

View File

@@ -1,8 +1,23 @@
from .model import main
def main():
from .model import main as test_model
from . import semantic_loss
from lightning.pytorch.loggers import TensorBoardLogger
from torch import nn
logger = TensorBoardLogger(
save_dir=".",
name="logs/comparison",
version="cross_entropy",
)
test_model(logger=logger, loss_func=nn.functional.cross_entropy)
logger = TensorBoardLogger(
save_dir=".",
name="logs/comparison",
version="similarity_weighted_cross_entropy",
)
test_model(logger=logger, loss_func=semantic_loss.similarity_weighted_cross_entropy)
if __name__ == "__main__":
from lightning.pytorch.loggers import TensorBoardLogger
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
main(logger)
main()

View File

@@ -1,3 +1,4 @@
from functools import lru_cache
from torch import nn
@@ -8,11 +9,19 @@ model = nn.Sequential(
)
def main(loss_func=nn.functional.cross_entropy, logger=None):
# This is just a quick, lazy way to ensure all models are trained on the same dataset
@lru_cache(maxsize=1)
def get_singleton_dataset():
from torchvision.datasets import QMNIST
import lightning as L
from .dataloader import get_dataset
return get_dataset(dataset=QMNIST)
def main(loss_func=nn.functional.cross_entropy, logger=None):
import lightning as L
from .train import TrainingWrapper
if logger is None:
@@ -20,8 +29,8 @@ def main(loss_func=nn.functional.cross_entropy, logger=None):
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
train, val, test = get_dataset(dataset=QMNIST)
lmodel = TrainingWrapper(model)
train, val, test = get_singleton_dataset()
lmodel = TrainingWrapper(model, loss_func=loss_func)
trainer = L.Trainer(max_epochs=5, logger=logger)
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)

View File

@@ -0,0 +1,33 @@
import torch
# NOTE: This similarity matrix defines loss scaling factors for misclassification
# of numbers from our QMNIST dataset. Visually similar numbers (e.g: 3/8) are
# penalised less harshly than visually distinct numbers as this mistake is "less
# mistaken" given our understanding of the visual characteristics of numerals.
# By using this scaling matric we can inject human knowledge into the model via
# the loss function, making this an example of a "semantic loss function"
SIMILARITY_MATRIX = torch.tensor(
[
[2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0],
[1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0],
[1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.5, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 2.0, 1.0, 1.0, 1.0],
[1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
]
).to("cuda")
SIMILARITY_MATRIX /= SIMILARITY_MATRIX.sum() # Normalized to sum of 1
def similarity_weighted_cross_entropy(input, target):
ce_loss = torch.nn.functional.cross_entropy(input, target)
penalty_tensor = SIMILARITY_MATRIX[target.argmax(dim=1)]
similarity = (target - input).abs()
similarity_penalty = (similarity * penalty_tensor).sum()
return ce_loss * similarity_penalty

View File

@@ -22,7 +22,8 @@ class TrainingWrapper(L.LightningModule):
x, y = collate_batch(batch)
y_pred = self.model(x)
batch_size = x.shape[0]
loss = self.loss_func(y_pred, nn.functional.one_hot(y).type(torch.float64))
one_hot_y = nn.functional.one_hot(y).type(torch.float64)
loss = self.loss_func(y_pred, one_hot_y)
acc = torch.sum(y_pred.argmax(dim=1) == y) / batch_size
self.log(f"{label}{'_' if label else ''}loss", loss, batch_size=batch_size)
self.log(f"{label}{'_' if label else ''}acc", acc, batch_size=batch_size)