mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
Added cli for configuring/launching experiments
This commit is contained in:
48
poetry.lock
generated
48
poetry.lock
generated
@@ -2933,6 +2933,24 @@ urllib3 = ">=1.21.1,<3"
|
||||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "13.3.1"
|
||||
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
{file = "rich-13.3.1-py3-none-any.whl", hash = "sha256:8aa57747f3fc3e977684f0176a88e789be314a99f99b43b75d1e9cb5dc6db9e9"},
|
||||
{file = "rich-13.3.1.tar.gz", hash = "sha256:125d96d20c92b946b983d0d392b84ff945461e5a06d3867e9f9e575f8697b67f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
markdown-it-py = ">=2.1.0,<3.0.0"
|
||||
pygments = ">=2.14.0,<3.0.0"
|
||||
|
||||
[package.extras]
|
||||
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
||||
|
||||
[[package]]
|
||||
name = "rpds-py"
|
||||
version = "0.18.1"
|
||||
@@ -3207,6 +3225,17 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments
|
||||
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
|
||||
testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
|
||||
|
||||
[[package]]
|
||||
name = "shellingham"
|
||||
version = "1.5.4"
|
||||
description = "Tool to Detect Surrounding Shell"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"},
|
||||
{file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "six"
|
||||
version = "1.16.0"
|
||||
@@ -3629,6 +3658,23 @@ build = ["cmake (>=3.20)", "lit"]
|
||||
tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"]
|
||||
tutorials = ["matplotlib", "pandas", "tabulate", "torch"]
|
||||
|
||||
[[package]]
|
||||
name = "typer"
|
||||
version = "0.12.3"
|
||||
description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "typer-0.12.3-py3-none-any.whl", hash = "sha256:070d7ca53f785acbccba8e7d28b08dcd88f79f1fbda035ade0aecec71ca5c914"},
|
||||
{file = "typer-0.12.3.tar.gz", hash = "sha256:49e73131481d804288ef62598d97a1ceef3058905aa536a1134f90891ba35482"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
click = ">=8.0.0"
|
||||
rich = ">=10.11.0"
|
||||
shellingham = ">=1.3.0"
|
||||
typing-extensions = ">=3.7.4.3"
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.11.0"
|
||||
@@ -3865,4 +3911,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "035f9d736293c9021a1dcdf67fd794bcc9034180dd2f458d9132e732daca06d9"
|
||||
content-hash = "8bf3620a195c51ef0b5bf1cc4e98299c43e4b288cd00f38c167764fb917522f8"
|
||||
|
||||
@@ -21,6 +21,7 @@ matplotlib-backend-kitty = "^2.1.2"
|
||||
euporie = "^2.8.2"
|
||||
ipykernel = "^6.29.4"
|
||||
tensorboard = "^2.16.2"
|
||||
typer = "^0.12.3"
|
||||
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -1,20 +1,31 @@
|
||||
import typer
|
||||
from typing import Optional, Iterable
|
||||
import experiment1
|
||||
from typing_extensions import Annotated
|
||||
from . import experiment1
|
||||
|
||||
|
||||
EXPERIMENTS = (experiment1,)
|
||||
|
||||
|
||||
def main(
|
||||
experiments: Optional[int | Iterable[int]] = None,
|
||||
tensorboard: bool = True,
|
||||
wandb: bool = True,
|
||||
):
|
||||
if experiments is None:
|
||||
experiments = range(1, len(EXPERIMENTS) + 1)
|
||||
elif not isinstance(experiments, Iterable):
|
||||
experiments = (experiments,)
|
||||
def parse_int_or_intiterable(i: Optional[str]) -> Iterable[int]:
|
||||
return range(1, len(EXPERIMENTS) + 1) if i is None else map(int, i.split(","))
|
||||
|
||||
|
||||
def main(
|
||||
experiments: Annotated[
|
||||
Optional[str],
|
||||
typer.Option(
|
||||
parser=parse_int_or_intiterable,
|
||||
help="A comma separated list of experiments to be run. Defaults to all.",
|
||||
),
|
||||
] = None,
|
||||
tensorboard: Annotated[
|
||||
bool, typer.Option(help="Whether or not to log via tensorboard")
|
||||
] = True,
|
||||
wandb: Annotated[
|
||||
bool, typer.Option(help="Whether or not to log via Weights & Biases")
|
||||
] = True,
|
||||
):
|
||||
experiment_indeces = (i - 1 for i in experiments)
|
||||
experiment_funcs = [EXPERIMENTS[i].run for i in experiment_indeces]
|
||||
|
||||
@@ -23,4 +34,4 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
typer.run(main)
|
||||
|
||||
75
symbolic_nn_tests/experiment1/__init__.py
Normal file
75
symbolic_nn_tests/experiment1/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
LEARNING_RATE = 10e-5
|
||||
|
||||
|
||||
def test(loss_func, version, tensorboard=True, wandb=True):
|
||||
from .model import main as test_model
|
||||
|
||||
logger = []
|
||||
|
||||
if tensorboard:
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
tb_logger = TensorBoardLogger(
|
||||
save_dir=".",
|
||||
name="logs/comparison",
|
||||
version=version,
|
||||
)
|
||||
logger.append(tb_logger)
|
||||
|
||||
if wandb:
|
||||
import wandb as _wandb
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
|
||||
wandb_logger = WandbLogger(
|
||||
project="Symbolic_NN_Tests",
|
||||
name=version,
|
||||
dir="wandb",
|
||||
)
|
||||
logger.append(wandb_logger)
|
||||
|
||||
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
|
||||
|
||||
if wandb:
|
||||
_wandb.finish()
|
||||
|
||||
|
||||
def run(tensorboard: bool = True, wandb: bool = True):
|
||||
from . import semantic_loss
|
||||
from torch import nn
|
||||
|
||||
test(
|
||||
nn.functional.cross_entropy,
|
||||
"cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
semantic_loss.similarity_cross_entropy,
|
||||
"similarity_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
semantic_loss.hasline_cross_entropy,
|
||||
"hasline_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
semantic_loss.hasloop_cross_entropy,
|
||||
"hasloop_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
semantic_loss.multisemantic_cross_entropy,
|
||||
"multisemantic_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
semantic_loss.garbage_cross_entropy,
|
||||
"garbage_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
Reference in New Issue
Block a user