From 2720887f8806cfd105eb8869a1b4e7d94ecc5d6f Mon Sep 17 00:00:00 2001 From: Cian Hughes Date: Thu, 23 May 2024 11:06:11 +0100 Subject: [PATCH] Added cli for configuring/launching experiments --- poetry.lock | 48 +++++++++++- pyproject.toml | 1 + symbolic_nn_tests/__main__.py | 33 +++++--- symbolic_nn_tests/experiment1/__init__.py | 75 +++++++++++++++++++ .../{experiment_1 => experiment1}/model.py | 0 .../semantic_loss.py | 0 6 files changed, 145 insertions(+), 12 deletions(-) create mode 100644 symbolic_nn_tests/experiment1/__init__.py rename symbolic_nn_tests/{experiment_1 => experiment1}/model.py (100%) rename symbolic_nn_tests/{experiment_1 => experiment1}/semantic_loss.py (100%) diff --git a/poetry.lock b/poetry.lock index 2a60f50..0126971 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2933,6 +2933,24 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rich" +version = "13.3.1" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "rich-13.3.1-py3-none-any.whl", hash = "sha256:8aa57747f3fc3e977684f0176a88e789be314a99f99b43b75d1e9cb5dc6db9e9"}, + {file = "rich-13.3.1.tar.gz", hash = "sha256:125d96d20c92b946b983d0d392b84ff945461e5a06d3867e9f9e575f8697b67f"}, +] + +[package.dependencies] +markdown-it-py = ">=2.1.0,<3.0.0" +pygments = ">=2.14.0,<3.0.0" + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + [[package]] name = "rpds-py" version = "0.18.1" @@ -3207,6 +3225,17 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +[[package]] +name = "shellingham" +version = "1.5.4" +description = "Tool to Detect Surrounding Shell" +optional = false +python-versions = ">=3.7" +files = [ + {file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"}, + {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"}, +] + [[package]] name = "six" version = "1.16.0" @@ -3629,6 +3658,23 @@ build = ["cmake (>=3.20)", "lit"] tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] tutorials = ["matplotlib", "pandas", "tabulate", "torch"] +[[package]] +name = "typer" +version = "0.12.3" +description = "Typer, build great CLIs. Easy to code. Based on Python type hints." +optional = false +python-versions = ">=3.7" +files = [ + {file = "typer-0.12.3-py3-none-any.whl", hash = "sha256:070d7ca53f785acbccba8e7d28b08dcd88f79f1fbda035ade0aecec71ca5c914"}, + {file = "typer-0.12.3.tar.gz", hash = "sha256:49e73131481d804288ef62598d97a1ceef3058905aa536a1134f90891ba35482"}, +] + +[package.dependencies] +click = ">=8.0.0" +rich = ">=10.11.0" +shellingham = ">=1.3.0" +typing-extensions = ">=3.7.4.3" + [[package]] name = "typing-extensions" version = "4.11.0" @@ -3865,4 +3911,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "035f9d736293c9021a1dcdf67fd794bcc9034180dd2f458d9132e732daca06d9" +content-hash = "8bf3620a195c51ef0b5bf1cc4e98299c43e4b288cd00f38c167764fb917522f8" diff --git a/pyproject.toml b/pyproject.toml index 7f4e178..42fafb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ matplotlib-backend-kitty = "^2.1.2" euporie = "^2.8.2" ipykernel = "^6.29.4" tensorboard = "^2.16.2" +typer = "^0.12.3" [build-system] diff --git a/symbolic_nn_tests/__main__.py b/symbolic_nn_tests/__main__.py index a798059..3a59f26 100644 --- a/symbolic_nn_tests/__main__.py +++ b/symbolic_nn_tests/__main__.py @@ -1,20 +1,31 @@ +import typer from typing import Optional, Iterable -import experiment1 +from typing_extensions import Annotated +from . import experiment1 EXPERIMENTS = (experiment1,) -def main( - experiments: Optional[int | Iterable[int]] = None, - tensorboard: bool = True, - wandb: bool = True, -): - if experiments is None: - experiments = range(1, len(EXPERIMENTS) + 1) - elif not isinstance(experiments, Iterable): - experiments = (experiments,) +def parse_int_or_intiterable(i: Optional[str]) -> Iterable[int]: + return range(1, len(EXPERIMENTS) + 1) if i is None else map(int, i.split(",")) + +def main( + experiments: Annotated[ + Optional[str], + typer.Option( + parser=parse_int_or_intiterable, + help="A comma separated list of experiments to be run. Defaults to all.", + ), + ] = None, + tensorboard: Annotated[ + bool, typer.Option(help="Whether or not to log via tensorboard") + ] = True, + wandb: Annotated[ + bool, typer.Option(help="Whether or not to log via Weights & Biases") + ] = True, +): experiment_indeces = (i - 1 for i in experiments) experiment_funcs = [EXPERIMENTS[i].run for i in experiment_indeces] @@ -23,4 +34,4 @@ def main( if __name__ == "__main__": - main() + typer.run(main) diff --git a/symbolic_nn_tests/experiment1/__init__.py b/symbolic_nn_tests/experiment1/__init__.py new file mode 100644 index 0000000..e845b79 --- /dev/null +++ b/symbolic_nn_tests/experiment1/__init__.py @@ -0,0 +1,75 @@ +LEARNING_RATE = 10e-5 + + +def test(loss_func, version, tensorboard=True, wandb=True): + from .model import main as test_model + + logger = [] + + if tensorboard: + from lightning.pytorch.loggers import TensorBoardLogger + + tb_logger = TensorBoardLogger( + save_dir=".", + name="logs/comparison", + version=version, + ) + logger.append(tb_logger) + + if wandb: + import wandb as _wandb + from lightning.pytorch.loggers import WandbLogger + + wandb_logger = WandbLogger( + project="Symbolic_NN_Tests", + name=version, + dir="wandb", + ) + logger.append(wandb_logger) + + test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE) + + if wandb: + _wandb.finish() + + +def run(tensorboard: bool = True, wandb: bool = True): + from . import semantic_loss + from torch import nn + + test( + nn.functional.cross_entropy, + "cross_entropy", + tensorboard=tensorboard, + wandb=wandb, + ) + test( + semantic_loss.similarity_cross_entropy, + "similarity_cross_entropy", + tensorboard=tensorboard, + wandb=wandb, + ) + test( + semantic_loss.hasline_cross_entropy, + "hasline_cross_entropy", + tensorboard=tensorboard, + wandb=wandb, + ) + test( + semantic_loss.hasloop_cross_entropy, + "hasloop_cross_entropy", + tensorboard=tensorboard, + wandb=wandb, + ) + test( + semantic_loss.multisemantic_cross_entropy, + "multisemantic_cross_entropy", + tensorboard=tensorboard, + wandb=wandb, + ) + test( + semantic_loss.garbage_cross_entropy, + "garbage_cross_entropy", + tensorboard=tensorboard, + wandb=wandb, + ) diff --git a/symbolic_nn_tests/experiment_1/model.py b/symbolic_nn_tests/experiment1/model.py similarity index 100% rename from symbolic_nn_tests/experiment_1/model.py rename to symbolic_nn_tests/experiment1/model.py diff --git a/symbolic_nn_tests/experiment_1/semantic_loss.py b/symbolic_nn_tests/experiment1/semantic_loss.py similarity index 100% rename from symbolic_nn_tests/experiment_1/semantic_loss.py rename to symbolic_nn_tests/experiment1/semantic_loss.py