Added cli for configuring/launching experiments

This commit is contained in:
2024-05-23 11:06:11 +01:00
parent 9536beea04
commit 2720887f88
6 changed files with 145 additions and 12 deletions

48
poetry.lock generated
View File

@@ -2933,6 +2933,24 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rich"
version = "13.3.1"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "rich-13.3.1-py3-none-any.whl", hash = "sha256:8aa57747f3fc3e977684f0176a88e789be314a99f99b43b75d1e9cb5dc6db9e9"},
{file = "rich-13.3.1.tar.gz", hash = "sha256:125d96d20c92b946b983d0d392b84ff945461e5a06d3867e9f9e575f8697b67f"},
]
[package.dependencies]
markdown-it-py = ">=2.1.0,<3.0.0"
pygments = ">=2.14.0,<3.0.0"
[package.extras]
jupyter = ["ipywidgets (>=7.5.1,<9)"]
[[package]]
name = "rpds-py"
version = "0.18.1"
@@ -3207,6 +3225,17 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
[[package]]
name = "shellingham"
version = "1.5.4"
description = "Tool to Detect Surrounding Shell"
optional = false
python-versions = ">=3.7"
files = [
{file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"},
{file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"},
]
[[package]]
name = "six"
version = "1.16.0"
@@ -3629,6 +3658,23 @@ build = ["cmake (>=3.20)", "lit"]
tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"]
tutorials = ["matplotlib", "pandas", "tabulate", "torch"]
[[package]]
name = "typer"
version = "0.12.3"
description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
optional = false
python-versions = ">=3.7"
files = [
{file = "typer-0.12.3-py3-none-any.whl", hash = "sha256:070d7ca53f785acbccba8e7d28b08dcd88f79f1fbda035ade0aecec71ca5c914"},
{file = "typer-0.12.3.tar.gz", hash = "sha256:49e73131481d804288ef62598d97a1ceef3058905aa536a1134f90891ba35482"},
]
[package.dependencies]
click = ">=8.0.0"
rich = ">=10.11.0"
shellingham = ">=1.3.0"
typing-extensions = ">=3.7.4.3"
[[package]]
name = "typing-extensions"
version = "4.11.0"
@@ -3865,4 +3911,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "035f9d736293c9021a1dcdf67fd794bcc9034180dd2f458d9132e732daca06d9"
content-hash = "8bf3620a195c51ef0b5bf1cc4e98299c43e4b288cd00f38c167764fb917522f8"

View File

@@ -21,6 +21,7 @@ matplotlib-backend-kitty = "^2.1.2"
euporie = "^2.8.2"
ipykernel = "^6.29.4"
tensorboard = "^2.16.2"
typer = "^0.12.3"
[build-system]

View File

@@ -1,20 +1,31 @@
import typer
from typing import Optional, Iterable
import experiment1
from typing_extensions import Annotated
from . import experiment1
EXPERIMENTS = (experiment1,)
def main(
experiments: Optional[int | Iterable[int]] = None,
tensorboard: bool = True,
wandb: bool = True,
):
if experiments is None:
experiments = range(1, len(EXPERIMENTS) + 1)
elif not isinstance(experiments, Iterable):
experiments = (experiments,)
def parse_int_or_intiterable(i: Optional[str]) -> Iterable[int]:
return range(1, len(EXPERIMENTS) + 1) if i is None else map(int, i.split(","))
def main(
experiments: Annotated[
Optional[str],
typer.Option(
parser=parse_int_or_intiterable,
help="A comma separated list of experiments to be run. Defaults to all.",
),
] = None,
tensorboard: Annotated[
bool, typer.Option(help="Whether or not to log via tensorboard")
] = True,
wandb: Annotated[
bool, typer.Option(help="Whether or not to log via Weights & Biases")
] = True,
):
experiment_indeces = (i - 1 for i in experiments)
experiment_funcs = [EXPERIMENTS[i].run for i in experiment_indeces]
@@ -23,4 +34,4 @@ def main(
if __name__ == "__main__":
main()
typer.run(main)

View File

@@ -0,0 +1,75 @@
LEARNING_RATE = 10e-5
def test(loss_func, version, tensorboard=True, wandb=True):
from .model import main as test_model
logger = []
if tensorboard:
from lightning.pytorch.loggers import TensorBoardLogger
tb_logger = TensorBoardLogger(
save_dir=".",
name="logs/comparison",
version=version,
)
logger.append(tb_logger)
if wandb:
import wandb as _wandb
from lightning.pytorch.loggers import WandbLogger
wandb_logger = WandbLogger(
project="Symbolic_NN_Tests",
name=version,
dir="wandb",
)
logger.append(wandb_logger)
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
if wandb:
_wandb.finish()
def run(tensorboard: bool = True, wandb: bool = True):
from . import semantic_loss
from torch import nn
test(
nn.functional.cross_entropy,
"cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
)
test(
semantic_loss.similarity_cross_entropy,
"similarity_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
)
test(
semantic_loss.hasline_cross_entropy,
"hasline_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
)
test(
semantic_loss.hasloop_cross_entropy,
"hasloop_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
)
test(
semantic_loss.multisemantic_cross_entropy,
"multisemantic_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
)
test(
semantic_loss.garbage_cross_entropy,
"garbage_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
)