mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
Added overview and plan for expt 2
This commit is contained in:
@@ -2,7 +2,7 @@ LEARNING_RATE = 10e-5
|
||||
|
||||
|
||||
def qmnist_test(loss_func, version):
|
||||
from .qmnist.model import main as test_model
|
||||
from .experiment_1.model import main as test_model
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
import wandb
|
||||
@@ -12,20 +12,18 @@ def qmnist_test(loss_func, version):
|
||||
name="logs/comparison",
|
||||
version=version,
|
||||
)
|
||||
# wandb_logger = WandbLogger(
|
||||
# project="Symbolic_NN_Tests",
|
||||
# name=version,
|
||||
# dir="wandb",
|
||||
# )
|
||||
logger = [
|
||||
tb_logger,
|
||||
] # wandb_logger]
|
||||
wandb_logger = WandbLogger(
|
||||
project="Symbolic_NN_Tests",
|
||||
name=version,
|
||||
dir="wandb",
|
||||
)
|
||||
logger = [tb_logger, wandb_logger]
|
||||
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
|
||||
wandb.finish()
|
||||
|
||||
|
||||
def qmnist_experiment():
|
||||
from .qmnist import semantic_loss
|
||||
from .experiment_1 import semantic_loss
|
||||
from torch import nn
|
||||
|
||||
qmnist_test(nn.functional.cross_entropy, "cross_entropy")
|
||||
|
||||
Reference in New Issue
Block a user