mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
Refactored slightly to allow other experiments
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
LEARNING_RATE = 10e-5
|
LEARNING_RATE = 10e-5
|
||||||
|
|
||||||
|
|
||||||
def run_test(loss_func, version):
|
def qmnist_test(loss_func, version):
|
||||||
from .model import main as test_model
|
from .qmnist.model import main as test_model
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
from lightning.pytorch.loggers import WandbLogger
|
from lightning.pytorch.loggers import WandbLogger
|
||||||
import wandb
|
import wandb
|
||||||
@@ -12,26 +12,34 @@ def run_test(loss_func, version):
|
|||||||
name="logs/comparison",
|
name="logs/comparison",
|
||||||
version=version,
|
version=version,
|
||||||
)
|
)
|
||||||
wandb_logger = WandbLogger(
|
# wandb_logger = WandbLogger(
|
||||||
project="Symbolic_NN_Tests",
|
# project="Symbolic_NN_Tests",
|
||||||
name=version,
|
# name=version,
|
||||||
dir="wandb",
|
# dir="wandb",
|
||||||
)
|
# )
|
||||||
logger = [tb_logger, wandb_logger]
|
logger = [
|
||||||
|
tb_logger,
|
||||||
|
] # wandb_logger]
|
||||||
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
|
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
|
||||||
wandb.finish()
|
wandb.finish()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def qmnist_experiment():
|
||||||
from . import semantic_loss
|
from .qmnist import semantic_loss
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
run_test(nn.functional.cross_entropy, "cross_entropy")
|
qmnist_test(nn.functional.cross_entropy, "cross_entropy")
|
||||||
run_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy")
|
qmnist_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy")
|
||||||
run_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy")
|
qmnist_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy")
|
||||||
run_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy")
|
qmnist_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy")
|
||||||
run_test(semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy")
|
qmnist_test(
|
||||||
run_test(semantic_loss.garbage_cross_entropy, "garbage_cross_entropy")
|
semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy"
|
||||||
|
)
|
||||||
|
qmnist_test(semantic_loss.garbage_cross_entropy, "garbage_cross_entropy")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
qmnist_experiment()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ model = nn.Sequential(
|
|||||||
def get_singleton_dataset():
|
def get_singleton_dataset():
|
||||||
from torchvision.datasets import QMNIST
|
from torchvision.datasets import QMNIST
|
||||||
|
|
||||||
from .dataloader import get_dataset
|
from symbolic_nn_tests.dataloader import get_dataset
|
||||||
|
|
||||||
return get_dataset(dataset=QMNIST)
|
return get_dataset(dataset=QMNIST)
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ def get_singleton_dataset():
|
|||||||
def main(loss_func=nn.functional.cross_entropy, logger=None, **kwargs):
|
def main(loss_func=nn.functional.cross_entropy, logger=None, **kwargs):
|
||||||
import lightning as L
|
import lightning as L
|
||||||
|
|
||||||
from .train import TrainingWrapper
|
from symbolic_nn_tests.train import TrainingWrapper
|
||||||
|
|
||||||
if logger is None:
|
if logger is None:
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
Reference in New Issue
Block a user