mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
Added more varied semantic loss functions
This commit is contained in:
@@ -1,22 +1,28 @@
|
||||
def main():
|
||||
LEARNING_RATE = 10e-5
|
||||
|
||||
|
||||
def run_test(loss_func, version):
|
||||
from .model import main as test_model
|
||||
from . import semantic_loss
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
logger = TensorBoardLogger(
|
||||
save_dir=".",
|
||||
name="logs/comparison",
|
||||
version=version,
|
||||
)
|
||||
test_model(lr=LEARNING_RATE)
|
||||
# test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
|
||||
|
||||
|
||||
def main():
|
||||
from . import semantic_loss
|
||||
from torch import nn
|
||||
|
||||
logger = TensorBoardLogger(
|
||||
save_dir=".",
|
||||
name="logs/comparison",
|
||||
version="cross_entropy",
|
||||
)
|
||||
test_model(logger=logger, loss_func=nn.functional.cross_entropy)
|
||||
|
||||
logger = TensorBoardLogger(
|
||||
save_dir=".",
|
||||
name="logs/comparison",
|
||||
version="similarity_weighted_cross_entropy",
|
||||
)
|
||||
test_model(logger=logger, loss_func=semantic_loss.similarity_weighted_cross_entropy)
|
||||
run_test(nn.functional.cross_entropy, "cross_entropy")
|
||||
# run_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy")
|
||||
# run_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy")
|
||||
# run_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy")
|
||||
# run_test(semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user