Files
Aconity_ML_Expt1/expt3_analysis.ipynb
2023-08-03 21:33:03 +01:00

1230 lines
28 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import torch\n",
"from torch import nn\n",
"from torch.utils.data import DataLoader\n",
"from torchmetrics import MeanSquaredError\n",
"import numpy as np\n",
"import pandas as pd\n",
"from tqdm.auto import tqdm\n",
"\n",
"from expt1 import X, y, collate_fn, device\n",
"from expt3 import model\n",
"\n",
"pd.options.plotting.backend = \"plotly\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"model.eval()\n",
"\n",
"checkpoints = tuple(Path(\"checkpoints\").glob(\"*.ckpt\"))\n",
"\n",
"test = DataLoader(\n",
" list(zip(X.values(), y.values())),\n",
" collate_fn=collate_fn,\n",
" shuffle=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cf4c681d9d5b4bb5aa1aef4f544bf2ef",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluating checkpoints: 0%| | 0/10 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "462ba8b04a7b4e25a108bf31b9d358b0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluating samples: 0%| | 0/81 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Evaluate all checkpoints from expt3 on all samples\n",
"sample_df = pd.DataFrame(columns=[\"sample\"], data=X.keys()).set_index(\"sample\")\n",
"overall_df = []\n",
"with torch.no_grad():\n",
" for checkpoint in tqdm(checkpoints, desc=\"Evaluating checkpoints\"):\n",
" model.load_state_dict(torch.load(checkpoint)[\"state_dict\"])\n",
" y_ground = []\n",
" y_eval = []\n",
" for sample_X0, sample_X1, sample_y in tqdm(\n",
" test, desc=\"Evaluating samples\", leave=False\n",
" ):\n",
" y_ground.append(sample_y.to(\"cpu\"))\n",
" y_eval.append(\n",
" model(sample_X0, sample_X1).to(\"cpu\")\n",
" ) # Move to CPU because otherwise will run out of VRAM\n",
" torch.cuda.empty_cache() # empty cache to ensure maximum VRAM available\n",
" y_eval, y_ground = torch.stack(y_eval), torch.stack(y_ground)\n",
" sample_df[checkpoint.name] = [\n",
" nn.MeanSquaredError()(ye, yg).item() for ye, yg in zip(y_eval, y_ground)\n",
" ]\n",
" overall_df.append(\n",
" (\n",
" model.__name__,\n",
" checkpoint.name,\n",
" MeanSquaredError()(y_eval, y_ground).item(),\n",
" )\n",
" )\n",
" overall_df = pd.DataFrame(columns=[\"model\", \"checkpoint\", \"loss\"], data=overall_df)\n",
" torch.cuda.empty_cache()\n",
"\n",
"overall_df.to_csv(\"expt3_overall.csv\")\n",
"sample_df.to_csv(\"expt3_sample.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# checkpoint = \"checkpoints/checkpoint-epoch=12-val_loss=0.00.ckpt\"\n",
"# model.load_state_dict(torch.load(checkpoint)[\"state_dict\"])\n",
"y_ground = []\n",
"y_eval = []\n",
"with torch.no_grad():\n",
" for sample_X0, sample_X1, sample_y in tqdm(\n",
" test, desc=\"Evaluating samples\", leave=False\n",
" ):\n",
" y_ground.append(sample_y.to(\"cpu\"))\n",
" y_eval.append(\n",
" model(sample_X0, sample_X1).to(\"cpu\")\n",
" ) # Move to CPU because otherwise will run out of VRAM\n",
" torch.cuda.empty_cache() # empty cache to ensure maximum VRAM available\n",
" y_eval, y_ground = torch.stack(y_eval), torch.stack(y_ground)\n",
" sample_accuracy = [\n",
" MeanSquaredError()(ye, yg).item() for ye, yg in zip(y_eval, y_ground)\n",
" ]\n",
" overall_accuracy = MeanSquaredError()(y_eval, y_ground).item()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame()\n",
"df[\"sample\"] = [\"Overall\"] + list(str(i) for i in range(len(sample_accuracy)))\n",
"df[\"mean_squared_error\"] = [overall_accuracy] + sample_accuracy\n",
"df[\"y_ground_Ni\"] = [None] + [y[0][0].item() for y in y_ground]\n",
"df[\"y_ground_Ti\"] = [None] + [y[0][1].item() for y in y_ground]\n",
"df[\"y_eval_Ni\"] = [None] + [y[0][0].item() for y in y_eval]\n",
"df[\"y_eval_Ti\"] = [None] + [y[0][1].item() for y in y_eval]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.plotly.v1+json": {
"config": {
"plotlyServerURL": "https://plot.ly"
},
"data": [
{
"alignmentgroup": "True",
"hovertemplate": "sample=%{x}<br>mean_squared_error=%{y}<extra></extra>",
"legendgroup": "",
"marker": {
"color": "#636efa",
"pattern": {
"shape": ""
}
},
"name": "",
"offsetgroup": "",
"orientation": "v",
"showlegend": false,
"textposition": "auto",
"type": "bar",
"x": [
"Overall",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"10",
"11",
"12",
"13",
"14",
"15",
"16",
"17",
"18",
"19",
"20",
"21",
"22",
"23",
"24",
"25",
"26",
"27",
"28",
"29",
"30",
"31",
"32",
"33",
"34",
"35",
"36",
"37",
"38",
"39",
"40",
"41",
"42",
"43",
"44",
"45",
"46",
"47",
"48",
"49",
"50",
"51",
"52",
"53",
"54",
"55",
"56",
"57",
"58",
"59",
"60",
"61",
"62",
"63",
"64",
"65",
"66",
"67",
"68",
"69",
"70",
"71",
"72",
"73",
"74",
"75",
"76",
"77",
"78",
"79",
"80"
],
"xaxis": "x",
"y": [
0.0026951939798891544,
0.0029656963888555765,
0.004948271904140711,
0.004154370632022619,
0.0028953494038432837,
0.003127720672637224,
0.0024840072728693485,
0.0024596236180514097,
0.004973412957042456,
0.0026887855492532253,
0.0025184503756463528,
0.002657898934558034,
0.004549721255898476,
0.002396129537373781,
0.0027176388539373875,
0.002620724495500326,
0.0025216771755367517,
0.0028491278644651175,
0.0028278320096433163,
0.002643025480210781,
0.0024697165936231613,
0.0026034615002572536,
0.002543746493756771,
0.002621922641992569,
0.0026970498729497194,
0.0027286007534712553,
0.0024804165586829185,
0.0025682011619210243,
0.0025432738475501537,
0.002615747507661581,
0.00269591948017478,
0.0029035978950560093,
0.002422402612864971,
0.0026250018272548914,
0.0024287384003400803,
0.0024327291175723076,
0.0032032637391239405,
0.002659817459061742,
0.0026294984854757786,
0.0028597083874046803,
0.002377850003540516,
0.002997367875650525,
0.0026343083009123802,
0.002686077496036887,
0.0025162948295474052,
0.0025714626535773277,
0.002468457445502281,
0.002512851729989052,
0.0025557084009051323,
0.0025269147008657455,
0.0025010996032506227,
0.002492662984877825,
0.0025910018011927605,
0.002502894029021263,
0.002440253272652626,
0.0024367240257561207,
0.0025100174825638533,
0.0026254504919052124,
0.002377192722633481,
0.00256515434011817,
0.0023771594278514385,
0.0024908073246479034,
0.0024112453684210777,
0.002494183834642172,
0.0025759506970643997,
0.002665249165147543,
0.0026045949198305607,
0.002568627707660198,
0.0025984200183302164,
0.002447066828608513,
0.0024844894651323557,
0.002440405311062932,
0.0027299586217850447,
0.002489563077688217,
0.0025099513586610556,
0.0025681699626147747,
0.002684804145246744,
0.002629152499139309,
0.002570347860455513,
0.00249234726652503,
0.0024778181686997414,
0.002378384815528989
],
"yaxis": "y"
}
],
"layout": {
"barmode": "relative",
"legend": {
"tracegroupgap": 0
},
"template": {
"data": {
"bar": [
{
"error_x": {
"color": "#2a3f5f"
},
"error_y": {
"color": "#2a3f5f"
},
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "bar"
}
],
"barpolar": [
{
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "barpolar"
}
],
"carpet": [
{
"aaxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"baxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"type": "carpet"
}
],
"choropleth": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "choropleth"
}
],
"contour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "contour"
}
],
"contourcarpet": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "contourcarpet"
}
],
"heatmap": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmap"
}
],
"heatmapgl": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmapgl"
}
],
"histogram": [
{
"marker": {
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "histogram"
}
],
"histogram2d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2d"
}
],
"histogram2dcontour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2dcontour"
}
],
"mesh3d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "mesh3d"
}
],
"parcoords": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "parcoords"
}
],
"pie": [
{
"automargin": true,
"type": "pie"
}
],
"scatter": [
{
"fillpattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
},
"type": "scatter"
}
],
"scatter3d": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter3d"
}
],
"scattercarpet": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattercarpet"
}
],
"scattergeo": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergeo"
}
],
"scattergl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergl"
}
],
"scattermapbox": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattermapbox"
}
],
"scatterpolar": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolar"
}
],
"scatterpolargl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolargl"
}
],
"scatterternary": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterternary"
}
],
"surface": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "surface"
}
],
"table": [
{
"cells": {
"fill": {
"color": "#EBF0F8"
},
"line": {
"color": "white"
}
},
"header": {
"fill": {
"color": "#C8D4E3"
},
"line": {
"color": "white"
}
},
"type": "table"
}
]
},
"layout": {
"annotationdefaults": {
"arrowcolor": "#2a3f5f",
"arrowhead": 0,
"arrowwidth": 1
},
"autotypenumbers": "strict",
"coloraxis": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"colorscale": {
"diverging": [
[
0,
"#8e0152"
],
[
0.1,
"#c51b7d"
],
[
0.2,
"#de77ae"
],
[
0.3,
"#f1b6da"
],
[
0.4,
"#fde0ef"
],
[
0.5,
"#f7f7f7"
],
[
0.6,
"#e6f5d0"
],
[
0.7,
"#b8e186"
],
[
0.8,
"#7fbc41"
],
[
0.9,
"#4d9221"
],
[
1,
"#276419"
]
],
"sequential": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"sequentialminus": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
]
},
"colorway": [
"#636efa",
"#EF553B",
"#00cc96",
"#ab63fa",
"#FFA15A",
"#19d3f3",
"#FF6692",
"#B6E880",
"#FF97FF",
"#FECB52"
],
"font": {
"color": "#2a3f5f"
},
"geo": {
"bgcolor": "white",
"lakecolor": "white",
"landcolor": "#E5ECF6",
"showlakes": true,
"showland": true,
"subunitcolor": "white"
},
"hoverlabel": {
"align": "left"
},
"hovermode": "closest",
"mapbox": {
"style": "light"
},
"paper_bgcolor": "white",
"plot_bgcolor": "#E5ECF6",
"polar": {
"angularaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"radialaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"scene": {
"xaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"yaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"zaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
}
},
"shapedefaults": {
"line": {
"color": "#2a3f5f"
}
},
"ternary": {
"aaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"baxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"caxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"title": {
"x": 0.05
},
"xaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
},
"yaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
}
}
},
"title": {
"text": "Error by Sample"
},
"xaxis": {
"anchor": "y",
"domain": [
0,
1
],
"title": {
"text": "sample"
}
},
"yaxis": {
"anchor": "x",
"domain": [
0,
1
],
"title": {
"text": "mean_squared_error"
}
}
}
}
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"df.plot(\"sample\", \"mean_squared_error\", kind=\"bar\", title=\"Error by Sample\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}