mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 14:11:59 +00:00
Made model smaller to avoid overfitting
This commit is contained in:
@@ -14,15 +14,7 @@ class Model(nn.Module):
|
|||||||
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
|
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
|
||||||
self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder)
|
self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder)
|
||||||
self.ff = nn.Sequential(
|
self.ff = nn.Sequential(
|
||||||
nn.Linear(17, 512),
|
nn.Linear(17, 32),
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(512, 256),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(256, 128),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(128, 64),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(64, 32),
|
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(32, 16),
|
nn.Linear(32, 16),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
@@ -86,8 +78,8 @@ def main(
|
|||||||
lmodel = TrainingWrapper(
|
lmodel = TrainingWrapper(
|
||||||
Model(),
|
Model(),
|
||||||
train_loss=train_loss,
|
train_loss=train_loss,
|
||||||
val_loss=train_loss,
|
val_loss=val_loss,
|
||||||
test_loss=train_loss,
|
test_loss=test_loss,
|
||||||
)
|
)
|
||||||
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
|
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
|
||||||
trainer = L.Trainer(max_epochs=10, logger=logger)
|
trainer = L.Trainer(max_epochs=10, logger=logger)
|
||||||
|
|||||||
Reference in New Issue
Block a user