mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 14:11:59 +00:00
Refactored to lay groundwork for TUI
This commit is contained in:
@@ -1,54 +1,25 @@
|
||||
LEARNING_RATE = 10e-5
|
||||
from typing import Optional, Iterable
|
||||
import experiment1
|
||||
|
||||
|
||||
def qmnist_test(loss_func, version, tensorboard=True, wandb=True):
|
||||
from .experiment_1.model import main as test_model
|
||||
|
||||
logger = []
|
||||
|
||||
if tensorboard:
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
tb_logger = TensorBoardLogger(
|
||||
save_dir=".",
|
||||
name="logs/comparison",
|
||||
version=version,
|
||||
)
|
||||
logger.append(tb_logger)
|
||||
|
||||
if wandb:
|
||||
import wandb as _wandb
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
|
||||
wandb_logger = WandbLogger(
|
||||
project="Symbolic_NN_Tests",
|
||||
name=version,
|
||||
dir="wandb",
|
||||
)
|
||||
logger.append(wandb_logger)
|
||||
|
||||
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
|
||||
|
||||
if wandb:
|
||||
_wandb.finish()
|
||||
EXPERIMENTS = (experiment1,)
|
||||
|
||||
|
||||
def qmnist_experiment():
|
||||
from .experiment_1 import semantic_loss
|
||||
from torch import nn
|
||||
def main(
|
||||
experiments: Optional[int | Iterable[int]] = None,
|
||||
tensorboard: bool = True,
|
||||
wandb: bool = True,
|
||||
):
|
||||
if experiments is None:
|
||||
experiments = range(1, len(EXPERIMENTS) + 1)
|
||||
elif not isinstance(experiments, Iterable):
|
||||
experiments = (experiments,)
|
||||
|
||||
qmnist_test(nn.functional.cross_entropy, "cross_entropy")
|
||||
qmnist_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy")
|
||||
qmnist_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy")
|
||||
qmnist_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy")
|
||||
qmnist_test(
|
||||
semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy"
|
||||
)
|
||||
qmnist_test(semantic_loss.garbage_cross_entropy, "garbage_cross_entropy")
|
||||
experiment_indeces = (i - 1 for i in experiments)
|
||||
experiment_funcs = [EXPERIMENTS[i].run for i in experiment_indeces]
|
||||
|
||||
|
||||
def main():
|
||||
qmnist_experiment()
|
||||
for experiment in experiment_funcs:
|
||||
experiment(tensorboard=tensorboard, wandb=wandb)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user