Made model smaller to avoid overfitting

This commit is contained in:
2024-06-07 16:09:51 +01:00
parent c7133a8bb1
commit 08fe32302c

View File

@@ -14,15 +14,7 @@ class Model(nn.Module):
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder) self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder) self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder)
self.ff = nn.Sequential( self.ff = nn.Sequential(
nn.Linear(17, 512), nn.Linear(17, 32),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(), nn.ReLU(),
nn.Linear(32, 16), nn.Linear(32, 16),
nn.ReLU(), nn.ReLU(),
@@ -86,8 +78,8 @@ def main(
lmodel = TrainingWrapper( lmodel = TrainingWrapper(
Model(), Model(),
train_loss=train_loss, train_loss=train_loss,
val_loss=train_loss, val_loss=val_loss,
test_loss=train_loss, test_loss=test_loss,
) )
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs) lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
trainer = L.Trainer(max_epochs=10, logger=logger) trainer = L.Trainer(max_epochs=10, logger=logger)