mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
Refactor of experiment2 for convenience
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
import typer
|
||||
from typing import Optional, Iterable
|
||||
from typing_extensions import Annotated
|
||||
from . import experiment1
|
||||
from loguru import logger
|
||||
from . import experiment1, experiment2
|
||||
|
||||
|
||||
EXPERIMENTS = (experiment1,)
|
||||
EXPERIMENTS = (experiment1, experiment2)
|
||||
|
||||
|
||||
def parse_int_or_intiterable(i: Optional[str]) -> Iterable[int]:
|
||||
return range(1, len(EXPERIMENTS) + 1) if i is None else map(int, i.split(","))
|
||||
return range(1, len(EXPERIMENTS) + 1) if i is None else list(map(int, i.split(",")))
|
||||
|
||||
|
||||
def main(
|
||||
@@ -26,10 +27,10 @@ def main(
|
||||
bool, typer.Option(help="Whether or not to log via Weights & Biases")
|
||||
] = True,
|
||||
):
|
||||
experiment_indeces = (i - 1 for i in experiments)
|
||||
experiment_funcs = [EXPERIMENTS[i].run for i in experiment_indeces]
|
||||
|
||||
for experiment in experiment_funcs:
|
||||
for i, n in enumerate(experiments, start=1):
|
||||
j = n - 1
|
||||
experiment = EXPERIMENTS[j].run
|
||||
logger.info(f"Running Experiment {n} ({i}/{len(experiments)})...")
|
||||
experiment(tensorboard=tensorboard, wandb=wandb)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user