diff --git a/.gitignore b/.gitignore index d9c2d02..57eaa8d 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,6 @@ explore/ lightning_logs/ logs/ wandb/ + +symbolic_nn_tests/local/* +!symbolic_nn_tests/local/README.md diff --git a/symbolic_nn_tests/__main__.py b/symbolic_nn_tests/__main__.py index dc476c1..b2ee0f0 100644 --- a/symbolic_nn_tests/__main__.py +++ b/symbolic_nn_tests/__main__.py @@ -2,21 +2,23 @@ import typer from typing import Optional, Iterable from typing_extensions import Annotated from loguru import logger -from . import experiment1, experiment2, experiment3 +from . import local, experiment1, experiment2, experiment3 -EXPERIMENTS = (experiment1, experiment2, experiment3) +EXPERIMENTS = (local, experiment1, experiment2, experiment3) -def parse_int_or_intiterable(i: Optional[str]) -> Iterable[int]: - return range(1, len(EXPERIMENTS) + 1) if i is None else list(map(int, i.split(","))) +def parse_int_or_intiterable(i: Optional[str] = None) -> Iterable[int]: + if i is None: + return range(1, len(EXPERIMENTS)) + else: + return list(map(int, i.replace("local", "0").split(","))) def main( experiments: Annotated[ Optional[str], typer.Option( - parser=parse_int_or_intiterable, help="A comma separated list of experiments to be run. Defaults to all.", ), ] = None, @@ -27,6 +29,8 @@ def main( bool, typer.Option(help="Whether or not to log via Weights & Biases") ] = True, ): + experiments = parse_int_or_intiterable(experiments) + import torch # Enable tensor cores for compatible GPUs @@ -35,8 +39,7 @@ def main( torch.set_float32_matmul_precision("medium") for i, n in enumerate(experiments, start=1): - j = n - 1 - experiment = EXPERIMENTS[j].run + experiment = EXPERIMENTS[n].run logger.info(f"Running Experiment {n} ({i}/{len(experiments)})...") experiment(tensorboard=tensorboard, wandb=wandb) diff --git a/symbolic_nn_tests/local/README.md b/symbolic_nn_tests/local/README.md new file mode 100644 index 0000000..c5dc9f3 --- /dev/null +++ b/symbolic_nn_tests/local/README.md @@ -0,0 +1,4 @@ +This directory is intended to contain local, private experiments +that make use of the framework in this module for fast experimentation. +The best of these exeriments will likely make their way into the public +repo over time.