mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 14:11:59 +00:00
Made model smaller to avoid overfitting
This commit is contained in:
@@ -14,15 +14,7 @@ class Model(nn.Module):
|
||||
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
|
||||
self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder)
|
||||
self.ff = nn.Sequential(
|
||||
nn.Linear(17, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 32),
|
||||
nn.Linear(17, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, 16),
|
||||
nn.ReLU(),
|
||||
@@ -86,8 +78,8 @@ def main(
|
||||
lmodel = TrainingWrapper(
|
||||
Model(),
|
||||
train_loss=train_loss,
|
||||
val_loss=train_loss,
|
||||
test_loss=train_loss,
|
||||
val_loss=val_loss,
|
||||
test_loss=test_loss,
|
||||
)
|
||||
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
|
||||
trainer = L.Trainer(max_epochs=10, logger=logger)
|
||||
|
||||
Reference in New Issue
Block a user