mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
Added cli for configuring/launching experiments
This commit is contained in:
48
poetry.lock
generated
48
poetry.lock
generated
@@ -2933,6 +2933,24 @@ urllib3 = ">=1.21.1,<3"
|
|||||||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rich"
|
||||||
|
version = "13.3.1"
|
||||||
|
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7.0"
|
||||||
|
files = [
|
||||||
|
{file = "rich-13.3.1-py3-none-any.whl", hash = "sha256:8aa57747f3fc3e977684f0176a88e789be314a99f99b43b75d1e9cb5dc6db9e9"},
|
||||||
|
{file = "rich-13.3.1.tar.gz", hash = "sha256:125d96d20c92b946b983d0d392b84ff945461e5a06d3867e9f9e575f8697b67f"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
markdown-it-py = ">=2.1.0,<3.0.0"
|
||||||
|
pygments = ">=2.14.0,<3.0.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rpds-py"
|
name = "rpds-py"
|
||||||
version = "0.18.1"
|
version = "0.18.1"
|
||||||
@@ -3207,6 +3225,17 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments
|
|||||||
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
|
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
|
||||||
testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
|
testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "shellingham"
|
||||||
|
version = "1.5.4"
|
||||||
|
description = "Tool to Detect Surrounding Shell"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"},
|
||||||
|
{file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "six"
|
name = "six"
|
||||||
version = "1.16.0"
|
version = "1.16.0"
|
||||||
@@ -3629,6 +3658,23 @@ build = ["cmake (>=3.20)", "lit"]
|
|||||||
tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"]
|
tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"]
|
||||||
tutorials = ["matplotlib", "pandas", "tabulate", "torch"]
|
tutorials = ["matplotlib", "pandas", "tabulate", "torch"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typer"
|
||||||
|
version = "0.12.3"
|
||||||
|
description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "typer-0.12.3-py3-none-any.whl", hash = "sha256:070d7ca53f785acbccba8e7d28b08dcd88f79f1fbda035ade0aecec71ca5c914"},
|
||||||
|
{file = "typer-0.12.3.tar.gz", hash = "sha256:49e73131481d804288ef62598d97a1ceef3058905aa536a1134f90891ba35482"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
click = ">=8.0.0"
|
||||||
|
rich = ">=10.11.0"
|
||||||
|
shellingham = ">=1.3.0"
|
||||||
|
typing-extensions = ">=3.7.4.3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typing-extensions"
|
name = "typing-extensions"
|
||||||
version = "4.11.0"
|
version = "4.11.0"
|
||||||
@@ -3865,4 +3911,4 @@ multidict = ">=4.0"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "035f9d736293c9021a1dcdf67fd794bcc9034180dd2f458d9132e732daca06d9"
|
content-hash = "8bf3620a195c51ef0b5bf1cc4e98299c43e4b288cd00f38c167764fb917522f8"
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ matplotlib-backend-kitty = "^2.1.2"
|
|||||||
euporie = "^2.8.2"
|
euporie = "^2.8.2"
|
||||||
ipykernel = "^6.29.4"
|
ipykernel = "^6.29.4"
|
||||||
tensorboard = "^2.16.2"
|
tensorboard = "^2.16.2"
|
||||||
|
typer = "^0.12.3"
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@@ -1,20 +1,31 @@
|
|||||||
|
import typer
|
||||||
from typing import Optional, Iterable
|
from typing import Optional, Iterable
|
||||||
import experiment1
|
from typing_extensions import Annotated
|
||||||
|
from . import experiment1
|
||||||
|
|
||||||
|
|
||||||
EXPERIMENTS = (experiment1,)
|
EXPERIMENTS = (experiment1,)
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def parse_int_or_intiterable(i: Optional[str]) -> Iterable[int]:
|
||||||
experiments: Optional[int | Iterable[int]] = None,
|
return range(1, len(EXPERIMENTS) + 1) if i is None else map(int, i.split(","))
|
||||||
tensorboard: bool = True,
|
|
||||||
wandb: bool = True,
|
|
||||||
):
|
|
||||||
if experiments is None:
|
|
||||||
experiments = range(1, len(EXPERIMENTS) + 1)
|
|
||||||
elif not isinstance(experiments, Iterable):
|
|
||||||
experiments = (experiments,)
|
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
experiments: Annotated[
|
||||||
|
Optional[str],
|
||||||
|
typer.Option(
|
||||||
|
parser=parse_int_or_intiterable,
|
||||||
|
help="A comma separated list of experiments to be run. Defaults to all.",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
tensorboard: Annotated[
|
||||||
|
bool, typer.Option(help="Whether or not to log via tensorboard")
|
||||||
|
] = True,
|
||||||
|
wandb: Annotated[
|
||||||
|
bool, typer.Option(help="Whether or not to log via Weights & Biases")
|
||||||
|
] = True,
|
||||||
|
):
|
||||||
experiment_indeces = (i - 1 for i in experiments)
|
experiment_indeces = (i - 1 for i in experiments)
|
||||||
experiment_funcs = [EXPERIMENTS[i].run for i in experiment_indeces]
|
experiment_funcs = [EXPERIMENTS[i].run for i in experiment_indeces]
|
||||||
|
|
||||||
@@ -23,4 +34,4 @@ def main(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
typer.run(main)
|
||||||
|
|||||||
75
symbolic_nn_tests/experiment1/__init__.py
Normal file
75
symbolic_nn_tests/experiment1/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
LEARNING_RATE = 10e-5
|
||||||
|
|
||||||
|
|
||||||
|
def test(loss_func, version, tensorboard=True, wandb=True):
|
||||||
|
from .model import main as test_model
|
||||||
|
|
||||||
|
logger = []
|
||||||
|
|
||||||
|
if tensorboard:
|
||||||
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
|
tb_logger = TensorBoardLogger(
|
||||||
|
save_dir=".",
|
||||||
|
name="logs/comparison",
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
|
logger.append(tb_logger)
|
||||||
|
|
||||||
|
if wandb:
|
||||||
|
import wandb as _wandb
|
||||||
|
from lightning.pytorch.loggers import WandbLogger
|
||||||
|
|
||||||
|
wandb_logger = WandbLogger(
|
||||||
|
project="Symbolic_NN_Tests",
|
||||||
|
name=version,
|
||||||
|
dir="wandb",
|
||||||
|
)
|
||||||
|
logger.append(wandb_logger)
|
||||||
|
|
||||||
|
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
|
||||||
|
|
||||||
|
if wandb:
|
||||||
|
_wandb.finish()
|
||||||
|
|
||||||
|
|
||||||
|
def run(tensorboard: bool = True, wandb: bool = True):
|
||||||
|
from . import semantic_loss
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
test(
|
||||||
|
nn.functional.cross_entropy,
|
||||||
|
"cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
semantic_loss.similarity_cross_entropy,
|
||||||
|
"similarity_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
semantic_loss.hasline_cross_entropy,
|
||||||
|
"hasline_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
semantic_loss.hasloop_cross_entropy,
|
||||||
|
"hasloop_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
semantic_loss.multisemantic_cross_entropy,
|
||||||
|
"multisemantic_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
semantic_loss.garbage_cross_entropy,
|
||||||
|
"garbage_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user