mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 14:11:59 +00:00
Completely implemented baseline for expt2
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -161,6 +161,8 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
.jj
|
||||
scratchpad.ipynb
|
||||
datasets/
|
||||
lightning_logs/
|
||||
logs/
|
||||
|
||||
1206
poetry.lock
generated
1206
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,12 @@ euporie = "^2.8.2"
|
||||
ipykernel = "^6.29.4"
|
||||
tensorboard = "^2.16.2"
|
||||
typer = "^0.12.3"
|
||||
kaggle = "^1.6.14"
|
||||
periodic-table-dataclasses = "^1.0"
|
||||
polars = "^0.20.28"
|
||||
jupyter = "^1.0.0"
|
||||
safetensors = "^0.4.3"
|
||||
alive-progress = "^3.1.5"
|
||||
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -1,26 +1,18 @@
|
||||
from pathlib import Path
|
||||
from torchvision.datasets import Caltech256
|
||||
from torchvision.transforms import ToTensor
|
||||
from torch.utils.data import random_split
|
||||
from torch.utils.data import BatchSampler
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
DATASET_DIR = PROJECT_ROOT / "datasets/"
|
||||
|
||||
|
||||
def get_dataset(
|
||||
def create_dataset(
|
||||
split: (float, float, float) = (0.7, 0.1, 0.2),
|
||||
dataset=Caltech256,
|
||||
batch_size: int = 128,
|
||||
drop_last: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
_kwargs = {
|
||||
"transform": ToTensor(),
|
||||
}
|
||||
_kwargs.update(kwargs)
|
||||
ds = dataset(PROJECT_ROOT / "datasets/", download=True, **_kwargs)
|
||||
train, val, test = (
|
||||
BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split)
|
||||
)
|
||||
ds = dataset(DATASET_DIR, download=True, transform=ToTensor())
|
||||
train, val, test = (DataLoader(i, **kwargs) for i in random_split(ds, split))
|
||||
return train, val, test
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from functools import lru_cache
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
@@ -9,17 +10,35 @@ model = nn.Sequential(
|
||||
)
|
||||
|
||||
|
||||
def collate(batch):
|
||||
x, y = zip(*batch)
|
||||
x = [i[0] for i in x]
|
||||
y = [torch.tensor(i) for i in y]
|
||||
x = torch.stack(x).to("cuda")
|
||||
y = torch.tensor(y).to("cuda")
|
||||
return x, y
|
||||
|
||||
|
||||
# This is just a quick, lazy way to ensure all models are trained on the same dataset
|
||||
@lru_cache(maxsize=1)
|
||||
def get_singleton_dataset():
|
||||
from torchvision.datasets import QMNIST
|
||||
|
||||
from symbolic_nn_tests.dataloader import get_dataset
|
||||
from symbolic_nn_tests.dataloader import create_dataset
|
||||
|
||||
return get_dataset(dataset=QMNIST)
|
||||
return create_dataset(
|
||||
dataset=QMNIST, collate_fn=collate, batch_size=128, shuffle=True
|
||||
)
|
||||
|
||||
|
||||
def main(loss_func=nn.functional.cross_entropy, logger=None, **kwargs):
|
||||
def oh_vs_cat_cross_entropy(y_bin, y_cat):
|
||||
return nn.functional.cross_entropy(
|
||||
y_bin,
|
||||
nn.functional.one_hot(y_cat),
|
||||
)
|
||||
|
||||
|
||||
def main(loss_func=oh_vs_cat_cross_entropy, logger=None, **kwargs):
|
||||
import lightning as L
|
||||
|
||||
from symbolic_nn_tests.train import TrainingWrapper
|
||||
|
||||
@@ -10,7 +10,10 @@ def create_semantic_cross_entropy(semantic_matrix):
|
||||
semantic_penalty = (abs_diff * penalty_tensor).sum()
|
||||
return ce_loss * semantic_penalty
|
||||
|
||||
return semantic_cross_entropy
|
||||
def oh_vs_cat_semantic_cross_entropy(input_oh, target_cat):
|
||||
return semantic_cross_entropy(input_oh, torch.nn.functional.one_hot(target_cat))
|
||||
|
||||
return oh_vs_cat_semantic_cross_entropy
|
||||
|
||||
|
||||
# NOTE: This similarity matrix defines loss scaling factors for misclassification
|
||||
|
||||
75
symbolic_nn_tests/experiment2/__init__.py
Normal file
75
symbolic_nn_tests/experiment2/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
LEARNING_RATE = 10e-5
|
||||
|
||||
|
||||
def test(loss_func, version, tensorboard=True, wandb=True):
|
||||
from .model import main as test_model
|
||||
|
||||
logger = []
|
||||
|
||||
if tensorboard:
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
tb_logger = TensorBoardLogger(
|
||||
save_dir=".",
|
||||
name="logs/comparison",
|
||||
version=version,
|
||||
)
|
||||
logger.append(tb_logger)
|
||||
|
||||
if wandb:
|
||||
import wandb as _wandb
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
|
||||
wandb_logger = WandbLogger(
|
||||
project="Symbolic_NN_Tests",
|
||||
name=version,
|
||||
dir="wandb",
|
||||
)
|
||||
logger.append(wandb_logger)
|
||||
|
||||
test_model(logger=logger, loss_func=loss_func, lr=LEARNING_RATE)
|
||||
|
||||
if wandb:
|
||||
_wandb.finish()
|
||||
|
||||
|
||||
def run(tensorboard: bool = True, wandb: bool = True):
|
||||
from . import semantic_loss
|
||||
from torch import nn
|
||||
|
||||
test(
|
||||
nn.functional.cross_entropy,
|
||||
"cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
semantic_loss.similarity_cross_entropy,
|
||||
"similarity_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
semantic_loss.hasline_cross_entropy,
|
||||
"hasline_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
semantic_loss.hasloop_cross_entropy,
|
||||
"hasloop_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
semantic_loss.multisemantic_cross_entropy,
|
||||
"multisemantic_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
semantic_loss.garbage_cross_entropy,
|
||||
"garbage_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
244
symbolic_nn_tests/experiment2/dataset.py
Normal file
244
symbolic_nn_tests/experiment2/dataset.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import kaggle
|
||||
import polars as pl
|
||||
import shutil
|
||||
from functools import lru_cache
|
||||
from periodic_table import PeriodicTable
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset
|
||||
import pickle
|
||||
from multiprocessing import Pool
|
||||
from symbolic_nn_tests.dataloader import DATASET_DIR
|
||||
import warnings
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning)
|
||||
|
||||
|
||||
PUBCHEM_DIR = DATASET_DIR / "PubChem"
|
||||
|
||||
ORBITALS = {
|
||||
"s": (0, 2),
|
||||
"p": (1, 6),
|
||||
"d": (2, 10),
|
||||
"f": (3, 14),
|
||||
}
|
||||
|
||||
|
||||
def collate(batch):
|
||||
x0_in, x1_in, y_in = list(zip(*batch))
|
||||
x0_out = torch.nested.as_nested_tensor(list(x0_in))
|
||||
x1_out = torch.nested.as_nested_tensor(list(x1_in))
|
||||
y_out = torch.as_tensor(y_in)
|
||||
return (x0_out, x1_out), y_out
|
||||
|
||||
|
||||
def pubchem(*args, **kwargs):
|
||||
PUBCHEM_DIR.mkdir(exist_ok=True, parents=True)
|
||||
return get_dataset()
|
||||
|
||||
|
||||
def get_dataset():
|
||||
if not (
|
||||
"pubchem_x0.pickle"
|
||||
and "pubchem_x1.pickle"
|
||||
and "pubchem_y.pickle" in (x.name for x in PUBCHEM_DIR.iterdir())
|
||||
):
|
||||
construct_dataset("pubchem")
|
||||
else:
|
||||
print("Pre-existing dataset detected!")
|
||||
print("Dataset loaded!")
|
||||
return TensorDataset(*load_dataset("pubchem"))
|
||||
|
||||
|
||||
def construct_dataset(filename):
|
||||
print("Constructing dataset...")
|
||||
df = construct_ds_dataframe(filename)
|
||||
save_dataframe_to_dataset(df, PUBCHEM_DIR / f"{filename}.pickle")
|
||||
print("Dataset constructed!")
|
||||
|
||||
|
||||
def construct_ds_dataframe(filename):
|
||||
print("Constructing dataset dataframe...")
|
||||
df = add_molecule_encodings(construct_raw_dataset(filename))
|
||||
# NOTE: This kind of checkpointing will be used throughout the construction process It doesn't
|
||||
# take much disk space, it lets the GC collect out-of-scope data from the construction process
|
||||
# and it makes it easier to debug if construction fails
|
||||
parquet_file = PUBCHEM_DIR / f"{filename}.parquet"
|
||||
df.write_parquet(parquet_file)
|
||||
print("Dataset dataframe constructed!")
|
||||
return pl.read_parquet(parquet_file)
|
||||
|
||||
|
||||
def construct_raw_dataset(filename):
|
||||
print("Constructing raw dataset...")
|
||||
df = collate_dataset()
|
||||
parquet_file = PUBCHEM_DIR / f"{filename}_raw.parquet"
|
||||
df.write_parquet(parquet_file)
|
||||
print("Raw dataset constructed!")
|
||||
return pl.read_parquet(parquet_file)
|
||||
|
||||
|
||||
def collate_dataset():
|
||||
print("Collating dataset...")
|
||||
if not (PUBCHEM_DIR.exists() and len(tuple(PUBCHEM_DIR.glob("*.json")))):
|
||||
fetch_dataset()
|
||||
|
||||
df = pl.concat(
|
||||
map(pl.read_json, PUBCHEM_DIR.glob("*.json")),
|
||||
).drop("id")
|
||||
print("dataset collated!")
|
||||
return df
|
||||
|
||||
|
||||
def fetch_dataset():
|
||||
print("Fetching dataset...")
|
||||
kaggle.api.dataset_download_files(
|
||||
"burakhmmtgl/predict-molecular-properties", quiet=False, path=DATASET_DIR
|
||||
)
|
||||
shutil.unpack_archive(DATASET_DIR / "predict-molecular-properties.zip", PUBCHEM_DIR)
|
||||
print("Dataset fetched!")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_periodic_table():
|
||||
return PeriodicTable()
|
||||
|
||||
|
||||
def add_molecule_encodings(df):
|
||||
atom_properties, atom_electrons = encode_molecules(df["atoms"])
|
||||
return df.with_columns(
|
||||
atom_properties=atom_properties, atom_electrons=atom_electrons
|
||||
)
|
||||
|
||||
|
||||
def encode_molecules(series):
|
||||
# Yes, it is gross and RAM inefficient to do it this way but i dont have all day...
|
||||
with Pool() as p:
|
||||
molecules = p.map(encode_molecule, series)
|
||||
properties, electrons = zip(*molecules)
|
||||
return pl.Series(properties), pl.Series(electrons)
|
||||
|
||||
|
||||
def encode_molecule(molecule):
|
||||
properties, electrons = zip(*_encode_molecule(molecule))
|
||||
properties = pl.Series(properties)
|
||||
return properties, electrons
|
||||
|
||||
|
||||
def _encode_molecule(molecule):
|
||||
for atom in molecule:
|
||||
properties, electrons = encode_atom(atom["type"])
|
||||
yield np.array([*properties, *atom["xyz"]]), pl.Series(electrons)
|
||||
|
||||
|
||||
def encode_atom(atom):
|
||||
element = get_periodic_table().search_symbol(atom)
|
||||
return (
|
||||
np.array(
|
||||
[
|
||||
# n and z need to be scaled somehow to normalize to approximately 1
|
||||
# Because this is kind arbitrary i've decided to scale relative to
|
||||
# Fermium (n = 100)
|
||||
element.atomic / 100.0,
|
||||
element.atomic_mass / 257.0,
|
||||
element.electron_affinity / 350.0, # Highest known is just below 350
|
||||
element.electronegativity_pauling
|
||||
/ 4.0, # Max theoretical val is 4.0 here
|
||||
],
|
||||
),
|
||||
encode_electron_config(element.electron_configuration),
|
||||
)
|
||||
|
||||
|
||||
def encode_electron_config(config):
|
||||
return np.array([encode_orbital(x) for x in config.split()])
|
||||
|
||||
|
||||
def encode_orbital(orbital):
|
||||
shell, subshell, *n = orbital
|
||||
shell = int(shell)
|
||||
n = int("".join(n))
|
||||
azimuthal, capacity = ORBITALS[subshell]
|
||||
return np.array(
|
||||
[
|
||||
1.0
|
||||
/ shell, # This is the simplest way to normalize shell, as shells become less distinct as n increases
|
||||
azimuthal / 4.0, # This is simply normalizing the azimuthal quantum number
|
||||
n / capacity, # Basically encoding this as a proportion of "fullness"
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def save_dataframe_to_dataset(df, filename):
|
||||
print("Saving dataset to tensors...")
|
||||
with (filename.parent / f"{filename.stem}_x0{filename.suffix}").open("wb") as f:
|
||||
pickle.dump(properties_to_tensor(df).float(), f)
|
||||
with (filename.parent / f"{filename.stem}_x1{filename.suffix}").open("wb") as f:
|
||||
pickle.dump(electrons_to_tensor(df).float(), f)
|
||||
with (filename.parent / f"{filename.stem}_y{filename.suffix}").open("wb") as f:
|
||||
pickle.dump(df["En"].to_torch().float(), f)
|
||||
del df
|
||||
print("Tensors saved!")
|
||||
|
||||
|
||||
def chunked_df(df, n):
|
||||
chunk_size = (len(df) // n) + 1
|
||||
chunk_boundaries = [*range(0, len(df), chunk_size), len(df)]
|
||||
chunk_ranges = list(zip(chunk_boundaries[:-1], chunk_boundaries[1:]))
|
||||
yield from (df[i:j] for i, j in chunk_ranges)
|
||||
|
||||
|
||||
def properties_to_tensor(df):
|
||||
with Pool() as p:
|
||||
out = torch.cat(
|
||||
p.map(
|
||||
property_chunk_to_torch, chunked_df(df["atom_properties"], p._processes)
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def property_chunk_to_torch(chunk):
|
||||
return torch.nested.nested_tensor([properties_to_torch(x) for x in chunk])
|
||||
|
||||
|
||||
def properties_to_torch(p):
|
||||
return torch.stack(tuple(map(pl.Series.to_torch, p)))
|
||||
|
||||
|
||||
def electrons_to_tensor(df):
|
||||
return torch.nested.nested_tensor(
|
||||
[
|
||||
molecule_electrons_to_torch(e)
|
||||
for e in tqdm(df["atom_electrons"], desc="Converting molecules to orbitals")
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def molecule_electrons_to_torch(e):
|
||||
return torch.stack([atom_electrons_to_torch(x) for x in e])
|
||||
|
||||
|
||||
def atom_electrons_to_torch(e):
|
||||
# pytorch doesn't like doubly nested tensors, and the unnocupied orbitals still exist here even if
|
||||
# they're empty, so it makes sense to pad here instead. No elements in the dataset exceed an
|
||||
# azimuthal of 3, so we only need to pad to length 10. Also: i'm realising here that the orbital
|
||||
# info will be unecessary if we have to pad here anyway
|
||||
return pad_tensor_to(torch.tensor(tuple(x[-1] for x in e)), 10)
|
||||
|
||||
|
||||
def pad_tensor_to(t, length):
|
||||
return torch.nn.functional.pad(t, (0, length - t.shape[0]))
|
||||
|
||||
|
||||
def load_dataset(filename):
|
||||
filepath = PUBCHEM_DIR / f"{filename}.pickle"
|
||||
with (filepath.parent / f"{filepath.stem}_x0{filepath.suffix}").open("rb") as f:
|
||||
x0 = pickle.load(f)
|
||||
with (filepath.parent / f"{filepath.stem}_x1{filepath.suffix}").open("rb") as f:
|
||||
x1 = pickle.load(f)
|
||||
with (filepath.parent / f"{filepath.stem}_y{filepath.suffix}").open("rb") as f:
|
||||
y = pickle.load(f)
|
||||
return x0, x1, y
|
||||
76
symbolic_nn_tests/experiment2/model.py
Normal file
76
symbolic_nn_tests/experiment2/model.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from functools import lru_cache
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.x0_encoder = nn.TransformerEncoderLayer(7, 7)
|
||||
self.x1_encoder = nn.TransformerEncoderLayer(10, 10)
|
||||
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
|
||||
self.encode_x1 = self.create_xval_encoding_fn(self.x1_encoder)
|
||||
self.ff = nn.Sequential(
|
||||
nn.Linear(17, 256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, 16),
|
||||
nn.ReLU(),
|
||||
nn.Linear(16, 8),
|
||||
nn.ReLU(),
|
||||
nn.Linear(8, 1),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_xval_encoding_fn(layer):
|
||||
def encoding_fn(xbatch):
|
||||
return torch.stack([layer(x)[-1] for x in xbatch])
|
||||
|
||||
return encoding_fn
|
||||
|
||||
def forward(self, x):
|
||||
x0, x1 = x
|
||||
x0 = self.encode_x0(x0)
|
||||
x1 = self.encode_x1(x1)
|
||||
x = torch.cat([x0, x1], dim=1)
|
||||
y = self.ff(x)
|
||||
return y
|
||||
|
||||
|
||||
# This is just a quick, lazy way to ensure all models are trained on the same dataset
|
||||
@lru_cache(maxsize=1)
|
||||
def get_singleton_dataset():
|
||||
from symbolic_nn_tests.dataloader import create_dataset
|
||||
from symbolic_nn_tests.experiment2.dataset import collate, pubchem
|
||||
|
||||
return create_dataset(
|
||||
dataset=pubchem, collate_fn=collate, batch_size=128, shuffle=True
|
||||
)
|
||||
|
||||
|
||||
def main(loss_func=nn.functional.smooth_l1_loss, logger=None, **kwargs):
|
||||
import lightning as L
|
||||
|
||||
from symbolic_nn_tests.train import TrainingWrapper
|
||||
|
||||
if logger is None:
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
||||
|
||||
train, val, test = get_singleton_dataset()
|
||||
lmodel = TrainingWrapper(Model(), loss_func=loss_func)
|
||||
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
|
||||
trainer = L.Trainer(max_epochs=20, logger=logger)
|
||||
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||
trainer.test(dataloaders=test)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
symbolic_nn_tests/experiment2/semantic_loss.py
Normal file
1
symbolic_nn_tests/experiment2/semantic_loss.py
Normal file
@@ -0,0 +1 @@
|
||||
import torch
|
||||
@@ -1,37 +1,26 @@
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
import lightning as L
|
||||
|
||||
|
||||
def collate_batch(batch):
|
||||
x, y = zip(*batch)
|
||||
x = [i[0] for i in x]
|
||||
y = [torch.tensor(i) for i in y]
|
||||
x = torch.stack(x).to("cuda")
|
||||
y = torch.tensor(y).to("cuda")
|
||||
return x, y
|
||||
|
||||
|
||||
class TrainingWrapper(L.LightningModule):
|
||||
def __init__(self, model, loss_func=nn.functional.cross_entropy):
|
||||
def __init__(self, model, loss_func=nn.functional.mse_loss, accuracy=None):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.loss_func = loss_func
|
||||
self.accuracy = accuracy
|
||||
|
||||
def _forward_step(self, batch, batch_idx, label=""):
|
||||
x, y = collate_batch(batch)
|
||||
x, y = batch
|
||||
y_pred = self.model(x)
|
||||
batch_size = x.shape[0]
|
||||
one_hot_y = nn.functional.one_hot(y).type(torch.float64)
|
||||
loss = self.loss_func(y_pred, one_hot_y)
|
||||
acc = torch.sum(y_pred.argmax(dim=1) == y) / batch_size
|
||||
self.log(f"{label}{'_' if label else ''}loss", loss, batch_size=batch_size)
|
||||
self.log(f"{label}{'_' if label else ''}acc", acc, batch_size=batch_size)
|
||||
return loss, acc
|
||||
loss = self.loss_func(y_pred, y)
|
||||
self.log(f"{label}{'_' if label else ''}loss", loss)
|
||||
if self.accuracy is not None:
|
||||
acc = self.accuracy(y_pred, y)
|
||||
self.log(f"{label}{'_' if label else ''}acc", acc)
|
||||
return loss
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss, _ = self._forward_step(batch, batch_idx, label="train")
|
||||
return loss
|
||||
return self._forward_step(batch, batch_idx, label="train")
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
self._forward_step(batch, batch_idx, label="val")
|
||||
|
||||
Reference in New Issue
Block a user