Added more varied semantic loss functions

This commit is contained in:
2024-05-15 13:37:12 +01:00
parent 01127de4b3
commit 6600a79f71
3 changed files with 74 additions and 23 deletions

View File

@@ -1,22 +1,28 @@
def main():
LEARNING_RATE = 10e-5
def run_test(loss_func, version):
from .model import main as test_model
from . import semantic_loss
from lightning.pytorch.loggers import TensorBoardLogger
logger = TensorBoardLogger(
save_dir=".",
name="logs/comparison",
version=version,
)
test_model(lr=LEARNING_RATE)
# test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
def main():
from . import semantic_loss
from torch import nn
logger = TensorBoardLogger(
save_dir=".",
name="logs/comparison",
version="cross_entropy",
)
test_model(logger=logger, loss_func=nn.functional.cross_entropy)
logger = TensorBoardLogger(
save_dir=".",
name="logs/comparison",
version="similarity_weighted_cross_entropy",
)
test_model(logger=logger, loss_func=semantic_loss.similarity_weighted_cross_entropy)
run_test(nn.functional.cross_entropy, "cross_entropy")
# run_test(semantic_loss.similarity_cross_entropy, "similarity_cross_entropy")
# run_test(semantic_loss.hasline_cross_entropy, "hasline_cross_entropy")
# run_test(semantic_loss.hasloop_cross_entropy, "hasloop_cross_entropy")
# run_test(semantic_loss.multisemantic_cross_entropy, "multisemantic_cross_entropy")
if __name__ == "__main__":