diff --git a/symbolic_nn_tests/experiment2/model.py b/symbolic_nn_tests/experiment2/model.py index 8bb70bd..48001d1 100644 --- a/symbolic_nn_tests/experiment2/model.py +++ b/symbolic_nn_tests/experiment2/model.py @@ -14,15 +14,7 @@ class Model(nn.Module): self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder) self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder) self.ff = nn.Sequential( - nn.Linear(17, 512), - nn.ReLU(), - nn.Linear(512, 256), - nn.ReLU(), - nn.Linear(256, 128), - nn.ReLU(), - nn.Linear(128, 64), - nn.ReLU(), - nn.Linear(64, 32), + nn.Linear(17, 32), nn.ReLU(), nn.Linear(32, 16), nn.ReLU(), @@ -86,8 +78,8 @@ def main( lmodel = TrainingWrapper( Model(), train_loss=train_loss, - val_loss=train_loss, - test_loss=train_loss, + val_loss=val_loss, + test_loss=test_loss, ) lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs) trainer = L.Trainer(max_epochs=10, logger=logger)