Added separate loss funcs for train, val, and test

This commit is contained in:
2024-06-06 22:26:23 +01:00
parent ea77a055f8
commit 1c21ee25d7
5 changed files with 93 additions and 33 deletions

View File

@@ -1,7 +1,7 @@
LEARNING_RATE = 10e-5 LEARNING_RATE = 10e-5
def test(loss_func, version, tensorboard=True, wandb=True): def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True):
from .model import main as test_model from .model import main as test_model
logger = [] logger = []
@@ -27,7 +27,13 @@ def test(loss_func, version, tensorboard=True, wandb=True):
) )
logger.append(wandb_logger) logger.append(wandb_logger)
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE) test_model(
logger=logger,
train_loss=train_loss,
val_loss=val_loss,
test_loss=test_loss,
lr=LEARNING_RATE,
)
if wandb: if wandb:
_wandb.finish() _wandb.finish()
@@ -35,41 +41,53 @@ def test(loss_func, version, tensorboard=True, wandb=True):
def run(tensorboard: bool = True, wandb: bool = True): def run(tensorboard: bool = True, wandb: bool = True):
from . import semantic_loss from . import semantic_loss
from torch import nn from .model import oh_vs_cat_cross_entropy
test( test(
nn.functional.cross_entropy, train_loss=oh_vs_cat_cross_entropy,
"cross_entropy", val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
version="cross_entropy",
tensorboard=tensorboard, tensorboard=tensorboard,
wandb=wandb, wandb=wandb,
) )
test( test(
semantic_loss.similarity_cross_entropy, train_loss=semantic_loss.similarity_cross_entropy,
"similarity_cross_entropy", val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
version="similarity_cross_entropy",
tensorboard=tensorboard, tensorboard=tensorboard,
wandb=wandb, wandb=wandb,
) )
test( test(
semantic_loss.hasline_cross_entropy, train_loss=semantic_loss.hasline_cross_entropy,
"hasline_cross_entropy", val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
version="hasline_cross_entropy",
tensorboard=tensorboard, tensorboard=tensorboard,
wandb=wandb, wandb=wandb,
) )
test( test(
semantic_loss.hasloop_cross_entropy, train_loss=semantic_loss.hasloop_cross_entropy,
"hasloop_cross_entropy", val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
version="hasloop_cross_entropy",
tensorboard=tensorboard, tensorboard=tensorboard,
wandb=wandb, wandb=wandb,
) )
test( test(
semantic_loss.multisemantic_cross_entropy, train_loss=semantic_loss.multisemantic_cross_entropy,
"multisemantic_cross_entropy", val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
version="multisemantic_cross_entropy",
tensorboard=tensorboard, tensorboard=tensorboard,
wandb=wandb, wandb=wandb,
) )
test( test(
semantic_loss.garbage_cross_entropy, train_loss=semantic_loss.garbage_cross_entropy,
"garbage_cross_entropy", val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
version="garbage_cross_entropy",
tensorboard=tensorboard, tensorboard=tensorboard,
wandb=wandb, wandb=wandb,
) )

View File

@@ -38,7 +38,13 @@ def oh_vs_cat_cross_entropy(y_bin, y_cat):
) )
def main(loss_func=oh_vs_cat_cross_entropy, logger=None, **kwargs): def main(
train_loss=oh_vs_cat_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
logger=None,
**kwargs,
):
import lightning as L import lightning as L
from symbolic_nn_tests.train import TrainingWrapper from symbolic_nn_tests.train import TrainingWrapper
@@ -49,7 +55,9 @@ def main(loss_func=oh_vs_cat_cross_entropy, logger=None, **kwargs):
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn") logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
train, val, test = get_singleton_dataset() train, val, test = get_singleton_dataset()
lmodel = TrainingWrapper(model, loss_func=loss_func) lmodel = TrainingWrapper(
model, train_loss=train_loss, val_loss=val_loss, test_loss=val_loss
)
lmodel.configure_optimizers(**kwargs) lmodel.configure_optimizers(**kwargs)
trainer = L.Trainer(max_epochs=20, logger=logger) trainer = L.Trainer(max_epochs=20, logger=logger)
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val) trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)

View File

@@ -1,7 +1,7 @@
LEARNING_RATE = 10e-5 LEARNING_RATE = 10e-5
def test(loss_func, version, tensorboard=True, wandb=True): def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True):
from .model import main as test_model from .model import main as test_model
logger = [] logger = []
@@ -27,7 +27,13 @@ def test(loss_func, version, tensorboard=True, wandb=True):
) )
logger.append(wandb_logger) logger.append(wandb_logger)
test_model(logger=logger, loss_func=loss_func) test_model(
logger=logger,
train_loss=train_loss,
val_loss=val_loss,
test_loss=test_loss,
lr=LEARNING_RATE,
)
if wandb: if wandb:
_wandb.finish() _wandb.finish()
@@ -38,14 +44,18 @@ def run(tensorboard: bool = True, wandb: bool = True):
from . import semantic_loss from . import semantic_loss
test( test(
unpacking_mse_loss, train_loss=unpacking_mse_loss,
"mse_loss", val_loss=unpacking_mse_loss,
test_loss=unpacking_mse_loss,
version="mse_loss",
tensorboard=tensorboard, tensorboard=tensorboard,
wandb=wandb, wandb=wandb,
) )
test( test(
semantic_loss.positive_slope_linear_loss, train_loss=semantic_loss.positive_slope_linear_loss,
"positive_slope_linear_loss", val_loss=unpacking_mse_loss,
test_loss=unpacking_mse_loss,
version="positive_slope_linear_loss",
tensorboard=tensorboard, tensorboard=tensorboard,
wandb=wandb, wandb=wandb,
) )

View File

@@ -14,7 +14,11 @@ class Model(nn.Module):
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder) self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder) self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder)
self.ff = nn.Sequential( self.ff = nn.Sequential(
nn.Linear(17, 128), nn.Linear(17, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(), nn.ReLU(),
nn.Linear(128, 64), nn.Linear(128, 64),
nn.ReLU(), nn.ReLU(),
@@ -62,7 +66,13 @@ def unpacking_mse_loss(out, y):
return nn.functional.mse_loss(y_pred, y) return nn.functional.mse_loss(y_pred, y)
def main(loss_func=unpacking_mse_loss, logger=None, **kwargs): def main(
train_loss=unpacking_mse_loss,
val_loss=unpacking_mse_loss,
test_loss=unpacking_mse_loss,
logger=None,
**kwargs,
):
import lightning as L import lightning as L
from symbolic_nn_tests.train import TrainingWrapper from symbolic_nn_tests.train import TrainingWrapper
@@ -73,7 +83,12 @@ def main(loss_func=unpacking_mse_loss, logger=None, **kwargs):
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn") logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
train, val, test = get_singleton_dataset() train, val, test = get_singleton_dataset()
lmodel = TrainingWrapper(Model(), loss_func=loss_func) lmodel = TrainingWrapper(
Model(),
train_loss=train_loss,
val_loss=train_loss,
test_loss=train_loss,
)
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs) lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
trainer = L.Trainer(max_epochs=10, logger=logger) trainer = L.Trainer(max_epochs=10, logger=logger)
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val) trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)

View File

@@ -3,16 +3,25 @@ import lightning as L
class TrainingWrapper(L.LightningModule): class TrainingWrapper(L.LightningModule):
def __init__(self, model, loss_func=nn.functional.mse_loss, accuracy=None): def __init__(
self,
model,
train_loss=nn.functional.mse_loss,
val_loss=nn.functional.mse_loss,
test_loss=nn.functional.mse_loss,
accuracy=None,
):
super().__init__() super().__init__()
self.model = model self.model = model
self.loss_func = loss_func self.train_loss = train_loss
self.val_loss = val_loss
self.test_loss = val_loss
self.accuracy = accuracy self.accuracy = accuracy
def _forward_step(self, batch, batch_idx, label=""): def _forward_step(self, batch, batch_idx, loss_func, label=""):
x, y = batch x, y = batch
y_pred = self.model(x) y_pred = self.model(x)
loss = self.loss_func(y_pred, y) loss = loss_func(y_pred, y)
self.log(f"{label}{'_' if label else ''}loss", loss) self.log(f"{label}{'_' if label else ''}loss", loss)
if self.accuracy is not None: if self.accuracy is not None:
acc = self.accuracy(y_pred, y) acc = self.accuracy(y_pred, y)
@@ -20,13 +29,13 @@ class TrainingWrapper(L.LightningModule):
return loss return loss
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
return self._forward_step(batch, batch_idx, label="train") return self._forward_step(batch, batch_idx, self.train_loss, label="train")
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
self._forward_step(batch, batch_idx, label="val") self._forward_step(batch, batch_idx, self.val_loss, label="val")
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
self._forward_step(batch, batch_idx, label="test") self._forward_step(batch, batch_idx, self.test_loss, label="test")
def configure_optimizers(self, optimizer=optim.SGD, **kwargs): def configure_optimizers(self, optimizer=optim.SGD, **kwargs):
_optimizer = optimizer(self.parameters(), **kwargs) _optimizer = optimizer(self.parameters(), **kwargs)