In [1]:
from pathlib import Path
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchmetrics import MeanSquaredError
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from expt1 import X, y, collate_fn, device
from expt3 import model

pd.options.plotting.backend = "plotly"

In [2]:
model.eval()

checkpoints = tuple(Path("checkpoints").glob("*.ckpt"))

test = DataLoader(
    list(zip(X.values(), y.values())),
    collate_fn=collate_fn,
    shuffle=False,
)

In [11]:
# Evaluate all checkpoints from expt3 on all samples
sample_df = pd.DataFrame(columns=["sample"], data=X.keys()).set_index("sample")
overall_df = []
with torch.no_grad():
    for checkpoint in tqdm(checkpoints, desc="Evaluating checkpoints"):
        model.load_state_dict(torch.load(checkpoint)["state_dict"])
        y_ground = []
        y_eval = []
        for sample_X0, sample_X1, sample_y in tqdm(
            test, desc="Evaluating samples", leave=False
        ):
            y_ground.append(sample_y.to("cpu"))
            y_eval.append(
                model(sample_X0, sample_X1).to("cpu")
            )  # Move to CPU because otherwise will run out of VRAM
            torch.cuda.empty_cache()  # empty cache to ensure maximum VRAM available
        y_eval, y_ground = torch.stack(y_eval), torch.stack(y_ground)
        sample_df[checkpoint.name] = [
            nn.MeanSquaredError()(ye, yg).item() for ye, yg in zip(y_eval, y_ground)
        ]
        overall_df.append(
            (
                model.__name__,
                checkpoint.name,
                MeanSquaredError()(y_eval, y_ground).item(),
            )
        )
    overall_df = pd.DataFrame(columns=["model", "checkpoint", "loss"], data=overall_df)
    torch.cuda.empty_cache()

overall_df.to_csv("expt3_overall.csv")
sample_df.to_csv("expt3_sample.csv")

Evaluating checkpoints:   0%|          | 0/10 [00:00<?, ?it/s]

Evaluating samples:   0%|          | 0/81 [00:00<?, ?it/s]

In [None]:
# checkpoint = "checkpoints/checkpoint-epoch=12-val_loss=0.00.ckpt"
# model.load_state_dict(torch.load(checkpoint)["state_dict"])
y_ground = []
y_eval = []
with torch.no_grad():
    for sample_X0, sample_X1, sample_y in tqdm(
        test, desc="Evaluating samples", leave=False
    ):
        y_ground.append(sample_y.to("cpu"))
        y_eval.append(
            model(sample_X0, sample_X1).to("cpu")
        )  # Move to CPU because otherwise will run out of VRAM
        torch.cuda.empty_cache()  # empty cache to ensure maximum VRAM available
    y_eval, y_ground = torch.stack(y_eval), torch.stack(y_ground)
    sample_accuracy = [
        MeanSquaredError()(ye, yg).item() for ye, yg in zip(y_eval, y_ground)
    ]
    overall_accuracy = MeanSquaredError()(y_eval, y_ground).item()

In [10]:
df = pd.DataFrame()
df["sample"] = ["Overall"] + list(str(i) for i in range(len(sample_accuracy)))
df["mean_squared_error"] = [overall_accuracy] + sample_accuracy
df["y_ground_Ni"] = [None] + [y[0][0].item() for y in y_ground]
df["y_ground_Ti"] = [None] + [y[0][1].item() for y in y_ground]
df["y_eval_Ni"] = [None] + [y[0][0].item() for y in y_eval]
df["y_eval_Ti"] = [None] + [y[0][1].item() for y in y_eval]

In [11]:
df.plot("sample", "mean_squared_error", kind="bar", title="Error by Sample")