<h1>Experiment 2</h1>
<h3>Targetted hyperparameter tuning</h3>
<p>
By examining the results of expt1, a smaller range of hyperparameters for expt2 was chosen. This allowed for a more targetted search of the hyperparameter space to find an optimal configuration. The selected parameters for expt2 were as follows:
</p>
<ul>
<li>in_act = Linear, Mish, PBessel, or Tanhshrink</li>
<li>compressor_kernel_size = 128</li>
<li>compressor_act = Softshrink, SoftExp, or PReLU</li>
<li>conv_kernel_size = 128</li>
<li>conv_act = Sigmoid or PBessel</li>
<li>channel_combine_act = HardSigmoid or GELU</li>
<li>ff_width = 512</li>
<li>ff_depth = 2, 4, or 6</li>
<li>ff_act = CELU</li>
<li>out_act = Tanhshrink or Mish</li>
</ul>
<p>
Several of the parameters were able to be fixed to a specific value, and the remaining parameters (with the exception of <code>`in_act`</code>) were reduced to only 2 or 3 possible values, dramatically shrinking the parameter space. For this reason, a significantly less aggressive pruning algorithm was used, allowing for a more thorough search of the parameter space.
</p>

In [6]:
# Data handling imports
from dask.distributed import Client, LocalCluster
import dask
import dask.dataframe as dd
import dask.array as da
import numpy as np
import pickle
import random
from itertools import chain
from tqdm.auto import tqdm

# Deep learning imports
import torch
from torch.utils.data import DataLoader
from torch import nn
from torch.nn import functional as F
from torch import optim
import pytorch_lightning as pl
from pytorch_lightning import Trainer
import optuna
from optuna.pruners import HyperbandPruner
from optuna.integration import PyTorchLightningPruningCallback

# Suppress some warning messages from pytorch_lightning,
# It really doesn't like that i've forced it to handle a dask array!
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module=pl.__name__)

# Also, set up a log to record debug messages for failed trials
import logging

logging.basicConfig(filename="debug_test.log", encoding="utf-8", level=logging.DEBUG)

In [7]:
from expt1 import (
    Model,
    Linear,
    device,
    activation_dispatcher,
    X_train,
    y_train,
    X_val,
    y_val,
    create_collate_fn,
)

In [None]:
cluster = LocalCluster(n_workers=8, threads_per_worker=1)
client = Client(cluster)

In [8]:
# Monkey patch to allow pytorch lightning to accept a dask array as a model input
from typing import Any, Generator, Iterable, Mapping, Optional, Union

BType = Union[da.Array, torch.Tensor, str, Mapping[Any, "BType"], Iterable["BType"]]

unpatched = pl.utilities.data._extract_batch_size


def patch(batch: BType) -> Generator[Optional[int], None, None]:
    if isinstance(batch, da.core.Array):
        if len(batch.shape) == 0:
            yield 1
        else:
            yield batch.shape[0]
    else:
        yield from unpatched(batch)


pl.utilities.data._extract_batch_size = patch

In [4]:
# Test parameters
n_epochs = 10
output_keys = list(next(iter(y_train.values())).keys())
activation_vals = list(activation_dispatcher.keys())


# Next we define the objective function for the hyperparameter optimization
def objective(trial):
    torch.cuda.empty_cache()
    objective_value = torch.inf
    model = None
    logger = None
    try:
        # Select hyperparameters for testing
        in_act = (
            activation_dispatcher[trial.suggest_categorical("in_act", activation_vals)],
            (),
            {},
        )
        compressor_kernel_size = trial.suggest_int(
            "compressor_kernel_size", 64, 257, 64
        )
        compressor_chunk_size = 128
        compressor_act = (
            activation_dispatcher[
                trial.suggest_categorical("compressor_act", activation_vals)
            ],
            (),
            {},
        )
        conv_kernel_size = trial.suggest_int("conv_kernel_size", 64, 257, 64)
        conv_act = (
            activation_dispatcher[
                trial.suggest_categorical("conv_act", activation_vals)
            ],
            (),
            {},
        )
        channel_combine_act = (
            activation_dispatcher[
                trial.suggest_categorical("channel_combine_act", activation_vals)
            ],
            (),
            {},
        )
        ff_width = trial.suggest_int("ff_width", 256, 1025, 256)
        ff_depth = trial.suggest_int("ff_depth", 2, 8, 2)
        ff_act = (
            activation_dispatcher[trial.suggest_categorical("ff_act", activation_vals)],
            (),
            {},
        )
        out_size = len(output_keys)
        out_act = (
            activation_dispatcher[
                trial.suggest_categorical("out_act", activation_vals)
            ],
            (),
            {},
        )

        # Set up the model architecture and other necessary components
        model = Model(
            in_act=in_act,
            compressor_kernel_size=compressor_kernel_size,
            compressor_chunk_size=compressor_chunk_size,
            compressor_act=compressor_act,
            conv_kernel_size=conv_kernel_size,
            conv_act=conv_act,
            channel_combine_act=channel_combine_act,
            ff_width=ff_width,
            ff_depth=ff_depth,
            ff_act=ff_act,
            out_size=out_size,
            out_act=out_act,
        ).to(device)

        trainer = Trainer(
            accelerator="gpu",
            max_epochs=n_epochs,
            devices=1,
            logger=logger,
            num_sanity_val_steps=0,  # Needs to be disabled or else we get an error because X is dask array
            # precision="16-mixed",
            callbacks=[
                PyTorchLightningPruningCallback(trial, monitor="val_loss"),
            ],
        )
        # Prepare datasets
        train = DataLoader(
            list(zip(X_train.values(), y_train.values())),
            collate_fn=create_collate_fn(),
            shuffle=True,
        )
        valid = DataLoader(
            list(zip(X_val.values(), y_val.values())),
            shuffle=True,
            collate_fn=create_collate_fn(),
        )
        # Finally, train the model
        trainer.fit(model, train, valid)
    except torch.cuda.OutOfMemoryError as e:
        logging.warning(f"Ran out of memory in trial {trial.number}!")
        raise optuna.exceptions.TrialPruned()
    except Exception as e:
        logging.exception(f"An exception occurred in trial {trial.number}: {e}")
        raise optuna.exceptions.TrialPruned()
    finally:
        if logger is not None:
            logger.experiment.unwatch(model)
            logger.experiment.finish()
    del model
    torch.cuda.empty_cache()
    if objective_value == torch.inf:
        raise optuna.exceptions.TrialPruned()
    return objective_value

In [5]:
storage_name = "sqlite:///optuna.sql"
storage_name = "mysql+pymysql://root:Ch31121992@192.168.1.10:3306/optuna_db"
study_name = "Experiment 2"
study = optuna.create_study(
    study_name=study_name,
    storage=storage_name,
    direction="minimize",
    pruner=HyperbandPruner(),
    load_if_exists=True,
)
study.optimize(
    objective,
    n_trials=None,
    timeout=None,
)

[I 2023-07-31 23:49:15,744] Using an existing study with name 'Experiment 2' instead of creating a new one.
[I 2023-07-31 23:49:16,553] Trial 221 pruned. 
[I 2023-07-31 23:49:16,928] Trial 222 pruned. 
[I 2023-07-31 23:49:17,318] Trial 223 pruned. 
[I 2023-07-31 23:49:17,682] Trial 224 pruned. 
[W 2023-07-31 23:49:18,028] Trial 225 failed with parameters: {} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/home/cianh/Programming/Git_Projects/Aconity_ML_Test/.venv/lib/python3.11/site-packages/optuna/study/_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/tmp/ipykernel_562333/3392796582.py", line 16, in objective
    activation_dispatcher[trial.suggest_categorical("in_act", activation_vals)],
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cianh/Programming/Git_Projects/Aconity_ML_Test/.venv/lib/python3.11/site-packages/optuna/tria

KeyboardInterrupt: 