mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
Fixed similarity matrix normalization
This commit is contained in:
6
poetry.lock
generated
6
poetry.lock
generated
@@ -1637,13 +1637,13 @@ test = ["coverage", "pytest", "pytest-cov"]
|
||||
|
||||
[[package]]
|
||||
name = "mako"
|
||||
version = "1.3.3"
|
||||
version = "1.3.5"
|
||||
description = "A super-fast templating language that borrows the best ideas from the existing templating languages."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "Mako-1.3.3-py3-none-any.whl", hash = "sha256:5324b88089a8978bf76d1629774fcc2f1c07b82acdf00f4c5dd8ceadfffc4b40"},
|
||||
{file = "Mako-1.3.3.tar.gz", hash = "sha256:e16c01d9ab9c11f7290eef1cfefc093fb5a45ee4a3da09e2fec2e4d1bae54e73"},
|
||||
{file = "Mako-1.3.5-py3-none-any.whl", hash = "sha256:260f1dbc3a519453a9c856dedfe4beb4e50bd5a26d96386cb6c80856556bb91a"},
|
||||
{file = "Mako-1.3.5.tar.gz", hash = "sha256:48dbc20568c1d276a2698b36d968fa76161bf127194907ea6fc594fa81f943bc"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
||||
@@ -1,8 +1,23 @@
|
||||
from .model import main
|
||||
def main():
|
||||
from .model import main as test_model
|
||||
from . import semantic_loss
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from torch import nn
|
||||
|
||||
logger = TensorBoardLogger(
|
||||
save_dir=".",
|
||||
name="logs/comparison",
|
||||
version="cross_entropy",
|
||||
)
|
||||
test_model(logger=logger, loss_func=nn.functional.cross_entropy)
|
||||
|
||||
logger = TensorBoardLogger(
|
||||
save_dir=".",
|
||||
name="logs/comparison",
|
||||
version="similarity_weighted_cross_entropy",
|
||||
)
|
||||
test_model(logger=logger, loss_func=semantic_loss.similarity_weighted_cross_entropy)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
||||
main(logger)
|
||||
main()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from functools import lru_cache
|
||||
from torch import nn
|
||||
|
||||
|
||||
@@ -8,11 +9,19 @@ model = nn.Sequential(
|
||||
)
|
||||
|
||||
|
||||
def main(loss_func=nn.functional.cross_entropy, logger=None):
|
||||
# This is just a quick, lazy way to ensure all models are trained on the same dataset
|
||||
@lru_cache(maxsize=1)
|
||||
def get_singleton_dataset():
|
||||
from torchvision.datasets import QMNIST
|
||||
import lightning as L
|
||||
|
||||
from .dataloader import get_dataset
|
||||
|
||||
return get_dataset(dataset=QMNIST)
|
||||
|
||||
|
||||
def main(loss_func=nn.functional.cross_entropy, logger=None):
|
||||
import lightning as L
|
||||
|
||||
from .train import TrainingWrapper
|
||||
|
||||
if logger is None:
|
||||
@@ -20,8 +29,8 @@ def main(loss_func=nn.functional.cross_entropy, logger=None):
|
||||
|
||||
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
||||
|
||||
train, val, test = get_dataset(dataset=QMNIST)
|
||||
lmodel = TrainingWrapper(model)
|
||||
train, val, test = get_singleton_dataset()
|
||||
lmodel = TrainingWrapper(model, loss_func=loss_func)
|
||||
trainer = L.Trainer(max_epochs=5, logger=logger)
|
||||
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||
|
||||
|
||||
33
symbolic_nn_tests/semantic_loss.py
Normal file
33
symbolic_nn_tests/semantic_loss.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
|
||||
|
||||
# NOTE: This similarity matrix defines loss scaling factors for misclassification
|
||||
# of numbers from our QMNIST dataset. Visually similar numbers (e.g: 3/8) are
|
||||
# penalised less harshly than visually distinct numbers as this mistake is "less
|
||||
# mistaken" given our understanding of the visual characteristics of numerals.
|
||||
# By using this scaling matric we can inject human knowledge into the model via
|
||||
# the loss function, making this an example of a "semantic loss function"
|
||||
SIMILARITY_MATRIX = torch.tensor(
|
||||
[
|
||||
[2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
[1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0],
|
||||
[1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0],
|
||||
[1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.5, 1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 2.0, 1.0, 1.0, 1.0],
|
||||
[1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0],
|
||||
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
|
||||
]
|
||||
).to("cuda")
|
||||
SIMILARITY_MATRIX /= SIMILARITY_MATRIX.sum() # Normalized to sum of 1
|
||||
|
||||
|
||||
def similarity_weighted_cross_entropy(input, target):
|
||||
ce_loss = torch.nn.functional.cross_entropy(input, target)
|
||||
|
||||
penalty_tensor = SIMILARITY_MATRIX[target.argmax(dim=1)]
|
||||
similarity = (target - input).abs()
|
||||
similarity_penalty = (similarity * penalty_tensor).sum()
|
||||
return ce_loss * similarity_penalty
|
||||
@@ -22,7 +22,8 @@ class TrainingWrapper(L.LightningModule):
|
||||
x, y = collate_batch(batch)
|
||||
y_pred = self.model(x)
|
||||
batch_size = x.shape[0]
|
||||
loss = self.loss_func(y_pred, nn.functional.one_hot(y).type(torch.float64))
|
||||
one_hot_y = nn.functional.one_hot(y).type(torch.float64)
|
||||
loss = self.loss_func(y_pred, one_hot_y)
|
||||
acc = torch.sum(y_pred.argmax(dim=1) == y) / batch_size
|
||||
self.log(f"{label}{'_' if label else ''}loss", loss, batch_size=batch_size)
|
||||
self.log(f"{label}{'_' if label else ''}acc", acc, batch_size=batch_size)
|
||||
|
||||
Reference in New Issue
Block a user