mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
Final attempt at Expt2
This commit is contained in:
@@ -9,8 +9,8 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
self.return_module_y = return_module_y
|
self.return_module_y = return_module_y
|
||||||
|
|
||||||
self.x0_encoder = nn.TransformerEncoderLayer(7, 7)
|
self.x0_encoder = nn.TransformerEncoderLayer(7, 7, dim_feedforward=512)
|
||||||
self.x1_encoder = nn.TransformerEncoderLayer(10, 10)
|
self.x1_encoder = nn.TransformerEncoderLayer(10, 10, dim_feedforward=1024)
|
||||||
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
|
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
|
||||||
self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder)
|
self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder)
|
||||||
self.ff = nn.Sequential(
|
self.ff = nn.Sequential(
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from symbolic_nn_tests.experiment2.math import linear_fit, linear_residuals, sech
|
from symbolic_nn_tests.experiment2.math import linear_fit, linear_residuals
|
||||||
from random import random
|
from random import random
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch
|
import torch
|
||||||
@@ -22,7 +22,7 @@ import torch
|
|||||||
class PositiveSlopeLinearLoss(nn.Module):
|
class PositiveSlopeLinearLoss(nn.Module):
|
||||||
def __init__(self, wandb_logger=None, version="", device="cuda", log_freq=50):
|
def __init__(self, wandb_logger=None, version="", device="cuda", log_freq=50):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.params = [random()]
|
self.params = [random(), random(), random()]
|
||||||
self.wandb_logger = wandb_logger
|
self.wandb_logger = wandb_logger
|
||||||
self.version = version
|
self.version = version
|
||||||
self.device = device
|
self.device = device
|
||||||
@@ -54,17 +54,15 @@ class PositiveSlopeLinearLoss(nn.Module):
|
|||||||
m, c = linear_fit(diff_electronegativity, y_pred)
|
m, c = linear_fit(diff_electronegativity, y_pred)
|
||||||
|
|
||||||
# To start with, we want to calculate a penalty based on deviation from a linear relationship
|
# To start with, we want to calculate a penalty based on deviation from a linear relationship
|
||||||
residual_penalty = (
|
r = linear_residuals(diff_electronegativity, y_pred, m, c)
|
||||||
(1 / sech(linear_residuals(diff_electronegativity, y_pred, m, c)))
|
residual_penalty = ((self.params[0] * r * r) + 1).float().mean()
|
||||||
.abs()
|
|
||||||
.float()
|
|
||||||
.mean()
|
|
||||||
)
|
|
||||||
|
|
||||||
# We also need to calculate a penalty that incentivizes a positive slope. For this, im using relu
|
# We also need to calculate a penalty that incentivizes a positive slope. For this, im using relu
|
||||||
# to scale the slope as it will penalise negative slopes without just creating a reward hack for
|
# to scale the slope as it will penalise negative slopes without just creating a reward hack for
|
||||||
# maximizing slope.
|
# maximizing slope.
|
||||||
slope_penalty = (nn.functional.relu(self.params[0] * (-m)) + 1).mean()
|
slope_penalty = (
|
||||||
|
nn.functional.softplus(self.params[1] * (-m), beta=self.params[2]) + 1
|
||||||
|
).mean()
|
||||||
|
|
||||||
if self.wandb_logger and (self.steps_since_log >= self.log_freq):
|
if self.wandb_logger and (self.steps_since_log >= self.log_freq):
|
||||||
self.wandb_logger.log_metrics({f"{self.version}-a": self.params})
|
self.wandb_logger.log_metrics({f"{self.version}-a": self.params})
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ class TrainingWrapper(_TrainingWrapper):
|
|||||||
def __init__(self, *args, loss_rate_target=-10, **kwargs):
|
def __init__(self, *args, loss_rate_target=-10, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.loss_optimizer = Optimizer(
|
self.loss_optimizer = Optimizer(
|
||||||
[(0.0, 1000.0)],
|
[(0.0, 512.0), (0.0, 1024.0), (0.0, 512.0)],
|
||||||
base_estimator=RandomForestRegressor(
|
base_estimator=RandomForestRegressor(
|
||||||
n_jobs=-1,
|
n_jobs=-1,
|
||||||
),
|
),
|
||||||
|
|||||||
Reference in New Issue
Block a user