Files
Aconity_ML_Expt1/expt3.ipynb
2023-08-03 21:33:03 +01:00

399 lines
13 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h1>Experiment 3</h1>\n",
"<h3>Optimised model training</h3>\n",
"<p>In experiment 3 the model was trained using the optimised hyperparameters. By examining the results of expt2, it was noticed that trials #1, #10, and #16 all resulted in quite low losses while also showing clear downward trends resembling a clearly discernible training curve. Of these, trial #16 was ultimately selected as the model to be tested, as the data suggests that <code>`in_act=Mish`</code> tends to give the lowest losses in most models tested. The parameters for trial #16 were as follows:</p>\n",
"<ul>\n",
"<li><b>in_act</b> = Mish</li>\n",
"<li><b>compressor_kernel_size</b> = 128</li>\n",
"<li><b>compressor_chunk_size</b> = 128</li>\n",
"<li><b>compressor_act</b> = SoftExp</li>\n",
"<li><b>conv_kernel_size</b> = 128</li>\n",
"<li><b>conv_act</b> = Sigmoid</li>\n",
"<li><b>channel_combine_act</b> = GELU</li>\n",
"<li><b>ff_width</b> = 512</li>\n",
"<li><b>ff_depth</b> = 2</li>\n",
"<li><b>ff_act</b> = CELU</li>\n",
"<li><b>out_act</b> = Tanhshrink</li>\n",
"</ul>\n",
"<p>\n",
"Because most of the training curves in expt2 appeared to be unstable, a learning rate scheduler was used to reduce the learning rate by 20% if the validation loss did not improve for 5 epochs. The model was checkpointed, with the best 10 iterations of the model being retained for testing after training.\n",
"</p>\n",
"<h3>Modified optimal model training</h3>\n",
"<p>\n",
"Following the first attempt at training the optimised model (Model 1, Test 1), it was noted that training curves were clearly discernible, but still quite unstable and noisy. To try and further improve the stability of the training, a modified version of the model was prepared and trained (Model 2, Test 2). The modified model was the same as Model 1, but with the addition of a LayerNormalization layer to the convolutional layer of the <code>`DaskCompressor`</code> submodule. This change was made because highly recurrent submodules such as the compressor are known to be especially prone to instability caused by vanishing or exploding gradients. It was reasoned that by normalizing at each iteration the gradients would be less likely to vanish or explode, making the training more stable.\n",
"</p>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Data handling imports\n",
"from dask.distributed import Client, LocalCluster\n",
"import dask\n",
"import dask.dataframe as dd\n",
"import dask.array as da\n",
"import numpy as np\n",
"import pickle\n",
"import random\n",
"from itertools import chain\n",
"from tqdm.auto import tqdm\n",
"\n",
"# Deep learning imports\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"from torch import nn\n",
"from torch.nn import functional as F\n",
"from torch import optim\n",
"import pytorch_lightning as pl\n",
"from pytorch_lightning import Trainer\n",
"from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\n",
"from pytorch_lightning.loggers import WandbLogger\n",
"\n",
"# Suppress some warning messages from pytorch_lightning,\n",
"# It really doesn't like that i've forced it to handle a dask array!\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\", category=UserWarning, module=pl.__name__)\n",
"\n",
"# Also, set up a log to record debug messages for failed trials\n",
"import logging\n",
"\n",
"logging.basicConfig(filename=\"debug.log\", encoding=\"utf-8\", level=logging.DEBUG)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from expt1 import (\n",
" Model,\n",
" device,\n",
" X_train,\n",
" y_train,\n",
" X_val,\n",
" y_val,\n",
" create_collate_fn,\n",
")\n",
"from custom_activations import SoftExp"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cianh/Programming/Git_Projects/Aconity_ML_Test/.venv/lib/python3.11/site-packages/distributed/node.py:182: UserWarning: Port 8787 is already in use.\n",
"Perhaps you already have a cluster running?\n",
"Hosting the HTTP server on port 34477 instead\n",
" warnings.warn(\n"
]
}
],
"source": [
"cluster = LocalCluster(n_workers=8, threads_per_worker=1)\n",
"client = Client(cluster)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Monkey patch to allow pytorch lightning to accept a dask array as a model input\n",
"from typing import Any, Generator, Iterable, Mapping, Optional, Union\n",
"\n",
"BType = Union[da.Array, torch.Tensor, str, Mapping[Any, \"BType\"], Iterable[\"BType\"]]\n",
"\n",
"unpatched = pl.utilities.data._extract_batch_size\n",
"\n",
"\n",
"def patch(batch: BType) -> Generator[Optional[int], None, None]:\n",
" if isinstance(batch, da.core.Array):\n",
" if len(batch.shape) == 0:\n",
" yield 1\n",
" else:\n",
" yield batch.shape[0]\n",
" else:\n",
" yield from unpatched(batch)\n",
"\n",
"\n",
"pl.utilities.data._extract_batch_size = patch"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Prepare datasets\n",
"train = DataLoader(\n",
" list(zip(X_train.values(), y_train.values())),\n",
" collate_fn=create_collate_fn(),\n",
" shuffle=True,\n",
")\n",
"valid = DataLoader(\n",
" list(zip(X_val.values(), y_val.values())),\n",
" shuffle=True,\n",
" collate_fn=create_collate_fn(),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Set up the model architecture and other necessary components\n",
"model = Model(\n",
" # Training parameters\n",
" optimizer=optim.Adam,\n",
" scheduler=optim.lr_scheduler.ReduceLROnPlateau,\n",
" scheduler_kwargs={\"factor\": 0.8, \"patience\": 5},\n",
" # Model parameters\n",
" in_act=(nn.Mish, list(), dict()),\n",
" compressor_kernel_size=128,\n",
" compressor_chunk_size=128,\n",
" compressor_act=(SoftExp, list(), dict()),\n",
" conv_kernel_size=128,\n",
" conv_act=(nn.Sigmoid, list(), dict()),\n",
" channel_combine_act=(nn.GELU, list(), dict()),\n",
" ff_width=512,\n",
" ff_depth=2,\n",
" ff_act=(nn.CELU, list(), dict()),\n",
" out_size=len(list(next(iter(y_train.values())).keys())),\n",
" out_act=(nn.Tanhshrink, list(), dict()),\n",
").to(device)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchughes000\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6d1624339b4c4aaeb195b5ebc3b3e69e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.016669258750092317, max=1.0…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"wandb version 0.15.8 is available! To upgrade, please run:\n",
" $ pip install wandb --upgrade"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Tracking run with wandb version 0.15.7"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Run data is saved locally in <code>./wandb/run-20230801_233841-q70oibx2</code>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/chughes000/Aconity_ML_Test_DryRun/runs/q70oibx2' target=\"_blank\">Test 2</a></strong> to <a href='https://wandb.ai/chughes000/Aconity_ML_Test_DryRun' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View project at <a href='https://wandb.ai/chughes000/Aconity_ML_Test_DryRun' target=\"_blank\">https://wandb.ai/chughes000/Aconity_ML_Test_DryRun</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run at <a href='https://wandb.ai/chughes000/Aconity_ML_Test_DryRun/runs/q70oibx2' target=\"_blank\">https://wandb.ai/chughes000/Aconity_ML_Test_DryRun/runs/q70oibx2</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n"
]
}
],
"source": [
"early_stop_callback = EarlyStopping(\n",
" monitor=\"val_loss\", patience=15, verbose=False, mode=\"min\"\n",
")\n",
"\n",
"checkpoint_callback = ModelCheckpoint(\n",
" monitor=\"val_loss\",\n",
" dirpath=\"./checkpoints\",\n",
" filename=\"checkpoint-{epoch:02d}-{val_loss:.2f}\",\n",
" save_top_k=10,\n",
" mode=\"min\",\n",
")\n",
"\n",
"logger = WandbLogger(project=\"Aconity_ML_Test_DryRun\", name=f\"Test 1\")\n",
"logger.experiment.watch(model, log=\"all\", log_freq=1)\n",
"\n",
"trainer = Trainer(\n",
" accelerator=\"gpu\",\n",
" max_epochs=-1,\n",
" devices=\"auto\",\n",
" strategy=\"auto\",\n",
" logger=logger,\n",
" callbacks=[checkpoint_callback, early_stop_callback],\n",
" num_sanity_val_steps=0, # Needs to be disabled or else we get an error because X is dask array\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"/home/cianh/Programming/Git_Projects/Aconity_ML_Test/.venv/lib/python3.11/site-packages/pytorch_lightning/core/optimizer.py:361: RuntimeWarning: Found unsupported keys in the optimizer configuration: {'scheduler'}\n",
" rank_zero_warn(\n",
"\n",
" | Name | Type | Params\n",
"--------------------------------------------------------------\n",
"0 | loss | MSELoss | 0 \n",
"1 | in_act | Mish | 0 \n",
"2 | convolutional_compressor | DaskCompression | 3.2 K \n",
"3 | compressor_act | SoftExp | 1 \n",
"4 | conv | Conv1d | 3.2 K \n",
"5 | conv_act | Sigmoid | 0 \n",
"6 | combine_channels | Conv1d | 6 \n",
"7 | channel_combine_act | GELU | 0 \n",
"8 | ff | Sequential | 525 K \n",
"9 | out_dense | Linear | 11.8 K\n",
"10 | out_act | Tanhshrink | 0 \n",
"--------------------------------------------------------------\n",
"543 K Trainable params\n",
"0 Non-trainable params\n",
"543 K Total params\n",
"2.174 Total estimated model params size (MB)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "25d7ba2f5e3c4f68a55fdafed5a5b092",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Finally, train the model\n",
"trainer.fit(model, train, valid)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}