mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 14:11:59 +00:00
Added local experiment for private tests that shouldnt be pushed.
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -168,3 +168,6 @@ explore/
|
|||||||
lightning_logs/
|
lightning_logs/
|
||||||
logs/
|
logs/
|
||||||
wandb/
|
wandb/
|
||||||
|
|
||||||
|
symbolic_nn_tests/local/*
|
||||||
|
!symbolic_nn_tests/local/README.md
|
||||||
|
|||||||
@@ -2,21 +2,23 @@ import typer
|
|||||||
from typing import Optional, Iterable
|
from typing import Optional, Iterable
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from . import experiment1, experiment2, experiment3
|
from . import local, experiment1, experiment2, experiment3
|
||||||
|
|
||||||
|
|
||||||
EXPERIMENTS = (experiment1, experiment2, experiment3)
|
EXPERIMENTS = (local, experiment1, experiment2, experiment3)
|
||||||
|
|
||||||
|
|
||||||
def parse_int_or_intiterable(i: Optional[str]) -> Iterable[int]:
|
def parse_int_or_intiterable(i: Optional[str] = None) -> Iterable[int]:
|
||||||
return range(1, len(EXPERIMENTS) + 1) if i is None else list(map(int, i.split(",")))
|
if i is None:
|
||||||
|
return range(1, len(EXPERIMENTS))
|
||||||
|
else:
|
||||||
|
return list(map(int, i.replace("local", "0").split(",")))
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
experiments: Annotated[
|
experiments: Annotated[
|
||||||
Optional[str],
|
Optional[str],
|
||||||
typer.Option(
|
typer.Option(
|
||||||
parser=parse_int_or_intiterable,
|
|
||||||
help="A comma separated list of experiments to be run. Defaults to all.",
|
help="A comma separated list of experiments to be run. Defaults to all.",
|
||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
@@ -27,6 +29,8 @@ def main(
|
|||||||
bool, typer.Option(help="Whether or not to log via Weights & Biases")
|
bool, typer.Option(help="Whether or not to log via Weights & Biases")
|
||||||
] = True,
|
] = True,
|
||||||
):
|
):
|
||||||
|
experiments = parse_int_or_intiterable(experiments)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Enable tensor cores for compatible GPUs
|
# Enable tensor cores for compatible GPUs
|
||||||
@@ -35,8 +39,7 @@ def main(
|
|||||||
torch.set_float32_matmul_precision("medium")
|
torch.set_float32_matmul_precision("medium")
|
||||||
|
|
||||||
for i, n in enumerate(experiments, start=1):
|
for i, n in enumerate(experiments, start=1):
|
||||||
j = n - 1
|
experiment = EXPERIMENTS[n].run
|
||||||
experiment = EXPERIMENTS[j].run
|
|
||||||
logger.info(f"Running Experiment {n} ({i}/{len(experiments)})...")
|
logger.info(f"Running Experiment {n} ({i}/{len(experiments)})...")
|
||||||
experiment(tensorboard=tensorboard, wandb=wandb)
|
experiment(tensorboard=tensorboard, wandb=wandb)
|
||||||
|
|
||||||
|
|||||||
4
symbolic_nn_tests/local/README.md
Normal file
4
symbolic_nn_tests/local/README.md
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
This directory is intended to contain local, private experiments
|
||||||
|
that make use of the framework in this module for fast experimentation.
|
||||||
|
The best of these exeriments will likely make their way into the public
|
||||||
|
repo over time.
|
||||||
Reference in New Issue
Block a user