mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 14:11:59 +00:00
Completely implemented baseline for expt2
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -161,6 +161,8 @@ cython_debug/
|
|||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
|
.jj
|
||||||
|
scratchpad.ipynb
|
||||||
datasets/
|
datasets/
|
||||||
lightning_logs/
|
lightning_logs/
|
||||||
logs/
|
logs/
|
||||||
|
|||||||
1206
poetry.lock
generated
1206
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,12 @@ euporie = "^2.8.2"
|
|||||||
ipykernel = "^6.29.4"
|
ipykernel = "^6.29.4"
|
||||||
tensorboard = "^2.16.2"
|
tensorboard = "^2.16.2"
|
||||||
typer = "^0.12.3"
|
typer = "^0.12.3"
|
||||||
|
kaggle = "^1.6.14"
|
||||||
|
periodic-table-dataclasses = "^1.0"
|
||||||
|
polars = "^0.20.28"
|
||||||
|
jupyter = "^1.0.0"
|
||||||
|
safetensors = "^0.4.3"
|
||||||
|
alive-progress = "^3.1.5"
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@@ -1,26 +1,18 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torchvision.datasets import Caltech256
|
from torchvision.datasets import Caltech256
|
||||||
from torchvision.transforms import ToTensor
|
from torchvision.transforms import ToTensor
|
||||||
from torch.utils.data import random_split
|
from torch.utils.data import DataLoader, random_split
|
||||||
from torch.utils.data import BatchSampler
|
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent
|
PROJECT_ROOT = Path(__file__).parent.parent
|
||||||
|
DATASET_DIR = PROJECT_ROOT / "datasets/"
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def create_dataset(
|
||||||
split: (float, float, float) = (0.7, 0.1, 0.2),
|
split: (float, float, float) = (0.7, 0.1, 0.2),
|
||||||
dataset=Caltech256,
|
dataset=Caltech256,
|
||||||
batch_size: int = 128,
|
|
||||||
drop_last: bool = False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
_kwargs = {
|
ds = dataset(DATASET_DIR, download=True, transform=ToTensor())
|
||||||
"transform": ToTensor(),
|
train, val, test = (DataLoader(i, **kwargs) for i in random_split(ds, split))
|
||||||
}
|
|
||||||
_kwargs.update(kwargs)
|
|
||||||
ds = dataset(PROJECT_ROOT / "datasets/", download=True, **_kwargs)
|
|
||||||
train, val, test = (
|
|
||||||
BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split)
|
|
||||||
)
|
|
||||||
return train, val, test
|
return train, val, test
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
@@ -9,17 +10,35 @@ model = nn.Sequential(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def collate(batch):
|
||||||
|
x, y = zip(*batch)
|
||||||
|
x = [i[0] for i in x]
|
||||||
|
y = [torch.tensor(i) for i in y]
|
||||||
|
x = torch.stack(x).to("cuda")
|
||||||
|
y = torch.tensor(y).to("cuda")
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
# This is just a quick, lazy way to ensure all models are trained on the same dataset
|
# This is just a quick, lazy way to ensure all models are trained on the same dataset
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def get_singleton_dataset():
|
def get_singleton_dataset():
|
||||||
from torchvision.datasets import QMNIST
|
from torchvision.datasets import QMNIST
|
||||||
|
|
||||||
from symbolic_nn_tests.dataloader import get_dataset
|
from symbolic_nn_tests.dataloader import create_dataset
|
||||||
|
|
||||||
return get_dataset(dataset=QMNIST)
|
return create_dataset(
|
||||||
|
dataset=QMNIST, collate_fn=collate, batch_size=128, shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(loss_func=nn.functional.cross_entropy, logger=None, **kwargs):
|
def oh_vs_cat_cross_entropy(y_bin, y_cat):
|
||||||
|
return nn.functional.cross_entropy(
|
||||||
|
y_bin,
|
||||||
|
nn.functional.one_hot(y_cat),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(loss_func=oh_vs_cat_cross_entropy, logger=None, **kwargs):
|
||||||
import lightning as L
|
import lightning as L
|
||||||
|
|
||||||
from symbolic_nn_tests.train import TrainingWrapper
|
from symbolic_nn_tests.train import TrainingWrapper
|
||||||
|
|||||||
@@ -10,7 +10,10 @@ def create_semantic_cross_entropy(semantic_matrix):
|
|||||||
semantic_penalty = (abs_diff * penalty_tensor).sum()
|
semantic_penalty = (abs_diff * penalty_tensor).sum()
|
||||||
return ce_loss * semantic_penalty
|
return ce_loss * semantic_penalty
|
||||||
|
|
||||||
return semantic_cross_entropy
|
def oh_vs_cat_semantic_cross_entropy(input_oh, target_cat):
|
||||||
|
return semantic_cross_entropy(input_oh, torch.nn.functional.one_hot(target_cat))
|
||||||
|
|
||||||
|
return oh_vs_cat_semantic_cross_entropy
|
||||||
|
|
||||||
|
|
||||||
# NOTE: This similarity matrix defines loss scaling factors for misclassification
|
# NOTE: This similarity matrix defines loss scaling factors for misclassification
|
||||||
|
|||||||
75
symbolic_nn_tests/experiment2/__init__.py
Normal file
75
symbolic_nn_tests/experiment2/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
LEARNING_RATE = 10e-5
|
||||||
|
|
||||||
|
|
||||||
|
def test(loss_func, version, tensorboard=True, wandb=True):
|
||||||
|
from .model import main as test_model
|
||||||
|
|
||||||
|
logger = []
|
||||||
|
|
||||||
|
if tensorboard:
|
||||||
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
|
tb_logger = TensorBoardLogger(
|
||||||
|
save_dir=".",
|
||||||
|
name="logs/comparison",
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
|
logger.append(tb_logger)
|
||||||
|
|
||||||
|
if wandb:
|
||||||
|
import wandb as _wandb
|
||||||
|
from lightning.pytorch.loggers import WandbLogger
|
||||||
|
|
||||||
|
wandb_logger = WandbLogger(
|
||||||
|
project="Symbolic_NN_Tests",
|
||||||
|
name=version,
|
||||||
|
dir="wandb",
|
||||||
|
)
|
||||||
|
logger.append(wandb_logger)
|
||||||
|
|
||||||
|
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
|
||||||
|
|
||||||
|
if wandb:
|
||||||
|
_wandb.finish()
|
||||||
|
|
||||||
|
|
||||||
|
def run(tensorboard: bool = True, wandb: bool = True):
|
||||||
|
from . import semantic_loss
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
test(
|
||||||
|
nn.functional.cross_entropy,
|
||||||
|
"cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
semantic_loss.similarity_cross_entropy,
|
||||||
|
"similarity_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
semantic_loss.hasline_cross_entropy,
|
||||||
|
"hasline_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
semantic_loss.hasloop_cross_entropy,
|
||||||
|
"hasloop_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
semantic_loss.multisemantic_cross_entropy,
|
||||||
|
"multisemantic_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
semantic_loss.garbage_cross_entropy,
|
||||||
|
"garbage_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
244
symbolic_nn_tests/experiment2/dataset.py
Normal file
244
symbolic_nn_tests/experiment2/dataset.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
import kaggle
|
||||||
|
import polars as pl
|
||||||
|
import shutil
|
||||||
|
from functools import lru_cache
|
||||||
|
from periodic_table import PeriodicTable
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import TensorDataset
|
||||||
|
import pickle
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from symbolic_nn_tests.dataloader import DATASET_DIR
|
||||||
|
import warnings
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
warnings.filterwarnings(action="ignore", category=UserWarning)
|
||||||
|
|
||||||
|
|
||||||
|
PUBCHEM_DIR = DATASET_DIR / "PubChem"
|
||||||
|
|
||||||
|
ORBITALS = {
|
||||||
|
"s": (0, 2),
|
||||||
|
"p": (1, 6),
|
||||||
|
"d": (2, 10),
|
||||||
|
"f": (3, 14),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def collate(batch):
|
||||||
|
x0_in, x1_in, y_in = list(zip(*batch))
|
||||||
|
x0_out = torch.nested.as_nested_tensor(list(x0_in))
|
||||||
|
x1_out = torch.nested.as_nested_tensor(list(x1_in))
|
||||||
|
y_out = torch.as_tensor(y_in)
|
||||||
|
return (x0_out, x1_out), y_out
|
||||||
|
|
||||||
|
|
||||||
|
def pubchem(*args, **kwargs):
|
||||||
|
PUBCHEM_DIR.mkdir(exist_ok=True, parents=True)
|
||||||
|
return get_dataset()
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset():
|
||||||
|
if not (
|
||||||
|
"pubchem_x0.pickle"
|
||||||
|
and "pubchem_x1.pickle"
|
||||||
|
and "pubchem_y.pickle" in (x.name for x in PUBCHEM_DIR.iterdir())
|
||||||
|
):
|
||||||
|
construct_dataset("pubchem")
|
||||||
|
else:
|
||||||
|
print("Pre-existing dataset detected!")
|
||||||
|
print("Dataset loaded!")
|
||||||
|
return TensorDataset(*load_dataset("pubchem"))
|
||||||
|
|
||||||
|
|
||||||
|
def construct_dataset(filename):
|
||||||
|
print("Constructing dataset...")
|
||||||
|
df = construct_ds_dataframe(filename)
|
||||||
|
save_dataframe_to_dataset(df, PUBCHEM_DIR / f"{filename}.pickle")
|
||||||
|
print("Dataset constructed!")
|
||||||
|
|
||||||
|
|
||||||
|
def construct_ds_dataframe(filename):
|
||||||
|
print("Constructing dataset dataframe...")
|
||||||
|
df = add_molecule_encodings(construct_raw_dataset(filename))
|
||||||
|
# NOTE: This kind of checkpointing will be used throughout the construction process It doesn't
|
||||||
|
# take much disk space, it lets the GC collect out-of-scope data from the construction process
|
||||||
|
# and it makes it easier to debug if construction fails
|
||||||
|
parquet_file = PUBCHEM_DIR / f"{filename}.parquet"
|
||||||
|
df.write_parquet(parquet_file)
|
||||||
|
print("Dataset dataframe constructed!")
|
||||||
|
return pl.read_parquet(parquet_file)
|
||||||
|
|
||||||
|
|
||||||
|
def construct_raw_dataset(filename):
|
||||||
|
print("Constructing raw dataset...")
|
||||||
|
df = collate_dataset()
|
||||||
|
parquet_file = PUBCHEM_DIR / f"{filename}_raw.parquet"
|
||||||
|
df.write_parquet(parquet_file)
|
||||||
|
print("Raw dataset constructed!")
|
||||||
|
return pl.read_parquet(parquet_file)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_dataset():
|
||||||
|
print("Collating dataset...")
|
||||||
|
if not (PUBCHEM_DIR.exists() and len(tuple(PUBCHEM_DIR.glob("*.json")))):
|
||||||
|
fetch_dataset()
|
||||||
|
|
||||||
|
df = pl.concat(
|
||||||
|
map(pl.read_json, PUBCHEM_DIR.glob("*.json")),
|
||||||
|
).drop("id")
|
||||||
|
print("dataset collated!")
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_dataset():
|
||||||
|
print("Fetching dataset...")
|
||||||
|
kaggle.api.dataset_download_files(
|
||||||
|
"burakhmmtgl/predict-molecular-properties", quiet=False, path=DATASET_DIR
|
||||||
|
)
|
||||||
|
shutil.unpack_archive(DATASET_DIR / "predict-molecular-properties.zip", PUBCHEM_DIR)
|
||||||
|
print("Dataset fetched!")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_periodic_table():
|
||||||
|
return PeriodicTable()
|
||||||
|
|
||||||
|
|
||||||
|
def add_molecule_encodings(df):
|
||||||
|
atom_properties, atom_electrons = encode_molecules(df["atoms"])
|
||||||
|
return df.with_columns(
|
||||||
|
atom_properties=atom_properties, atom_electrons=atom_electrons
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_molecules(series):
|
||||||
|
# Yes, it is gross and RAM inefficient to do it this way but i dont have all day...
|
||||||
|
with Pool() as p:
|
||||||
|
molecules = p.map(encode_molecule, series)
|
||||||
|
properties, electrons = zip(*molecules)
|
||||||
|
return pl.Series(properties), pl.Series(electrons)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_molecule(molecule):
|
||||||
|
properties, electrons = zip(*_encode_molecule(molecule))
|
||||||
|
properties = pl.Series(properties)
|
||||||
|
return properties, electrons
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_molecule(molecule):
|
||||||
|
for atom in molecule:
|
||||||
|
properties, electrons = encode_atom(atom["type"])
|
||||||
|
yield np.array([*properties, *atom["xyz"]]), pl.Series(electrons)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_atom(atom):
|
||||||
|
element = get_periodic_table().search_symbol(atom)
|
||||||
|
return (
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
# n and z need to be scaled somehow to normalize to approximately 1
|
||||||
|
# Because this is kind arbitrary i've decided to scale relative to
|
||||||
|
# Fermium (n = 100)
|
||||||
|
element.atomic / 100.0,
|
||||||
|
element.atomic_mass / 257.0,
|
||||||
|
element.electron_affinity / 350.0, # Highest known is just below 350
|
||||||
|
element.electronegativity_pauling
|
||||||
|
/ 4.0, # Max theoretical val is 4.0 here
|
||||||
|
],
|
||||||
|
),
|
||||||
|
encode_electron_config(element.electron_configuration),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_electron_config(config):
|
||||||
|
return np.array([encode_orbital(x) for x in config.split()])
|
||||||
|
|
||||||
|
|
||||||
|
def encode_orbital(orbital):
|
||||||
|
shell, subshell, *n = orbital
|
||||||
|
shell = int(shell)
|
||||||
|
n = int("".join(n))
|
||||||
|
azimuthal, capacity = ORBITALS[subshell]
|
||||||
|
return np.array(
|
||||||
|
[
|
||||||
|
1.0
|
||||||
|
/ shell, # This is the simplest way to normalize shell, as shells become less distinct as n increases
|
||||||
|
azimuthal / 4.0, # This is simply normalizing the azimuthal quantum number
|
||||||
|
n / capacity, # Basically encoding this as a proportion of "fullness"
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_dataframe_to_dataset(df, filename):
|
||||||
|
print("Saving dataset to tensors...")
|
||||||
|
with (filename.parent / f"{filename.stem}_x0{filename.suffix}").open("wb") as f:
|
||||||
|
pickle.dump(properties_to_tensor(df).float(), f)
|
||||||
|
with (filename.parent / f"{filename.stem}_x1{filename.suffix}").open("wb") as f:
|
||||||
|
pickle.dump(electrons_to_tensor(df).float(), f)
|
||||||
|
with (filename.parent / f"{filename.stem}_y{filename.suffix}").open("wb") as f:
|
||||||
|
pickle.dump(df["En"].to_torch().float(), f)
|
||||||
|
del df
|
||||||
|
print("Tensors saved!")
|
||||||
|
|
||||||
|
|
||||||
|
def chunked_df(df, n):
|
||||||
|
chunk_size = (len(df) // n) + 1
|
||||||
|
chunk_boundaries = [*range(0, len(df), chunk_size), len(df)]
|
||||||
|
chunk_ranges = list(zip(chunk_boundaries[:-1], chunk_boundaries[1:]))
|
||||||
|
yield from (df[i:j] for i, j in chunk_ranges)
|
||||||
|
|
||||||
|
|
||||||
|
def properties_to_tensor(df):
|
||||||
|
with Pool() as p:
|
||||||
|
out = torch.cat(
|
||||||
|
p.map(
|
||||||
|
property_chunk_to_torch, chunked_df(df["atom_properties"], p._processes)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def property_chunk_to_torch(chunk):
|
||||||
|
return torch.nested.nested_tensor([properties_to_torch(x) for x in chunk])
|
||||||
|
|
||||||
|
|
||||||
|
def properties_to_torch(p):
|
||||||
|
return torch.stack(tuple(map(pl.Series.to_torch, p)))
|
||||||
|
|
||||||
|
|
||||||
|
def electrons_to_tensor(df):
|
||||||
|
return torch.nested.nested_tensor(
|
||||||
|
[
|
||||||
|
molecule_electrons_to_torch(e)
|
||||||
|
for e in tqdm(df["atom_electrons"], desc="Converting molecules to orbitals")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def molecule_electrons_to_torch(e):
|
||||||
|
return torch.stack([atom_electrons_to_torch(x) for x in e])
|
||||||
|
|
||||||
|
|
||||||
|
def atom_electrons_to_torch(e):
|
||||||
|
# pytorch doesn't like doubly nested tensors, and the unnocupied orbitals still exist here even if
|
||||||
|
# they're empty, so it makes sense to pad here instead. No elements in the dataset exceed an
|
||||||
|
# azimuthal of 3, so we only need to pad to length 10. Also: i'm realising here that the orbital
|
||||||
|
# info will be unecessary if we have to pad here anyway
|
||||||
|
return pad_tensor_to(torch.tensor(tuple(x[-1] for x in e)), 10)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_tensor_to(t, length):
|
||||||
|
return torch.nn.functional.pad(t, (0, length - t.shape[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(filename):
|
||||||
|
filepath = PUBCHEM_DIR / f"{filename}.pickle"
|
||||||
|
with (filepath.parent / f"{filepath.stem}_x0{filepath.suffix}").open("rb") as f:
|
||||||
|
x0 = pickle.load(f)
|
||||||
|
with (filepath.parent / f"{filepath.stem}_x1{filepath.suffix}").open("rb") as f:
|
||||||
|
x1 = pickle.load(f)
|
||||||
|
with (filepath.parent / f"{filepath.stem}_y{filepath.suffix}").open("rb") as f:
|
||||||
|
y = pickle.load(f)
|
||||||
|
return x0, x1, y
|
||||||
76
symbolic_nn_tests/experiment2/model.py
Normal file
76
symbolic_nn_tests/experiment2/model.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.x0_encoder = nn.TransformerEncoderLayer(7, 7)
|
||||||
|
self.x1_encoder = nn.TransformerEncoderLayer(10, 10)
|
||||||
|
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
|
||||||
|
self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder)
|
||||||
|
self.ff = nn.Sequential(
|
||||||
|
nn.Linear(17, 256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(256, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(128, 64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(64, 32),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(32, 16),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(16, 8),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(8, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_xval_encoding_fn(layer):
|
||||||
|
def encoding_fn(xbatch):
|
||||||
|
return torch.stack([layer(x)[-1] for x in xbatch])
|
||||||
|
|
||||||
|
return encoding_fn
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x0, x1 = x
|
||||||
|
x0 = self.encode_x0(x0)
|
||||||
|
x1 = self.encode_x1(x1)
|
||||||
|
x = torch.cat([x0, x1], dim=1)
|
||||||
|
y = self.ff(x)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
# This is just a quick, lazy way to ensure all models are trained on the same dataset
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_singleton_dataset():
|
||||||
|
from symbolic_nn_tests.dataloader import create_dataset
|
||||||
|
from symbolic_nn_tests.experiment2.dataset import collate, pubchem
|
||||||
|
|
||||||
|
return create_dataset(
|
||||||
|
dataset=pubchem, collate_fn=collate, batch_size=128, shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(loss_func=nn.functional.smooth_l1_loss, logger=None, **kwargs):
|
||||||
|
import lightning as L
|
||||||
|
|
||||||
|
from symbolic_nn_tests.train import TrainingWrapper
|
||||||
|
|
||||||
|
if logger is None:
|
||||||
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
|
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
||||||
|
|
||||||
|
train, val, test = get_singleton_dataset()
|
||||||
|
lmodel = TrainingWrapper(Model(), loss_func=loss_func)
|
||||||
|
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
|
||||||
|
trainer = L.Trainer(max_epochs=20, logger=logger)
|
||||||
|
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||||
|
trainer.test(dataloaders=test)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
symbolic_nn_tests/experiment2/semantic_loss.py
Normal file
1
symbolic_nn_tests/experiment2/semantic_loss.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
import torch
|
||||||
@@ -1,37 +1,26 @@
|
|||||||
import torch
|
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
import lightning as L
|
import lightning as L
|
||||||
|
|
||||||
|
|
||||||
def collate_batch(batch):
|
|
||||||
x, y = zip(*batch)
|
|
||||||
x = [i[0] for i in x]
|
|
||||||
y = [torch.tensor(i) for i in y]
|
|
||||||
x = torch.stack(x).to("cuda")
|
|
||||||
y = torch.tensor(y).to("cuda")
|
|
||||||
return x, y
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingWrapper(L.LightningModule):
|
class TrainingWrapper(L.LightningModule):
|
||||||
def __init__(self, model, loss_func=nn.functional.cross_entropy):
|
def __init__(self, model, loss_func=nn.functional.mse_loss, accuracy=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.loss_func = loss_func
|
self.loss_func = loss_func
|
||||||
|
self.accuracy = accuracy
|
||||||
|
|
||||||
def _forward_step(self, batch, batch_idx, label=""):
|
def _forward_step(self, batch, batch_idx, label=""):
|
||||||
x, y = collate_batch(batch)
|
x, y = batch
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x)
|
||||||
batch_size = x.shape[0]
|
loss = self.loss_func(y_pred, y)
|
||||||
one_hot_y = nn.functional.one_hot(y).type(torch.float64)
|
self.log(f"{label}{'_' if label else ''}loss", loss)
|
||||||
loss = self.loss_func(y_pred, one_hot_y)
|
if self.accuracy is not None:
|
||||||
acc = torch.sum(y_pred.argmax(dim=1) == y) / batch_size
|
acc = self.accuracy(y_pred, y)
|
||||||
self.log(f"{label}{'_' if label else ''}loss", loss, batch_size=batch_size)
|
self.log(f"{label}{'_' if label else ''}acc", acc)
|
||||||
self.log(f"{label}{'_' if label else ''}acc", acc, batch_size=batch_size)
|
return loss
|
||||||
return loss, acc
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
loss, _ = self._forward_step(batch, batch_idx, label="train")
|
return self._forward_step(batch, batch_idx, label="train")
|
||||||
return loss
|
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
self._forward_step(batch, batch_idx, label="val")
|
self._forward_step(batch, batch_idx, label="val")
|
||||||
|
|||||||
Reference in New Issue
Block a user