Added local experiment for private tests that shouldnt be pushed.

This commit is contained in:
2024-09-11 10:57:48 +01:00
parent 2c1e9aada0
commit 4036364ea0
3 changed files with 17 additions and 7 deletions

3
.gitignore vendored
View File

@@ -168,3 +168,6 @@ explore/
lightning_logs/
logs/
wandb/
symbolic_nn_tests/local/*
!symbolic_nn_tests/local/README.md

View File

@@ -2,21 +2,23 @@ import typer
from typing import Optional, Iterable
from typing_extensions import Annotated
from loguru import logger
from . import experiment1, experiment2, experiment3
from . import local, experiment1, experiment2, experiment3
EXPERIMENTS = (experiment1, experiment2, experiment3)
EXPERIMENTS = (local, experiment1, experiment2, experiment3)
def parse_int_or_intiterable(i: Optional[str]) -> Iterable[int]:
return range(1, len(EXPERIMENTS) + 1) if i is None else list(map(int, i.split(",")))
def parse_int_or_intiterable(i: Optional[str] = None) -> Iterable[int]:
if i is None:
return range(1, len(EXPERIMENTS))
else:
return list(map(int, i.replace("local", "0").split(",")))
def main(
experiments: Annotated[
Optional[str],
typer.Option(
parser=parse_int_or_intiterable,
help="A comma separated list of experiments to be run. Defaults to all.",
),
] = None,
@@ -27,6 +29,8 @@ def main(
bool, typer.Option(help="Whether or not to log via Weights & Biases")
] = True,
):
experiments = parse_int_or_intiterable(experiments)
import torch
# Enable tensor cores for compatible GPUs
@@ -35,8 +39,7 @@ def main(
torch.set_float32_matmul_precision("medium")
for i, n in enumerate(experiments, start=1):
j = n - 1
experiment = EXPERIMENTS[j].run
experiment = EXPERIMENTS[n].run
logger.info(f"Running Experiment {n} ({i}/{len(experiments)})...")
experiment(tensorboard=tensorboard, wandb=wandb)

View File

@@ -0,0 +1,4 @@
This directory is intended to contain local, private experiments
that make use of the framework in this module for fast experimentation.
The best of these exeriments will likely make their way into the public
repo over time.