mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
Added more varied semantic loss functions
This commit is contained in:
@@ -1,22 +1,28 @@
|
||||
def main():
|
||||
LEARNING_RATE = 10e-5
|
||||
|
||||
|
||||
def run_test(loss_func, version):
|
||||
from .model import main as test_model
|
||||
from . import semantic_loss
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
logger = TensorBoardLogger(
|
||||
save_dir=".",
|
||||
name="logs/comparison",
|
||||
version=version,
|
||||
)
|
||||
test_model(lr=LEARNING_RATE)
|
||||
# test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
|
||||
|
||||
|
||||
def main():
|
||||
from . import semantic_loss
|
||||
from torch import nn
|
||||
|
||||
logger = TensorBoardLogger(
|
||||
save_dir=".",
|
||||
name="logs/comparison",
|
||||
version="cross_entropy",
|
||||
)
|
||||
test_model(logger=logger, loss_func=nn.functional.cross_entropy)
|
||||
|
||||
logger = TensorBoardLogger(
|
||||
save_dir=".",
|
||||
name="logs/comparison",
|
||||
version="similarity_weighted_cross_entropy",
|
||||
)
|
||||
test_model(logger=logger, loss_func=semantic_loss.similarity_weighted_cross_entropy)
|
||||
run_test(nn.functional.cross_entropy, "cross_entropy")
|
||||
# run_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy")
|
||||
# run_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy")
|
||||
# run_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy")
|
||||
# run_test(semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
import torch
|
||||
|
||||
|
||||
def create_semantic_cross_entropy(semantic_matrix):
|
||||
def semantic_cross_entropy(input, target):
|
||||
ce_loss = torch.nn.functional.cross_entropy(input, target)
|
||||
|
||||
penalty_tensor = semantic_matrix[target.argmax(dim=1)]
|
||||
abs_diff = (target - input).abs()
|
||||
semantic_penalty = (abs_diff * penalty_tensor).sum()
|
||||
return ce_loss * semantic_penalty
|
||||
|
||||
return semantic_cross_entropy
|
||||
|
||||
|
||||
# NOTE: This similarity matrix defines loss scaling factors for misclassification
|
||||
# of numbers from our QMNIST dataset. Visually similar numbers (e.g: 3/8) are
|
||||
# penalised less harshly than visually distinct numbers as this mistake is "less
|
||||
@@ -23,11 +35,41 @@ SIMILARITY_MATRIX = torch.tensor(
|
||||
).to("cuda")
|
||||
SIMILARITY_MATRIX /= SIMILARITY_MATRIX.sum() # Normalized to sum of 1
|
||||
|
||||
similarity_cross_entropy = create_semantic_cross_entropy(SIMILARITY_MATRIX)
|
||||
|
||||
def similarity_weighted_cross_entropy(input, target):
|
||||
ce_loss = torch.nn.functional.cross_entropy(input, target)
|
||||
|
||||
penalty_tensor = SIMILARITY_MATRIX[target.argmax(dim=1)]
|
||||
similarity = (target - input).abs()
|
||||
similarity_penalty = (similarity * penalty_tensor).sum()
|
||||
return ce_loss * similarity_penalty
|
||||
# NOTE: The following matrix encodes a simpler semantic penalty for correctly/incorrectly
|
||||
# identifying shapes with straight lines in their representation. This can be a bit fuzzy
|
||||
# in cases like "9" though.
|
||||
HASLINE_MATRIX = torch.tensor(
|
||||
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
|
||||
[False, True, False, False, True, True, False, True, False, True]
|
||||
).to("cuda")
|
||||
HASLINE_MATRIX = torch.stack([i ^ HASLINE_MATRIX for i in HASLINE_MATRIX]).type(
|
||||
torch.float64
|
||||
)
|
||||
HASLINE_MATRIX += 1
|
||||
HASLINE_MATRIX /= HASLINE_MATRIX.sum() # Normalize to sum of 1
|
||||
|
||||
hasline_cross_entropy = create_semantic_cross_entropy(HASLINE_MATRIX)
|
||||
|
||||
|
||||
# NOTE: Similarly, we can do the same for closed circular loops in a numeric character
|
||||
HASLOOP_MATRIX = torch.tensor(
|
||||
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
|
||||
[True, False, False, False, False, False, True, False, True, True]
|
||||
).to("cuda")
|
||||
HASLOOP_MATRIX = torch.stack([i ^ HASLOOP_MATRIX for i in HASLOOP_MATRIX]).type(
|
||||
torch.float64
|
||||
)
|
||||
HASLOOP_MATRIX += 1
|
||||
HASLOOP_MATRIX /= HASLOOP_MATRIX.sum() # Normalize to sum of 1
|
||||
|
||||
hasloop_cross_entropy = create_semantic_cross_entropy(HASLOOP_MATRIX)
|
||||
|
||||
|
||||
# NOTE: We can also combine all of these semantic matrices
|
||||
MULTISEMANTIC_MATRIX = SIMILARITY_MATRIX * HASLINE_MATRIX * HASLOOP_MATRIX
|
||||
MULTISEMANTIC_MATRIX /= MULTISEMANTIC_MATRIX.sum()
|
||||
|
||||
multisemantic_cross_entropy = create_semantic_cross_entropy(MULTISEMANTIC_MATRIX)
|
||||
|
||||
@@ -36,6 +36,9 @@ class TrainingWrapper(L.LightningModule):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
self._forward_step(batch, batch_idx, label="val")
|
||||
|
||||
def configure_optimizers(self, optimizer=optim.Adam, *args, **kwargs):
|
||||
_optimizer = optimizer(self.parameters(), *args, **kwargs)
|
||||
def test_step(self, batch, batch_idx):
|
||||
self._forward_step(batch, batch_idx, label="test")
|
||||
|
||||
def configure_optimizers(self, optimizer=optim.SGD, **kwargs):
|
||||
_optimizer = optimizer(self.parameters(), **kwargs)
|
||||
return _optimizer
|
||||
|
||||
Reference in New Issue
Block a user