diff --git a/poetry.lock b/poetry.lock index 9c8c5d3..2a60f50 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1637,13 +1637,13 @@ test = ["coverage", "pytest", "pytest-cov"] [[package]] name = "mako" -version = "1.3.3" +version = "1.3.5" description = "A super-fast templating language that borrows the best ideas from the existing templating languages." optional = false python-versions = ">=3.8" files = [ - {file = "Mako-1.3.3-py3-none-any.whl", hash = "sha256:5324b88089a8978bf76d1629774fcc2f1c07b82acdf00f4c5dd8ceadfffc4b40"}, - {file = "Mako-1.3.3.tar.gz", hash = "sha256:e16c01d9ab9c11f7290eef1cfefc093fb5a45ee4a3da09e2fec2e4d1bae54e73"}, + {file = "Mako-1.3.5-py3-none-any.whl", hash = "sha256:260f1dbc3a519453a9c856dedfe4beb4e50bd5a26d96386cb6c80856556bb91a"}, + {file = "Mako-1.3.5.tar.gz", hash = "sha256:48dbc20568c1d276a2698b36d968fa76161bf127194907ea6fc594fa81f943bc"}, ] [package.dependencies] diff --git a/symbolic_nn_tests/__main__.py b/symbolic_nn_tests/__main__.py index 48b7ea0..261c4db 100644 --- a/symbolic_nn_tests/__main__.py +++ b/symbolic_nn_tests/__main__.py @@ -1,8 +1,23 @@ -from .model import main +def main(): + from .model import main as test_model + from . import semantic_loss + from lightning.pytorch.loggers import TensorBoardLogger + from torch import nn + + logger = TensorBoardLogger( + save_dir=".", + name="logs/comparison", + version="cross_entropy", + ) + test_model(logger=logger, loss_func=nn.functional.cross_entropy) + + logger = TensorBoardLogger( + save_dir=".", + name="logs/comparison", + version="similarity_weighted_cross_entropy", + ) + test_model(logger=logger, loss_func=semantic_loss.similarity_weighted_cross_entropy) if __name__ == "__main__": - from lightning.pytorch.loggers import TensorBoardLogger - - logger = TensorBoardLogger(save_dir=".", name="logs/ffnn") - main(logger) + main() diff --git a/symbolic_nn_tests/model.py b/symbolic_nn_tests/model.py index 2c28591..2b88952 100644 --- a/symbolic_nn_tests/model.py +++ b/symbolic_nn_tests/model.py @@ -1,3 +1,4 @@ +from functools import lru_cache from torch import nn @@ -8,11 +9,19 @@ model = nn.Sequential( ) -def main(loss_func=nn.functional.cross_entropy, logger=None): +# This is just a quick, lazy way to ensure all models are trained on the same dataset +@lru_cache(maxsize=1) +def get_singleton_dataset(): from torchvision.datasets import QMNIST - import lightning as L from .dataloader import get_dataset + + return get_dataset(dataset=QMNIST) + + +def main(loss_func=nn.functional.cross_entropy, logger=None): + import lightning as L + from .train import TrainingWrapper if logger is None: @@ -20,8 +29,8 @@ def main(loss_func=nn.functional.cross_entropy, logger=None): logger = TensorBoardLogger(save_dir=".", name="logs/ffnn") - train, val, test = get_dataset(dataset=QMNIST) - lmodel = TrainingWrapper(model) + train, val, test = get_singleton_dataset() + lmodel = TrainingWrapper(model, loss_func=loss_func) trainer = L.Trainer(max_epochs=5, logger=logger) trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val) diff --git a/symbolic_nn_tests/semantic_loss.py b/symbolic_nn_tests/semantic_loss.py new file mode 100644 index 0000000..f77a6e5 --- /dev/null +++ b/symbolic_nn_tests/semantic_loss.py @@ -0,0 +1,33 @@ +import torch + + +# NOTE: This similarity matrix defines loss scaling factors for misclassification +# of numbers from our QMNIST dataset. Visually similar numbers (e.g: 3/8) are +# penalised less harshly than visually distinct numbers as this mistake is "less +# mistaken" given our understanding of the visual characteristics of numerals. +# By using this scaling matric we can inject human knowledge into the model via +# the loss function, making this an example of a "semantic loss function" +SIMILARITY_MATRIX = torch.tensor( + [ + [2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0], + [1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0], + [1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.5, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 2.0, 1.0, 1.0, 1.0], + [1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0], + ] +).to("cuda") +SIMILARITY_MATRIX /= SIMILARITY_MATRIX.sum() # Normalized to sum of 1 + + +def similarity_weighted_cross_entropy(input, target): + ce_loss = torch.nn.functional.cross_entropy(input, target) + + penalty_tensor = SIMILARITY_MATRIX[target.argmax(dim=1)] + similarity = (target - input).abs() + similarity_penalty = (similarity * penalty_tensor).sum() + return ce_loss * similarity_penalty diff --git a/symbolic_nn_tests/train.py b/symbolic_nn_tests/train.py index d4f77d9..6df0ec2 100644 --- a/symbolic_nn_tests/train.py +++ b/symbolic_nn_tests/train.py @@ -22,7 +22,8 @@ class TrainingWrapper(L.LightningModule): x, y = collate_batch(batch) y_pred = self.model(x) batch_size = x.shape[0] - loss = self.loss_func(y_pred, nn.functional.one_hot(y).type(torch.float64)) + one_hot_y = nn.functional.one_hot(y).type(torch.float64) + loss = self.loss_func(y_pred, one_hot_y) acc = torch.sum(y_pred.argmax(dim=1) == y) / batch_size self.log(f"{label}{'_' if label else ''}loss", loss, batch_size=batch_size) self.log(f"{label}{'_' if label else ''}acc", acc, batch_size=batch_size)