diff --git a/symbolic_nn_tests/__main__.py b/symbolic_nn_tests/__main__.py index 261c4db..ef2412a 100644 --- a/symbolic_nn_tests/__main__.py +++ b/symbolic_nn_tests/__main__.py @@ -1,22 +1,28 @@ -def main(): +LEARNING_RATE = 10e-5 + + +def run_test(loss_func, version): from .model import main as test_model - from . import semantic_loss from lightning.pytorch.loggers import TensorBoardLogger + + logger = TensorBoardLogger( + save_dir=".", + name="logs/comparison", + version=version, + ) + test_model(lr=LEARNING_RATE) + # test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE) + + +def main(): + from . import semantic_loss from torch import nn - logger = TensorBoardLogger( - save_dir=".", - name="logs/comparison", - version="cross_entropy", - ) - test_model(logger=logger, loss_func=nn.functional.cross_entropy) - - logger = TensorBoardLogger( - save_dir=".", - name="logs/comparison", - version="similarity_weighted_cross_entropy", - ) - test_model(logger=logger, loss_func=semantic_loss.similarity_weighted_cross_entropy) + run_test(nn.functional.cross_entropy, "cross_entropy") + # run_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy") + # run_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy") + # run_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy") + # run_test(semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy") if __name__ == "__main__": diff --git a/symbolic_nn_tests/semantic_loss.py b/symbolic_nn_tests/semantic_loss.py index f77a6e5..7634e4b 100644 --- a/symbolic_nn_tests/semantic_loss.py +++ b/symbolic_nn_tests/semantic_loss.py @@ -1,6 +1,18 @@ import torch +def create_semantic_cross_entropy(semantic_matrix): + def semantic_cross_entropy(input, target): + ce_loss = torch.nn.functional.cross_entropy(input, target) + + penalty_tensor = semantic_matrix[target.argmax(dim=1)] + abs_diff = (target - input).abs() + semantic_penalty = (abs_diff * penalty_tensor).sum() + return ce_loss * semantic_penalty + + return semantic_cross_entropy + + # NOTE: This similarity matrix defines loss scaling factors for misclassification # of numbers from our QMNIST dataset. Visually similar numbers (e.g: 3/8) are # penalised less harshly than visually distinct numbers as this mistake is "less @@ -23,11 +35,41 @@ SIMILARITY_MATRIX = torch.tensor( ).to("cuda") SIMILARITY_MATRIX /= SIMILARITY_MATRIX.sum() # Normalized to sum of 1 +similarity_cross_entropy = create_semantic_cross_entropy(SIMILARITY_MATRIX) -def similarity_weighted_cross_entropy(input, target): - ce_loss = torch.nn.functional.cross_entropy(input, target) - penalty_tensor = SIMILARITY_MATRIX[target.argmax(dim=1)] - similarity = (target - input).abs() - similarity_penalty = (similarity * penalty_tensor).sum() - return ce_loss * similarity_penalty +# NOTE: The following matrix encodes a simpler semantic penalty for correctly/incorrectly +# identifying shapes with straight lines in their representation. This can be a bit fuzzy +# in cases like "9" though. +HASLINE_MATRIX = torch.tensor( + # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + [False, True, False, False, True, True, False, True, False, True] +).to("cuda") +HASLINE_MATRIX = torch.stack([i ^ HASLINE_MATRIX for i in HASLINE_MATRIX]).type( + torch.float64 +) +HASLINE_MATRIX += 1 +HASLINE_MATRIX /= HASLINE_MATRIX.sum() # Normalize to sum of 1 + +hasline_cross_entropy = create_semantic_cross_entropy(HASLINE_MATRIX) + + +# NOTE: Similarly, we can do the same for closed circular loops in a numeric character +HASLOOP_MATRIX = torch.tensor( + # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + [True, False, False, False, False, False, True, False, True, True] +).to("cuda") +HASLOOP_MATRIX = torch.stack([i ^ HASLOOP_MATRIX for i in HASLOOP_MATRIX]).type( + torch.float64 +) +HASLOOP_MATRIX += 1 +HASLOOP_MATRIX /= HASLOOP_MATRIX.sum() # Normalize to sum of 1 + +hasloop_cross_entropy = create_semantic_cross_entropy(HASLOOP_MATRIX) + + +# NOTE: We can also combine all of these semantic matrices +MULTISEMANTIC_MATRIX = SIMILARITY_MATRIX * HASLINE_MATRIX * HASLOOP_MATRIX +MULTISEMANTIC_MATRIX /= MULTISEMANTIC_MATRIX.sum() + +multisemantic_cross_entropy = create_semantic_cross_entropy(MULTISEMANTIC_MATRIX) diff --git a/symbolic_nn_tests/train.py b/symbolic_nn_tests/train.py index 6df0ec2..fe76f02 100644 --- a/symbolic_nn_tests/train.py +++ b/symbolic_nn_tests/train.py @@ -36,6 +36,9 @@ class TrainingWrapper(L.LightningModule): def validation_step(self, batch, batch_idx): self._forward_step(batch, batch_idx, label="val") - def configure_optimizers(self, optimizer=optim.Adam, *args, **kwargs): - _optimizer = optimizer(self.parameters(), *args, **kwargs) + def test_step(self, batch, batch_idx): + self._forward_step(batch, batch_idx, label="test") + + def configure_optimizers(self, optimizer=optim.SGD, **kwargs): + _optimizer = optimizer(self.parameters(), **kwargs) return _optimizer