Added trainable residual penalty & logging for it

This commit is contained in:
2024-06-11 14:18:56 +01:00
parent a14babd58a
commit 85363021be
3 changed files with 78 additions and 51 deletions

View File

@@ -20,11 +20,15 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True)
import wandb as _wandb
from lightning.pytorch.loggers import WandbLogger
wandb_logger = WandbLogger(
project="Symbolic_NN_Tests",
name=version,
dir="wandb",
)
if isinstance(wandb, WandbLogger):
wandb_logger = wandb
else:
wandb_logger = WandbLogger(
project="Symbolic_NN_Tests",
name=version,
dir="wandb",
log_model="all",
)
logger.append(wandb_logger)
test_model(
@@ -43,19 +47,33 @@ def run(tensorboard: bool = True, wandb: bool = True):
from .model import unpacking_smooth_l1_loss
from . import semantic_loss
# test(
# train_loss=unpacking_smooth_l1_loss,
# val_loss=unpacking_smooth_l1_loss,
# test_loss=unpacking_smooth_l1_loss,
# version="smooth_l1_loss",
# tensorboard=tensorboard,
# wandb=wandb,
# )
version = "positive_slope_linear_loss"
if wandb:
from lightning.pytorch.loggers import WandbLogger
wandb_logger = WandbLogger(
project="Symbolic_NN_Tests",
name=version,
dir="wandb",
log_model="all",
)
else:
wandb_logger = wandb
test(
train_loss=unpacking_smooth_l1_loss,
train_loss=semantic_loss.positive_slope_linear_loss(wandb_logger, version),
val_loss=unpacking_smooth_l1_loss,
test_loss=unpacking_smooth_l1_loss,
version="smooth_l1_loss",
version=version,
tensorboard=tensorboard,
wandb=wandb,
)
test(
train_loss=semantic_loss.positive_slope_linear_loss,
val_loss=unpacking_smooth_l1_loss,
test_loss=unpacking_smooth_l1_loss,
version="positive_slope_linear_loss",
tensorboard=tensorboard,
wandb=wandb,
wandb=wandb_logger,
)

View File

@@ -49,7 +49,7 @@ def get_singleton_dataset():
from symbolic_nn_tests.experiment2.dataset import collate, pubchem
return create_dataset(
dataset=pubchem, collate_fn=collate, batch_size=512, shuffle=True
dataset=pubchem, collate_fn=collate, batch_size=256, shuffle=True
)

View File

@@ -18,43 +18,52 @@ import torch
# proportionality.
def positive_slope_linear_loss(out, y):
x, y_pred = out
x0, x1 = x
def positive_slope_linear_loss(wandb_logger=None, version="", device="cuda"):
a = nn.Parameter(data=torch.randn(1), requires_grad=True).to(device)
# Here, we want to make semantic use of the differential electronegativity of the molecule
# so start by calculating that
mean_electronegativities = torch.tensor(
[i[:, 3].mean() for i in x0], dtype=torch.float32
).to(y_pred.device)
diff_electronegativity = (
torch.tensor(
[
(i[:, 3] - mean).abs().sum()
for i, mean in zip(x0, mean_electronegativities)
],
dtype=torch.float32,
def f(out, y):
x, y_pred = out
x0, x1 = x
# Here, we want to make semantic use of the differential electronegativity of the molecule
# so start by calculating that
mean_electronegativities = torch.tensor(
[i[:, 3].mean() for i in x0], dtype=torch.float32
).to(y_pred.device)
diff_electronegativity = (
torch.tensor(
[
(i[:, 3] - mean).abs().sum()
for i, mean in zip(x0, mean_electronegativities)
],
dtype=torch.float32,
)
* 4.0
).to(y_pred.device)
# Then, we need to get a linear best fit on that. Our semantic info is based on a graph of
# En (y) vs differential electronegativity on the x vs y axes, so y_pred is y here
m, c = linear_fit(diff_electronegativity, y_pred)
# To start with, we want to calculate a penalty based on deviation from a linear relationship
residual_penalty = (
(1 / sech(linear_residuals(diff_electronegativity, y_pred, m, c)))
.abs()
.float()
.mean()
)
* 4.0
).to(y_pred.device)
# Then, we need to get a linear best fit on that. Our semantic info is based on a graph of
# En (y) vs differential electronegativity on the x vs y axes, so y_pred is y here
m, c = linear_fit(diff_electronegativity, y_pred)
# We also need to calculate a penalty that incentivizes a positive slope. For this, im using relu
# to scale the slope as it will penalise negative slopes without just creating a reward hack for
# maximizing slope.
slope_penalty = (nn.functional.relu(a * (-m)) + 1).mean()
# To start with, we want to calculate a penalty based on deviation from a linear relationship
residual_penalty = (
(1 / sech(linear_residuals(diff_electronegativity, y_pred, m, c)))
.abs()
.float()
.mean()
)
if wandb_logger:
wandb_logger.log_metrics({f"{version}-a": a})
# We also need to calculate a penalty that incentivizes a positive slope. For this, im using softplus
# to scale the slope as it will penalise negative slopes without just creating a reward hack for
# maximizing slope.
slope_penalty = (nn.functional.softplus(-m) + 1).mean()
# Finally, let's get a smooth L1 loss and scale it based on these penalty functions
return nn.functional.smooth_l1_loss(y_pred, y) * residual_penalty * slope_penalty
# Finally, let's get a smooth L1 loss and scale it based on these penalty functions
return (
nn.functional.smooth_l1_loss(y_pred, y) * residual_penalty * slope_penalty
)
return f