From 08fe32302c8473cf382baa9eaff75644d5d61624 Mon Sep 17 00:00:00 2001 From: Cian Hughes Date: Fri, 7 Jun 2024 16:09:51 +0100 Subject: [PATCH] Made model smaller to avoid overfitting --- symbolic_nn_tests/experiment2/model.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/symbolic_nn_tests/experiment2/model.py b/symbolic_nn_tests/experiment2/model.py index 8bb70bd..48001d1 100644 --- a/symbolic_nn_tests/experiment2/model.py +++ b/symbolic_nn_tests/experiment2/model.py @@ -14,15 +14,7 @@ class Model(nn.Module): self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder) self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder) self.ff = nn.Sequential( - nn.Linear(17, 512), - nn.ReLU(), - nn.Linear(512, 256), - nn.ReLU(), - nn.Linear(256, 128), - nn.ReLU(), - nn.Linear(128, 64), - nn.ReLU(), - nn.Linear(64, 32), + nn.Linear(17, 32), nn.ReLU(), nn.Linear(32, 16), nn.ReLU(), @@ -86,8 +78,8 @@ def main( lmodel = TrainingWrapper( Model(), train_loss=train_loss, - val_loss=train_loss, - test_loss=train_loss, + val_loss=val_loss, + test_loss=test_loss, ) lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs) trainer = L.Trainer(max_epochs=10, logger=logger)