mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-23 14:42:01 +00:00
Modified model to use SGD for less optimal training
There's no point having a loss function testing sandbox where the baseling trains perfectly in 2 epochs. I've DEoptimised the training in this commit to ensure the testing sandbox is actually useful.
This commit is contained in:
@@ -19,7 +19,7 @@ def get_singleton_dataset():
|
|||||||
return get_dataset(dataset=QMNIST)
|
return get_dataset(dataset=QMNIST)
|
||||||
|
|
||||||
|
|
||||||
def main(loss_func=nn.functional.cross_entropy, logger=None):
|
def main(loss_func=nn.functional.cross_entropy, logger=None, **kwargs):
|
||||||
import lightning as L
|
import lightning as L
|
||||||
|
|
||||||
from .train import TrainingWrapper
|
from .train import TrainingWrapper
|
||||||
@@ -31,7 +31,8 @@ def main(loss_func=nn.functional.cross_entropy, logger=None):
|
|||||||
|
|
||||||
train, val, test = get_singleton_dataset()
|
train, val, test = get_singleton_dataset()
|
||||||
lmodel = TrainingWrapper(model, loss_func=loss_func)
|
lmodel = TrainingWrapper(model, loss_func=loss_func)
|
||||||
trainer = L.Trainer(max_epochs=5, logger=logger)
|
lmodel.configure_optimizers(**kwargs)
|
||||||
|
trainer = L.Trainer(max_epochs=20, logger=logger)
|
||||||
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user