From 85363021be943ed2d239df2ecf308ddf743ac90b Mon Sep 17 00:00:00 2001 From: Cian Hughes Date: Tue, 11 Jun 2024 14:18:56 +0100 Subject: [PATCH] Added trainable residual penalty & logging for it --- symbolic_nn_tests/experiment2/__init__.py | 50 ++++++++---- symbolic_nn_tests/experiment2/model.py | 2 +- .../experiment2/semantic_loss.py | 77 +++++++++++-------- 3 files changed, 78 insertions(+), 51 deletions(-) diff --git a/symbolic_nn_tests/experiment2/__init__.py b/symbolic_nn_tests/experiment2/__init__.py index 4e99af8..45f7204 100644 --- a/symbolic_nn_tests/experiment2/__init__.py +++ b/symbolic_nn_tests/experiment2/__init__.py @@ -20,11 +20,15 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True) import wandb as _wandb from lightning.pytorch.loggers import WandbLogger - wandb_logger = WandbLogger( - project="Symbolic_NN_Tests", - name=version, - dir="wandb", - ) + if isinstance(wandb, WandbLogger): + wandb_logger = wandb + else: + wandb_logger = WandbLogger( + project="Symbolic_NN_Tests", + name=version, + dir="wandb", + log_model="all", + ) logger.append(wandb_logger) test_model( @@ -43,19 +47,33 @@ def run(tensorboard: bool = True, wandb: bool = True): from .model import unpacking_smooth_l1_loss from . import semantic_loss + # test( + # train_loss=unpacking_smooth_l1_loss, + # val_loss=unpacking_smooth_l1_loss, + # test_loss=unpacking_smooth_l1_loss, + # version="smooth_l1_loss", + # tensorboard=tensorboard, + # wandb=wandb, + # ) + + version = "positive_slope_linear_loss" + if wandb: + from lightning.pytorch.loggers import WandbLogger + + wandb_logger = WandbLogger( + project="Symbolic_NN_Tests", + name=version, + dir="wandb", + log_model="all", + ) + else: + wandb_logger = wandb + test( - train_loss=unpacking_smooth_l1_loss, + train_loss=semantic_loss.positive_slope_linear_loss(wandb_logger, version), val_loss=unpacking_smooth_l1_loss, test_loss=unpacking_smooth_l1_loss, - version="smooth_l1_loss", + version=version, tensorboard=tensorboard, - wandb=wandb, - ) - test( - train_loss=semantic_loss.positive_slope_linear_loss, - val_loss=unpacking_smooth_l1_loss, - test_loss=unpacking_smooth_l1_loss, - version="positive_slope_linear_loss", - tensorboard=tensorboard, - wandb=wandb, + wandb=wandb_logger, ) diff --git a/symbolic_nn_tests/experiment2/model.py b/symbolic_nn_tests/experiment2/model.py index 6a403be..7877fb7 100644 --- a/symbolic_nn_tests/experiment2/model.py +++ b/symbolic_nn_tests/experiment2/model.py @@ -49,7 +49,7 @@ def get_singleton_dataset(): from symbolic_nn_tests.experiment2.dataset import collate, pubchem return create_dataset( - dataset=pubchem, collate_fn=collate, batch_size=512, shuffle=True + dataset=pubchem, collate_fn=collate, batch_size=256, shuffle=True ) diff --git a/symbolic_nn_tests/experiment2/semantic_loss.py b/symbolic_nn_tests/experiment2/semantic_loss.py index 5faaeca..5603353 100644 --- a/symbolic_nn_tests/experiment2/semantic_loss.py +++ b/symbolic_nn_tests/experiment2/semantic_loss.py @@ -18,43 +18,52 @@ import torch # proportionality. -def positive_slope_linear_loss(out, y): - x, y_pred = out - x0, x1 = x +def positive_slope_linear_loss(wandb_logger=None, version="", device="cuda"): + a = nn.Parameter(data=torch.randn(1), requires_grad=True).to(device) - # Here, we want to make semantic use of the differential electronegativity of the molecule - # so start by calculating that - mean_electronegativities = torch.tensor( - [i[:, 3].mean() for i in x0], dtype=torch.float32 - ).to(y_pred.device) - diff_electronegativity = ( - torch.tensor( - [ - (i[:, 3] - mean).abs().sum() - for i, mean in zip(x0, mean_electronegativities) - ], - dtype=torch.float32, + def f(out, y): + x, y_pred = out + x0, x1 = x + + # Here, we want to make semantic use of the differential electronegativity of the molecule + # so start by calculating that + mean_electronegativities = torch.tensor( + [i[:, 3].mean() for i in x0], dtype=torch.float32 + ).to(y_pred.device) + diff_electronegativity = ( + torch.tensor( + [ + (i[:, 3] - mean).abs().sum() + for i, mean in zip(x0, mean_electronegativities) + ], + dtype=torch.float32, + ) + * 4.0 + ).to(y_pred.device) + + # Then, we need to get a linear best fit on that. Our semantic info is based on a graph of + # En (y) vs differential electronegativity on the x vs y axes, so y_pred is y here + m, c = linear_fit(diff_electronegativity, y_pred) + + # To start with, we want to calculate a penalty based on deviation from a linear relationship + residual_penalty = ( + (1 / sech(linear_residuals(diff_electronegativity, y_pred, m, c))) + .abs() + .float() + .mean() ) - * 4.0 - ).to(y_pred.device) - # Then, we need to get a linear best fit on that. Our semantic info is based on a graph of - # En (y) vs differential electronegativity on the x vs y axes, so y_pred is y here - m, c = linear_fit(diff_electronegativity, y_pred) + # We also need to calculate a penalty that incentivizes a positive slope. For this, im using relu + # to scale the slope as it will penalise negative slopes without just creating a reward hack for + # maximizing slope. + slope_penalty = (nn.functional.relu(a * (-m)) + 1).mean() - # To start with, we want to calculate a penalty based on deviation from a linear relationship - residual_penalty = ( - (1 / sech(linear_residuals(diff_electronegativity, y_pred, m, c))) - .abs() - .float() - .mean() - ) + if wandb_logger: + wandb_logger.log_metrics({f"{version}-a": a}) - # We also need to calculate a penalty that incentivizes a positive slope. For this, im using softplus - # to scale the slope as it will penalise negative slopes without just creating a reward hack for - # maximizing slope. - slope_penalty = (nn.functional.softplus(-m) + 1).mean() - - # Finally, let's get a smooth L1 loss and scale it based on these penalty functions - return nn.functional.smooth_l1_loss(y_pred, y) * residual_penalty * slope_penalty + # Finally, let's get a smooth L1 loss and scale it based on these penalty functions + return ( + nn.functional.smooth_l1_loss(y_pred, y) * residual_penalty * slope_penalty + ) + return f