diff --git a/symbolic_nn_tests/experiment1/__init__.py b/symbolic_nn_tests/experiment1/__init__.py index e845b79..4c9d141 100644 --- a/symbolic_nn_tests/experiment1/__init__.py +++ b/symbolic_nn_tests/experiment1/__init__.py @@ -1,7 +1,7 @@ LEARNING_RATE = 10e-5 -def test(loss_func, version, tensorboard=True, wandb=True): +def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True): from .model import main as test_model logger = [] @@ -27,7 +27,13 @@ def test(loss_func, version, tensorboard=True, wandb=True): ) logger.append(wandb_logger) - test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE) + test_model( + logger=logger, + train_loss=train_loss, + val_loss=val_loss, + test_loss=test_loss, + lr=LEARNING_RATE, + ) if wandb: _wandb.finish() @@ -35,41 +41,53 @@ def test(loss_func, version, tensorboard=True, wandb=True): def run(tensorboard: bool = True, wandb: bool = True): from . import semantic_loss - from torch import nn + from .model import oh_vs_cat_cross_entropy test( - nn.functional.cross_entropy, - "cross_entropy", + train_loss=oh_vs_cat_cross_entropy, + val_loss=oh_vs_cat_cross_entropy, + test_loss=oh_vs_cat_cross_entropy, + version="cross_entropy", tensorboard=tensorboard, wandb=wandb, ) test( - semantic_loss.similarity_cross_entropy, - "similarity_cross_entropy", + train_loss=semantic_loss.similarity_cross_entropy, + val_loss=oh_vs_cat_cross_entropy, + test_loss=oh_vs_cat_cross_entropy, + version="similarity_cross_entropy", tensorboard=tensorboard, wandb=wandb, ) test( - semantic_loss.hasline_cross_entropy, - "hasline_cross_entropy", + train_loss=semantic_loss.hasline_cross_entropy, + val_loss=oh_vs_cat_cross_entropy, + test_loss=oh_vs_cat_cross_entropy, + version="hasline_cross_entropy", tensorboard=tensorboard, wandb=wandb, ) test( - semantic_loss.hasloop_cross_entropy, - "hasloop_cross_entropy", + train_loss=semantic_loss.hasloop_cross_entropy, + val_loss=oh_vs_cat_cross_entropy, + test_loss=oh_vs_cat_cross_entropy, + version="hasloop_cross_entropy", tensorboard=tensorboard, wandb=wandb, ) test( - semantic_loss.multisemantic_cross_entropy, - "multisemantic_cross_entropy", + train_loss=semantic_loss.multisemantic_cross_entropy, + val_loss=oh_vs_cat_cross_entropy, + test_loss=oh_vs_cat_cross_entropy, + version="multisemantic_cross_entropy", tensorboard=tensorboard, wandb=wandb, ) test( - semantic_loss.garbage_cross_entropy, - "garbage_cross_entropy", + train_loss=semantic_loss.garbage_cross_entropy, + val_loss=oh_vs_cat_cross_entropy, + test_loss=oh_vs_cat_cross_entropy, + version="garbage_cross_entropy", tensorboard=tensorboard, wandb=wandb, ) diff --git a/symbolic_nn_tests/experiment1/model.py b/symbolic_nn_tests/experiment1/model.py index a3bfa11..5679e4b 100644 --- a/symbolic_nn_tests/experiment1/model.py +++ b/symbolic_nn_tests/experiment1/model.py @@ -38,7 +38,13 @@ def oh_vs_cat_cross_entropy(y_bin, y_cat): ) -def main(loss_func=oh_vs_cat_cross_entropy, logger=None, **kwargs): +def main( + train_loss=oh_vs_cat_cross_entropy, + val_loss=oh_vs_cat_cross_entropy, + test_loss=oh_vs_cat_cross_entropy, + logger=None, + **kwargs, +): import lightning as L from symbolic_nn_tests.train import TrainingWrapper @@ -49,7 +55,9 @@ def main(loss_func=oh_vs_cat_cross_entropy, logger=None, **kwargs): logger = TensorBoardLogger(save_dir=".", name="logs/ffnn") train, val, test = get_singleton_dataset() - lmodel = TrainingWrapper(model, loss_func=loss_func) + lmodel = TrainingWrapper( + model, train_loss=train_loss, val_loss=val_loss, test_loss=val_loss + ) lmodel.configure_optimizers(**kwargs) trainer = L.Trainer(max_epochs=20, logger=logger) trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val) diff --git a/symbolic_nn_tests/experiment2/__init__.py b/symbolic_nn_tests/experiment2/__init__.py index 906af50..1e756e0 100644 --- a/symbolic_nn_tests/experiment2/__init__.py +++ b/symbolic_nn_tests/experiment2/__init__.py @@ -1,7 +1,7 @@ LEARNING_RATE = 10e-5 -def test(loss_func, version, tensorboard=True, wandb=True): +def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True): from .model import main as test_model logger = [] @@ -27,7 +27,13 @@ def test(loss_func, version, tensorboard=True, wandb=True): ) logger.append(wandb_logger) - test_model(logger=logger, loss_func=loss_func) + test_model( + logger=logger, + train_loss=train_loss, + val_loss=val_loss, + test_loss=test_loss, + lr=LEARNING_RATE, + ) if wandb: _wandb.finish() @@ -38,14 +44,18 @@ def run(tensorboard: bool = True, wandb: bool = True): from . import semantic_loss test( - unpacking_mse_loss, - "mse_loss", + train_loss=unpacking_mse_loss, + val_loss=unpacking_mse_loss, + test_loss=unpacking_mse_loss, + version="mse_loss", tensorboard=tensorboard, wandb=wandb, ) test( - semantic_loss.positive_slope_linear_loss, - "positive_slope_linear_loss", + train_loss=semantic_loss.positive_slope_linear_loss, + val_loss=unpacking_mse_loss, + test_loss=unpacking_mse_loss, + version="positive_slope_linear_loss", tensorboard=tensorboard, wandb=wandb, ) diff --git a/symbolic_nn_tests/experiment2/model.py b/symbolic_nn_tests/experiment2/model.py index e7b53de..d1d207a 100644 --- a/symbolic_nn_tests/experiment2/model.py +++ b/symbolic_nn_tests/experiment2/model.py @@ -14,7 +14,11 @@ class Model(nn.Module): self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder) self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder) self.ff = nn.Sequential( - nn.Linear(17, 128), + nn.Linear(17, 512), + nn.ReLU(), + nn.Linear(512, 256), + nn.ReLU(), + nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), @@ -62,7 +66,13 @@ def unpacking_mse_loss(out, y): return nn.functional.mse_loss(y_pred, y) -def main(loss_func=unpacking_mse_loss, logger=None, **kwargs): +def main( + train_loss=unpacking_mse_loss, + val_loss=unpacking_mse_loss, + test_loss=unpacking_mse_loss, + logger=None, + **kwargs, +): import lightning as L from symbolic_nn_tests.train import TrainingWrapper @@ -73,7 +83,12 @@ def main(loss_func=unpacking_mse_loss, logger=None, **kwargs): logger = TensorBoardLogger(save_dir=".", name="logs/ffnn") train, val, test = get_singleton_dataset() - lmodel = TrainingWrapper(Model(), loss_func=loss_func) + lmodel = TrainingWrapper( + Model(), + train_loss=train_loss, + val_loss=train_loss, + test_loss=train_loss, + ) lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs) trainer = L.Trainer(max_epochs=10, logger=logger) trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val) diff --git a/symbolic_nn_tests/train.py b/symbolic_nn_tests/train.py index a82969f..b1dc604 100644 --- a/symbolic_nn_tests/train.py +++ b/symbolic_nn_tests/train.py @@ -3,16 +3,25 @@ import lightning as L class TrainingWrapper(L.LightningModule): - def __init__(self, model, loss_func=nn.functional.mse_loss, accuracy=None): + def __init__( + self, + model, + train_loss=nn.functional.mse_loss, + val_loss=nn.functional.mse_loss, + test_loss=nn.functional.mse_loss, + accuracy=None, + ): super().__init__() self.model = model - self.loss_func = loss_func + self.train_loss = train_loss + self.val_loss = val_loss + self.test_loss = val_loss self.accuracy = accuracy - def _forward_step(self, batch, batch_idx, label=""): + def _forward_step(self, batch, batch_idx, loss_func, label=""): x, y = batch y_pred = self.model(x) - loss = self.loss_func(y_pred, y) + loss = loss_func(y_pred, y) self.log(f"{label}{'_' if label else ''}loss", loss) if self.accuracy is not None: acc = self.accuracy(y_pred, y) @@ -20,13 +29,13 @@ class TrainingWrapper(L.LightningModule): return loss def training_step(self, batch, batch_idx): - return self._forward_step(batch, batch_idx, label="train") + return self._forward_step(batch, batch_idx, self.train_loss, label="train") def validation_step(self, batch, batch_idx): - self._forward_step(batch, batch_idx, label="val") + self._forward_step(batch, batch_idx, self.val_loss, label="val") def test_step(self, batch, batch_idx): - self._forward_step(batch, batch_idx, label="test") + self._forward_step(batch, batch_idx, self.test_loss, label="test") def configure_optimizers(self, optimizer=optim.SGD, **kwargs): _optimizer = optimizer(self.parameters(), **kwargs)