mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 14:11:59 +00:00
Added separate loss funcs for train, val, and test
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
LEARNING_RATE = 10e-5
|
LEARNING_RATE = 10e-5
|
||||||
|
|
||||||
|
|
||||||
def test(loss_func, version, tensorboard=True, wandb=True):
|
def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True):
|
||||||
from .model import main as test_model
|
from .model import main as test_model
|
||||||
|
|
||||||
logger = []
|
logger = []
|
||||||
@@ -27,7 +27,13 @@ def test(loss_func, version, tensorboard=True, wandb=True):
|
|||||||
)
|
)
|
||||||
logger.append(wandb_logger)
|
logger.append(wandb_logger)
|
||||||
|
|
||||||
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
|
test_model(
|
||||||
|
logger=logger,
|
||||||
|
train_loss=train_loss,
|
||||||
|
val_loss=val_loss,
|
||||||
|
test_loss=test_loss,
|
||||||
|
lr=LEARNING_RATE,
|
||||||
|
)
|
||||||
|
|
||||||
if wandb:
|
if wandb:
|
||||||
_wandb.finish()
|
_wandb.finish()
|
||||||
@@ -35,41 +41,53 @@ def test(loss_func, version, tensorboard=True, wandb=True):
|
|||||||
|
|
||||||
def run(tensorboard: bool = True, wandb: bool = True):
|
def run(tensorboard: bool = True, wandb: bool = True):
|
||||||
from . import semantic_loss
|
from . import semantic_loss
|
||||||
from torch import nn
|
from .model import oh_vs_cat_cross_entropy
|
||||||
|
|
||||||
test(
|
test(
|
||||||
nn.functional.cross_entropy,
|
train_loss=oh_vs_cat_cross_entropy,
|
||||||
"cross_entropy",
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
version="cross_entropy",
|
||||||
tensorboard=tensorboard,
|
tensorboard=tensorboard,
|
||||||
wandb=wandb,
|
wandb=wandb,
|
||||||
)
|
)
|
||||||
test(
|
test(
|
||||||
semantic_loss.similarity_cross_entropy,
|
train_loss=semantic_loss.similarity_cross_entropy,
|
||||||
"similarity_cross_entropy",
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
version="similarity_cross_entropy",
|
||||||
tensorboard=tensorboard,
|
tensorboard=tensorboard,
|
||||||
wandb=wandb,
|
wandb=wandb,
|
||||||
)
|
)
|
||||||
test(
|
test(
|
||||||
semantic_loss.hasline_cross_entropy,
|
train_loss=semantic_loss.hasline_cross_entropy,
|
||||||
"hasline_cross_entropy",
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
version="hasline_cross_entropy",
|
||||||
tensorboard=tensorboard,
|
tensorboard=tensorboard,
|
||||||
wandb=wandb,
|
wandb=wandb,
|
||||||
)
|
)
|
||||||
test(
|
test(
|
||||||
semantic_loss.hasloop_cross_entropy,
|
train_loss=semantic_loss.hasloop_cross_entropy,
|
||||||
"hasloop_cross_entropy",
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
version="hasloop_cross_entropy",
|
||||||
tensorboard=tensorboard,
|
tensorboard=tensorboard,
|
||||||
wandb=wandb,
|
wandb=wandb,
|
||||||
)
|
)
|
||||||
test(
|
test(
|
||||||
semantic_loss.multisemantic_cross_entropy,
|
train_loss=semantic_loss.multisemantic_cross_entropy,
|
||||||
"multisemantic_cross_entropy",
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
version="multisemantic_cross_entropy",
|
||||||
tensorboard=tensorboard,
|
tensorboard=tensorboard,
|
||||||
wandb=wandb,
|
wandb=wandb,
|
||||||
)
|
)
|
||||||
test(
|
test(
|
||||||
semantic_loss.garbage_cross_entropy,
|
train_loss=semantic_loss.garbage_cross_entropy,
|
||||||
"garbage_cross_entropy",
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
version="garbage_cross_entropy",
|
||||||
tensorboard=tensorboard,
|
tensorboard=tensorboard,
|
||||||
wandb=wandb,
|
wandb=wandb,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -38,7 +38,13 @@ def oh_vs_cat_cross_entropy(y_bin, y_cat):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(loss_func=oh_vs_cat_cross_entropy, logger=None, **kwargs):
|
def main(
|
||||||
|
train_loss=oh_vs_cat_cross_entropy,
|
||||||
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
logger=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
import lightning as L
|
import lightning as L
|
||||||
|
|
||||||
from symbolic_nn_tests.train import TrainingWrapper
|
from symbolic_nn_tests.train import TrainingWrapper
|
||||||
@@ -49,7 +55,9 @@ def main(loss_func=oh_vs_cat_cross_entropy, logger=None, **kwargs):
|
|||||||
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
||||||
|
|
||||||
train, val, test = get_singleton_dataset()
|
train, val, test = get_singleton_dataset()
|
||||||
lmodel = TrainingWrapper(model, loss_func=loss_func)
|
lmodel = TrainingWrapper(
|
||||||
|
model, train_loss=train_loss, val_loss=val_loss, test_loss=val_loss
|
||||||
|
)
|
||||||
lmodel.configure_optimizers(**kwargs)
|
lmodel.configure_optimizers(**kwargs)
|
||||||
trainer = L.Trainer(max_epochs=20, logger=logger)
|
trainer = L.Trainer(max_epochs=20, logger=logger)
|
||||||
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
LEARNING_RATE = 10e-5
|
LEARNING_RATE = 10e-5
|
||||||
|
|
||||||
|
|
||||||
def test(loss_func, version, tensorboard=True, wandb=True):
|
def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True):
|
||||||
from .model import main as test_model
|
from .model import main as test_model
|
||||||
|
|
||||||
logger = []
|
logger = []
|
||||||
@@ -27,7 +27,13 @@ def test(loss_func, version, tensorboard=True, wandb=True):
|
|||||||
)
|
)
|
||||||
logger.append(wandb_logger)
|
logger.append(wandb_logger)
|
||||||
|
|
||||||
test_model(logger=logger, loss_func=loss_func)
|
test_model(
|
||||||
|
logger=logger,
|
||||||
|
train_loss=train_loss,
|
||||||
|
val_loss=val_loss,
|
||||||
|
test_loss=test_loss,
|
||||||
|
lr=LEARNING_RATE,
|
||||||
|
)
|
||||||
|
|
||||||
if wandb:
|
if wandb:
|
||||||
_wandb.finish()
|
_wandb.finish()
|
||||||
@@ -38,14 +44,18 @@ def run(tensorboard: bool = True, wandb: bool = True):
|
|||||||
from . import semantic_loss
|
from . import semantic_loss
|
||||||
|
|
||||||
test(
|
test(
|
||||||
unpacking_mse_loss,
|
train_loss=unpacking_mse_loss,
|
||||||
"mse_loss",
|
val_loss=unpacking_mse_loss,
|
||||||
|
test_loss=unpacking_mse_loss,
|
||||||
|
version="mse_loss",
|
||||||
tensorboard=tensorboard,
|
tensorboard=tensorboard,
|
||||||
wandb=wandb,
|
wandb=wandb,
|
||||||
)
|
)
|
||||||
test(
|
test(
|
||||||
semantic_loss.positive_slope_linear_loss,
|
train_loss=semantic_loss.positive_slope_linear_loss,
|
||||||
"positive_slope_linear_loss",
|
val_loss=unpacking_mse_loss,
|
||||||
|
test_loss=unpacking_mse_loss,
|
||||||
|
version="positive_slope_linear_loss",
|
||||||
tensorboard=tensorboard,
|
tensorboard=tensorboard,
|
||||||
wandb=wandb,
|
wandb=wandb,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,7 +14,11 @@ class Model(nn.Module):
|
|||||||
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
|
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
|
||||||
self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder)
|
self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder)
|
||||||
self.ff = nn.Sequential(
|
self.ff = nn.Sequential(
|
||||||
nn.Linear(17, 128),
|
nn.Linear(17, 512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(512, 256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(256, 128),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(128, 64),
|
nn.Linear(128, 64),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
@@ -62,7 +66,13 @@ def unpacking_mse_loss(out, y):
|
|||||||
return nn.functional.mse_loss(y_pred, y)
|
return nn.functional.mse_loss(y_pred, y)
|
||||||
|
|
||||||
|
|
||||||
def main(loss_func=unpacking_mse_loss, logger=None, **kwargs):
|
def main(
|
||||||
|
train_loss=unpacking_mse_loss,
|
||||||
|
val_loss=unpacking_mse_loss,
|
||||||
|
test_loss=unpacking_mse_loss,
|
||||||
|
logger=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
import lightning as L
|
import lightning as L
|
||||||
|
|
||||||
from symbolic_nn_tests.train import TrainingWrapper
|
from symbolic_nn_tests.train import TrainingWrapper
|
||||||
@@ -73,7 +83,12 @@ def main(loss_func=unpacking_mse_loss, logger=None, **kwargs):
|
|||||||
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
||||||
|
|
||||||
train, val, test = get_singleton_dataset()
|
train, val, test = get_singleton_dataset()
|
||||||
lmodel = TrainingWrapper(Model(), loss_func=loss_func)
|
lmodel = TrainingWrapper(
|
||||||
|
Model(),
|
||||||
|
train_loss=train_loss,
|
||||||
|
val_loss=train_loss,
|
||||||
|
test_loss=train_loss,
|
||||||
|
)
|
||||||
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
|
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
|
||||||
trainer = L.Trainer(max_epochs=10, logger=logger)
|
trainer = L.Trainer(max_epochs=10, logger=logger)
|
||||||
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||||
|
|||||||
@@ -3,16 +3,25 @@ import lightning as L
|
|||||||
|
|
||||||
|
|
||||||
class TrainingWrapper(L.LightningModule):
|
class TrainingWrapper(L.LightningModule):
|
||||||
def __init__(self, model, loss_func=nn.functional.mse_loss, accuracy=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
train_loss=nn.functional.mse_loss,
|
||||||
|
val_loss=nn.functional.mse_loss,
|
||||||
|
test_loss=nn.functional.mse_loss,
|
||||||
|
accuracy=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.loss_func = loss_func
|
self.train_loss = train_loss
|
||||||
|
self.val_loss = val_loss
|
||||||
|
self.test_loss = val_loss
|
||||||
self.accuracy = accuracy
|
self.accuracy = accuracy
|
||||||
|
|
||||||
def _forward_step(self, batch, batch_idx, label=""):
|
def _forward_step(self, batch, batch_idx, loss_func, label=""):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x)
|
||||||
loss = self.loss_func(y_pred, y)
|
loss = loss_func(y_pred, y)
|
||||||
self.log(f"{label}{'_' if label else ''}loss", loss)
|
self.log(f"{label}{'_' if label else ''}loss", loss)
|
||||||
if self.accuracy is not None:
|
if self.accuracy is not None:
|
||||||
acc = self.accuracy(y_pred, y)
|
acc = self.accuracy(y_pred, y)
|
||||||
@@ -20,13 +29,13 @@ class TrainingWrapper(L.LightningModule):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
return self._forward_step(batch, batch_idx, label="train")
|
return self._forward_step(batch, batch_idx, self.train_loss, label="train")
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
self._forward_step(batch, batch_idx, label="val")
|
self._forward_step(batch, batch_idx, self.val_loss, label="val")
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
self._forward_step(batch, batch_idx, label="test")
|
self._forward_step(batch, batch_idx, self.test_loss, label="test")
|
||||||
|
|
||||||
def configure_optimizers(self, optimizer=optim.SGD, **kwargs):
|
def configure_optimizers(self, optimizer=optim.SGD, **kwargs):
|
||||||
_optimizer = optimizer(self.parameters(), **kwargs)
|
_optimizer = optimizer(self.parameters(), **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user