Moved to rye instead of poetry, to avoid linking problems

This commit is contained in:
2024-09-11 14:26:35 +01:00
parent 4036364ea0
commit 546e235b09
27 changed files with 1776 additions and 5739 deletions

1
.envrc Normal file
View File

@@ -0,0 +1 @@
use flake . --impure

1
.gitignore vendored
View File

@@ -162,6 +162,7 @@ cython_debug/
#.idea/ #.idea/
.jj .jj
.direnv
scratchpad.ipynb scratchpad.ipynb
datasets/ datasets/
explore/ explore/

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.12.4

60
flake.lock generated Normal file
View File

@@ -0,0 +1,60 @@
{
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1726053618,
"narHash": "sha256-Xu5EVNdrbJ4XCpVU4moytVTuqoo9LsmQgR8g2BQd9Qc=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "0e6a2434a572fd583cac02e142ab0689895e395a",
"type": "github"
},
"original": {
"owner": "NixOS",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"nixpkgs": "nixpkgs",
"utils": "utils"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

49
flake.nix Normal file
View File

@@ -0,0 +1,49 @@
{
inputs = {
utils.url = "github:numtide/flake-utils";
};
inputs.nixpkgs.url = "github:NixOS/nixpkgs";
outputs = { self, nixpkgs, utils }: utils.lib.eachDefaultSystem (system:
let
pkgs = import nixpkgs {
system = "x86_64-linux";
config = {
allowUnfree = true;
cudaSupport = true;
};
};
in
{
devShell = pkgs.mkShell
{
NIX_LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath [
pkgs.stdenv.cc.cc
pkgs.zlib
pkgs.libGL
pkgs.libGLU
pkgs.wlroots
pkgs.ncurses5
pkgs.linuxKernel.packages.linux_latest_libre.nvidia_x11
];
NIX_LD = pkgs.lib.fileContents "${pkgs.stdenv.cc}/nix-support/dynamic-linker";
buildInputs = with pkgs; [
ruff
ruff-lsp
rye
uv
];
shellHook = ''
export CUDA_PATH=${pkgs.cudatoolkit}
export LD_LIBRARY_PATH=$(nix eval --raw nixpkgs#addOpenGLRunpath.driverLink)/lib
export EXTRA_LDFLAGS="-L/lib -L${pkgs.linuxPackages.nvidia_x11}/lib"
export EXTRA_CCFLAGS="-I/usr/include"
'';
};
}
);
}

5703
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,41 +1,55 @@
[tool.poetry] [project]
name = "symbolic-nn-tests" name = "symbolic-nn-tests"
version = "0.1.0" version = "0.1.0"
description = "" description = "Add your description here"
authors = ["Cian Hughes <chughes000@gmail.com>"] authors = [
license = "MIT" { name = "Cian Hughes", email = "chughes000@gmail.com" }
]
dependencies = [
"torch>=2.4.1",
"lightning>=2.4.0",
"torchvision>=0.19.1",
"wandb>=0.17.9",
"optuna>=4.0.0",
"setuptools>=74.1.2",
"gdown>=5.2.0",
"bpython>=0.24",
"ipython>=8.27.0",
"matplotlib-backend-kitty>=2.1.2",
"euporie>=2.8.2",
"ipykernel>=6.29.5",
"tensorboard>=2.17.1",
"typer>=0.12.5",
"kaggle>=1.6.17",
"periodic-table-dataclasses>=1.0",
"polars>=1.6.0",
"jupyter>=1.1.1",
"safetensors>=0.4.5",
"alive-progress>=3.1.5",
"hvplot>=0.10.0",
"pyarrow>=17.0.0",
"loguru>=0.7.2",
"plotly>=5.24.0",
"snoop>=0.4.3",
"scikit-optimize>=0.10.2",
]
readme = "README.md" readme = "README.md"
requires-python = ">= 3.8"
classifiers = ["Private :: Do Not Upload"]
[tool.poetry.dependencies] [project.scripts]
python = "^3.11" "symbolic-nn-tests" = "symbolic_nn_tests:main"
torch = "^2.3.0"
lightning = "^2.2.4"
torchvision = "^0.18.0"
wandb = "^0.17.0"
optuna = "^3.6.1"
setuptools = "^69.5.1"
gdown = "^5.2.0"
bpython = "^0.24"
ipython = "^8.24.0"
matplotlib-backend-kitty = "^2.1.2"
euporie = "^2.8.2"
ipykernel = "^6.29.4"
tensorboard = "^2.16.2"
typer = "^0.12.3"
kaggle = "^1.6.14"
periodic-table-dataclasses = "^1.0"
polars = "^0.20.28"
jupyter = "^1.0.0"
safetensors = "^0.4.3"
alive-progress = "^3.1.5"
hvplot = "^0.10.0"
pyarrow = "^16.1.0"
loguru = "^0.7.2"
plotly = "^5.22.0"
snoop = "^0.4.3"
scikit-optimize = "^0.10.2"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["hatchling"]
build-backend = "poetry.core.masonry.api" build-backend = "hatchling.build"
[tool.rye]
managed = true
dev-dependencies = []
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["src/symbolic_nn_tests"]

675
requirements-dev.lock Normal file
View File

@@ -0,0 +1,675 @@
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: []
# all-features: false
# with-sources: false
# generate-hashes: false
-e file:.
about-time==4.2.1
# via alive-progress
absl-py==2.1.0
# via tensorboard
aenum==3.1.12
# via euporie
aiohappyeyeballs==2.4.0
# via aiohttp
aiohttp==3.10.5
# via fsspec
aiosignal==1.3.1
# via aiohttp
alembic==1.13.2
# via optuna
alive-progress==3.1.5
# via symbolic-nn-tests
anyio==4.4.0
# via httpx
# via jupyter-server
argon2-cffi==23.1.0
# via jupyter-server
argon2-cffi-bindings==21.2.0
# via argon2-cffi
arrow==1.3.0
# via isoduration
asttokens==2.4.1
# via snoop
# via stack-data
async-lru==2.0.4
# via jupyterlab
attrs==24.2.0
# via aiohttp
# via jsonschema
# via referencing
babel==2.16.0
# via jupyterlab-server
beautifulsoup4==4.12.3
# via gdown
# via nbconvert
bleach==6.1.0
# via kaggle
# via nbconvert
# via panel
blessed==1.20.0
# via curtsies
bokeh==3.4.3
# via holoviews
# via hvplot
# via panel
bpython==0.24
# via symbolic-nn-tests
certifi==2024.8.30
# via httpcore
# via httpx
# via kaggle
# via requests
# via sentry-sdk
cffi==1.17.1
# via argon2-cffi-bindings
charset-normalizer==3.3.2
# via requests
cheap-repr==0.5.2
# via snoop
click==8.1.7
# via typer
# via wandb
colorcet==3.1.0
# via holoviews
# via hvplot
colorlog==6.8.2
# via optuna
comm==0.2.2
# via ipykernel
# via ipywidgets
contourpy==1.3.0
# via bokeh
# via matplotlib
curtsies==0.4.2
# via bpython
cwcwidth==0.1.9
# via bpython
# via curtsies
cycler==0.12.1
# via matplotlib
dataclasses-json==0.6.7
# via periodic-table-dataclasses
debugpy==1.8.5
# via ipykernel
decorator==5.1.1
# via ipython
defusedxml==0.7.1
# via nbconvert
docker-pycreds==0.4.0
# via wandb
euporie==2.8.2
# via symbolic-nn-tests
executing==2.1.0
# via snoop
# via stack-data
fastjsonschema==2.20.0
# via euporie
# via nbformat
filelock==3.16.0
# via gdown
# via torch
# via triton
flatlatex==0.15
# via euporie
fonttools==4.53.1
# via matplotlib
fqdn==1.5.1
# via jsonschema
frozenlist==1.4.1
# via aiohttp
# via aiosignal
fsspec==2024.9.0
# via euporie
# via lightning
# via pytorch-lightning
# via torch
# via universal-pathlib
gdown==5.2.0
# via symbolic-nn-tests
gitdb==4.0.11
# via gitpython
gitpython==3.1.43
# via wandb
grapheme==0.6.0
# via alive-progress
greenlet==3.1.0
# via bpython
# via sqlalchemy
grpcio==1.66.1
# via tensorboard
h11==0.14.0
# via httpcore
holoviews==1.19.1
# via hvplot
httpcore==1.0.5
# via httpx
httpx==0.27.2
# via jupyterlab
hvplot==0.10.0
# via symbolic-nn-tests
idna==3.8
# via anyio
# via httpx
# via jsonschema
# via requests
# via yarl
imagesize==1.4.1
# via euporie
ipykernel==6.29.5
# via jupyter
# via jupyter-console
# via jupyterlab
# via symbolic-nn-tests
ipython==8.27.0
# via ipykernel
# via ipywidgets
# via jupyter-console
# via symbolic-nn-tests
ipywidgets==8.1.5
# via jupyter
isoduration==20.11.0
# via jsonschema
jedi==0.19.1
# via ipython
jinja2==3.1.4
# via bokeh
# via jupyter-server
# via jupyterlab
# via jupyterlab-server
# via nbconvert
# via torch
joblib==1.4.2
# via scikit-learn
# via scikit-optimize
json5==0.9.25
# via jupyterlab-server
jsonpointer==3.0.0
# via jsonschema
jsonschema==4.23.0
# via jupyter-events
# via jupyterlab-server
# via nbformat
jsonschema-specifications==2023.12.1
# via jsonschema
jupyter==1.1.1
# via symbolic-nn-tests
jupyter-client==8.6.2
# via euporie
# via ipykernel
# via jupyter-console
# via jupyter-server
# via nbclient
jupyter-console==6.6.3
# via jupyter
jupyter-core==5.7.2
# via ipykernel
# via jupyter-client
# via jupyter-console
# via jupyter-server
# via jupyterlab
# via nbclient
# via nbconvert
# via nbformat
jupyter-events==0.10.0
# via jupyter-server
jupyter-lsp==2.2.5
# via jupyterlab
jupyter-server==2.14.2
# via jupyter-lsp
# via jupyterlab
# via jupyterlab-server
# via notebook
# via notebook-shim
jupyter-server-terminals==0.5.3
# via jupyter-server
jupyterlab==4.2.5
# via jupyter
# via notebook
jupyterlab-pygments==0.3.0
# via nbconvert
jupyterlab-server==2.27.3
# via jupyterlab
# via notebook
jupyterlab-widgets==3.0.13
# via ipywidgets
jupytext==1.16.4
# via euporie
kaggle==1.6.17
# via symbolic-nn-tests
kiwisolver==1.4.7
# via matplotlib
lightning==2.4.0
# via symbolic-nn-tests
lightning-utilities==0.11.7
# via lightning
# via pytorch-lightning
# via torchmetrics
linkify-it-py==1.0.3
# via euporie
# via panel
loguru==0.7.2
# via symbolic-nn-tests
mako==1.3.5
# via alembic
markdown==3.7
# via panel
# via tensorboard
markdown-it-py==2.1.0
# via euporie
# via jupytext
# via mdit-py-plugins
# via panel
# via rich
markupsafe==2.1.5
# via jinja2
# via mako
# via nbconvert
# via werkzeug
marshmallow==3.22.0
# via dataclasses-json
matplotlib==3.9.2
# via matplotlib-backend-kitty
matplotlib-backend-kitty==2.1.2
# via symbolic-nn-tests
matplotlib-inline==0.1.7
# via ipykernel
# via ipython
mdit-py-plugins==0.3.5
# via euporie
# via jupytext
# via panel
mdurl==0.1.2
# via markdown-it-py
mistune==3.0.2
# via nbconvert
mpmath==1.3.0
# via sympy
multidict==6.1.0
# via aiohttp
# via yarl
mypy-extensions==1.0.0
# via typing-inspect
nbclient==0.10.0
# via nbconvert
nbconvert==7.16.4
# via jupyter
# via jupyter-server
nbformat==5.10.4
# via euporie
# via jupyter-server
# via jupytext
# via nbclient
# via nbconvert
nest-asyncio==1.6.0
# via ipykernel
networkx==3.3
# via torch
notebook==7.2.2
# via jupyter
notebook-shim==0.2.4
# via jupyterlab
# via notebook
numpy==2.1.1
# via bokeh
# via contourpy
# via holoviews
# via hvplot
# via matplotlib
# via optuna
# via pandas
# via pyarrow
# via scikit-learn
# via scikit-optimize
# via scipy
# via tensorboard
# via torchmetrics
# via torchvision
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
# via torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==9.1.0.70
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
nvidia-nccl-cu12==2.20.5
# via torch
nvidia-nvjitlink-cu12==12.6.68
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
optuna==4.0.0
# via symbolic-nn-tests
overrides==7.7.0
# via jupyter-server
packaging==24.1
# via bokeh
# via holoviews
# via hvplot
# via ipykernel
# via jupyter-server
# via jupyterlab
# via jupyterlab-server
# via jupytext
# via lightning
# via lightning-utilities
# via marshmallow
# via matplotlib
# via nbconvert
# via optuna
# via plotly
# via pytorch-lightning
# via scikit-optimize
# via tensorboard
# via torchmetrics
pandas==2.2.2
# via bokeh
# via holoviews
# via hvplot
# via panel
pandocfilters==1.5.1
# via nbconvert
panel==1.4.5
# via holoviews
# via hvplot
param==2.1.1
# via holoviews
# via hvplot
# via panel
# via pyviz-comms
parso==0.8.4
# via jedi
periodic-table-dataclasses==1.0
# via symbolic-nn-tests
pexpect==4.9.0
# via ipython
pillow==10.4.0
# via bokeh
# via euporie
# via matplotlib
# via timg
# via torchvision
platformdirs==3.11.0
# via euporie
# via jupyter-core
# via wandb
plotly==5.24.0
# via symbolic-nn-tests
polars==1.6.0
# via symbolic-nn-tests
prometheus-client==0.20.0
# via jupyter-server
prompt-toolkit==3.0.47
# via euporie
# via ipython
# via jupyter-console
protobuf==5.28.0
# via tensorboard
# via wandb
psutil==6.0.0
# via ipykernel
# via wandb
ptyprocess==0.7.0
# via pexpect
# via terminado
pure-eval==0.2.3
# via stack-data
pyaml==24.7.0
# via scikit-optimize
pyarrow==17.0.0
# via symbolic-nn-tests
pycparser==2.22
# via cffi
pygments==2.18.0
# via bpython
# via euporie
# via ipython
# via jupyter-console
# via nbconvert
# via rich
# via snoop
pyparsing==3.1.4
# via matplotlib
pyperclip==1.9.0
# via euporie
pysocks==1.7.1
# via requests
python-dateutil==2.9.0.post0
# via arrow
# via jupyter-client
# via kaggle
# via matplotlib
# via pandas
python-json-logger==2.0.7
# via jupyter-events
python-slugify==8.0.4
# via kaggle
pytorch-lightning==2.4.0
# via lightning
pytz==2024.2
# via pandas
pyviz-comms==3.0.3
# via holoviews
# via panel
pyxdg==0.28
# via bpython
pyyaml==6.0.2
# via bokeh
# via jupyter-events
# via jupytext
# via lightning
# via optuna
# via pyaml
# via pytorch-lightning
# via wandb
pyzmq==26.2.0
# via ipykernel
# via jupyter-client
# via jupyter-console
# via jupyter-server
referencing==0.35.1
# via jsonschema
# via jsonschema-specifications
# via jupyter-events
regex==2024.7.24
# via flatlatex
requests==2.32.3
# via bpython
# via gdown
# via jupyterlab-server
# via kaggle
# via panel
# via wandb
rfc3339-validator==0.1.4
# via jsonschema
# via jupyter-events
rfc3986-validator==0.1.1
# via jsonschema
# via jupyter-events
rich==13.3.1
# via typer
rpds-py==0.20.0
# via jsonschema
# via referencing
safetensors==0.4.5
# via symbolic-nn-tests
scikit-learn==1.5.1
# via scikit-optimize
scikit-optimize==0.10.2
# via symbolic-nn-tests
scipy==1.14.1
# via scikit-learn
# via scikit-optimize
send2trash==1.8.3
# via jupyter-server
sentry-sdk==2.14.0
# via wandb
setproctitle==1.3.3
# via wandb
setuptools==74.1.2
# via jupyterlab
# via lightning-utilities
# via symbolic-nn-tests
# via tensorboard
# via torch
# via wandb
shellingham==1.5.4
# via typer
six==1.16.0
# via asttokens
# via bleach
# via blessed
# via docker-pycreds
# via kaggle
# via python-dateutil
# via rfc3339-validator
# via snoop
# via tensorboard
sixelcrop==0.1.7
# via euporie
smmap==5.0.1
# via gitdb
sniffio==1.3.1
# via anyio
# via httpx
snoop==0.4.3
# via symbolic-nn-tests
soupsieve==2.6
# via beautifulsoup4
sqlalchemy==2.0.34
# via alembic
# via optuna
stack-data==0.6.3
# via ipython
sympy==1.13.2
# via torch
tenacity==9.0.0
# via plotly
tensorboard==2.17.1
# via symbolic-nn-tests
tensorboard-data-server==0.7.2
# via tensorboard
terminado==0.18.1
# via jupyter-server
# via jupyter-server-terminals
text-unidecode==1.3
# via python-slugify
threadpoolctl==3.5.0
# via scikit-learn
timg==1.1.6
# via euporie
tinycss2==1.3.0
# via nbconvert
torch==2.4.1
# via lightning
# via pytorch-lightning
# via symbolic-nn-tests
# via torchmetrics
# via torchvision
torchmetrics==1.4.1
# via lightning
# via pytorch-lightning
torchvision==0.19.1
# via symbolic-nn-tests
tornado==6.4.1
# via bokeh
# via ipykernel
# via jupyter-client
# via jupyter-server
# via jupyterlab
# via notebook
# via terminado
tqdm==4.66.5
# via gdown
# via kaggle
# via lightning
# via optuna
# via panel
# via pytorch-lightning
traitlets==5.14.3
# via comm
# via ipykernel
# via ipython
# via ipywidgets
# via jupyter-client
# via jupyter-console
# via jupyter-core
# via jupyter-events
# via jupyter-server
# via jupyterlab
# via matplotlib-inline
# via nbclient
# via nbconvert
# via nbformat
triton==3.0.0
# via torch
typer==0.12.5
# via symbolic-nn-tests
types-python-dateutil==2.9.0.20240906
# via arrow
typing-extensions==4.12.2
# via alembic
# via euporie
# via lightning
# via lightning-utilities
# via panel
# via pytorch-lightning
# via sqlalchemy
# via torch
# via typer
# via typing-inspect
typing-inspect==0.9.0
# via dataclasses-json
tzdata==2024.1
# via pandas
uc-micro-py==1.0.3
# via linkify-it-py
universal-pathlib==0.2.5
# via euporie
uri-template==1.3.0
# via jsonschema
urllib3==2.2.2
# via kaggle
# via requests
# via sentry-sdk
wandb==0.17.9
# via symbolic-nn-tests
wcwidth==0.2.13
# via blessed
# via prompt-toolkit
webcolors==24.8.0
# via jsonschema
webencodings==0.5.1
# via bleach
# via tinycss2
websocket-client==1.8.0
# via jupyter-server
werkzeug==3.0.4
# via tensorboard
widgetsnbextension==4.0.13
# via ipywidgets
xyzservices==2024.9.0
# via bokeh
# via panel
yarl==1.11.1
# via aiohttp

675
requirements.lock Normal file
View File

@@ -0,0 +1,675 @@
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: []
# all-features: false
# with-sources: false
# generate-hashes: false
-e file:.
about-time==4.2.1
# via alive-progress
absl-py==2.1.0
# via tensorboard
aenum==3.1.12
# via euporie
aiohappyeyeballs==2.4.0
# via aiohttp
aiohttp==3.10.5
# via fsspec
aiosignal==1.3.1
# via aiohttp
alembic==1.13.2
# via optuna
alive-progress==3.1.5
# via symbolic-nn-tests
anyio==4.4.0
# via httpx
# via jupyter-server
argon2-cffi==23.1.0
# via jupyter-server
argon2-cffi-bindings==21.2.0
# via argon2-cffi
arrow==1.3.0
# via isoduration
asttokens==2.4.1
# via snoop
# via stack-data
async-lru==2.0.4
# via jupyterlab
attrs==24.2.0
# via aiohttp
# via jsonschema
# via referencing
babel==2.16.0
# via jupyterlab-server
beautifulsoup4==4.12.3
# via gdown
# via nbconvert
bleach==6.1.0
# via kaggle
# via nbconvert
# via panel
blessed==1.20.0
# via curtsies
bokeh==3.4.3
# via holoviews
# via hvplot
# via panel
bpython==0.24
# via symbolic-nn-tests
certifi==2024.8.30
# via httpcore
# via httpx
# via kaggle
# via requests
# via sentry-sdk
cffi==1.17.1
# via argon2-cffi-bindings
charset-normalizer==3.3.2
# via requests
cheap-repr==0.5.2
# via snoop
click==8.1.7
# via typer
# via wandb
colorcet==3.1.0
# via holoviews
# via hvplot
colorlog==6.8.2
# via optuna
comm==0.2.2
# via ipykernel
# via ipywidgets
contourpy==1.3.0
# via bokeh
# via matplotlib
curtsies==0.4.2
# via bpython
cwcwidth==0.1.9
# via bpython
# via curtsies
cycler==0.12.1
# via matplotlib
dataclasses-json==0.6.7
# via periodic-table-dataclasses
debugpy==1.8.5
# via ipykernel
decorator==5.1.1
# via ipython
defusedxml==0.7.1
# via nbconvert
docker-pycreds==0.4.0
# via wandb
euporie==2.8.2
# via symbolic-nn-tests
executing==2.1.0
# via snoop
# via stack-data
fastjsonschema==2.20.0
# via euporie
# via nbformat
filelock==3.16.0
# via gdown
# via torch
# via triton
flatlatex==0.15
# via euporie
fonttools==4.53.1
# via matplotlib
fqdn==1.5.1
# via jsonschema
frozenlist==1.4.1
# via aiohttp
# via aiosignal
fsspec==2024.9.0
# via euporie
# via lightning
# via pytorch-lightning
# via torch
# via universal-pathlib
gdown==5.2.0
# via symbolic-nn-tests
gitdb==4.0.11
# via gitpython
gitpython==3.1.43
# via wandb
grapheme==0.6.0
# via alive-progress
greenlet==3.1.0
# via bpython
# via sqlalchemy
grpcio==1.66.1
# via tensorboard
h11==0.14.0
# via httpcore
holoviews==1.19.1
# via hvplot
httpcore==1.0.5
# via httpx
httpx==0.27.2
# via jupyterlab
hvplot==0.10.0
# via symbolic-nn-tests
idna==3.8
# via anyio
# via httpx
# via jsonschema
# via requests
# via yarl
imagesize==1.4.1
# via euporie
ipykernel==6.29.5
# via jupyter
# via jupyter-console
# via jupyterlab
# via symbolic-nn-tests
ipython==8.27.0
# via ipykernel
# via ipywidgets
# via jupyter-console
# via symbolic-nn-tests
ipywidgets==8.1.5
# via jupyter
isoduration==20.11.0
# via jsonschema
jedi==0.19.1
# via ipython
jinja2==3.1.4
# via bokeh
# via jupyter-server
# via jupyterlab
# via jupyterlab-server
# via nbconvert
# via torch
joblib==1.4.2
# via scikit-learn
# via scikit-optimize
json5==0.9.25
# via jupyterlab-server
jsonpointer==3.0.0
# via jsonschema
jsonschema==4.23.0
# via jupyter-events
# via jupyterlab-server
# via nbformat
jsonschema-specifications==2023.12.1
# via jsonschema
jupyter==1.1.1
# via symbolic-nn-tests
jupyter-client==8.6.2
# via euporie
# via ipykernel
# via jupyter-console
# via jupyter-server
# via nbclient
jupyter-console==6.6.3
# via jupyter
jupyter-core==5.7.2
# via ipykernel
# via jupyter-client
# via jupyter-console
# via jupyter-server
# via jupyterlab
# via nbclient
# via nbconvert
# via nbformat
jupyter-events==0.10.0
# via jupyter-server
jupyter-lsp==2.2.5
# via jupyterlab
jupyter-server==2.14.2
# via jupyter-lsp
# via jupyterlab
# via jupyterlab-server
# via notebook
# via notebook-shim
jupyter-server-terminals==0.5.3
# via jupyter-server
jupyterlab==4.2.5
# via jupyter
# via notebook
jupyterlab-pygments==0.3.0
# via nbconvert
jupyterlab-server==2.27.3
# via jupyterlab
# via notebook
jupyterlab-widgets==3.0.13
# via ipywidgets
jupytext==1.16.4
# via euporie
kaggle==1.6.17
# via symbolic-nn-tests
kiwisolver==1.4.7
# via matplotlib
lightning==2.4.0
# via symbolic-nn-tests
lightning-utilities==0.11.7
# via lightning
# via pytorch-lightning
# via torchmetrics
linkify-it-py==1.0.3
# via euporie
# via panel
loguru==0.7.2
# via symbolic-nn-tests
mako==1.3.5
# via alembic
markdown==3.7
# via panel
# via tensorboard
markdown-it-py==2.1.0
# via euporie
# via jupytext
# via mdit-py-plugins
# via panel
# via rich
markupsafe==2.1.5
# via jinja2
# via mako
# via nbconvert
# via werkzeug
marshmallow==3.22.0
# via dataclasses-json
matplotlib==3.9.2
# via matplotlib-backend-kitty
matplotlib-backend-kitty==2.1.2
# via symbolic-nn-tests
matplotlib-inline==0.1.7
# via ipykernel
# via ipython
mdit-py-plugins==0.3.5
# via euporie
# via jupytext
# via panel
mdurl==0.1.2
# via markdown-it-py
mistune==3.0.2
# via nbconvert
mpmath==1.3.0
# via sympy
multidict==6.1.0
# via aiohttp
# via yarl
mypy-extensions==1.0.0
# via typing-inspect
nbclient==0.10.0
# via nbconvert
nbconvert==7.16.4
# via jupyter
# via jupyter-server
nbformat==5.10.4
# via euporie
# via jupyter-server
# via jupytext
# via nbclient
# via nbconvert
nest-asyncio==1.6.0
# via ipykernel
networkx==3.3
# via torch
notebook==7.2.2
# via jupyter
notebook-shim==0.2.4
# via jupyterlab
# via notebook
numpy==2.1.1
# via bokeh
# via contourpy
# via holoviews
# via hvplot
# via matplotlib
# via optuna
# via pandas
# via pyarrow
# via scikit-learn
# via scikit-optimize
# via scipy
# via tensorboard
# via torchmetrics
# via torchvision
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
# via torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==9.1.0.70
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
nvidia-nccl-cu12==2.20.5
# via torch
nvidia-nvjitlink-cu12==12.6.68
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
optuna==4.0.0
# via symbolic-nn-tests
overrides==7.7.0
# via jupyter-server
packaging==24.1
# via bokeh
# via holoviews
# via hvplot
# via ipykernel
# via jupyter-server
# via jupyterlab
# via jupyterlab-server
# via jupytext
# via lightning
# via lightning-utilities
# via marshmallow
# via matplotlib
# via nbconvert
# via optuna
# via plotly
# via pytorch-lightning
# via scikit-optimize
# via tensorboard
# via torchmetrics
pandas==2.2.2
# via bokeh
# via holoviews
# via hvplot
# via panel
pandocfilters==1.5.1
# via nbconvert
panel==1.4.5
# via holoviews
# via hvplot
param==2.1.1
# via holoviews
# via hvplot
# via panel
# via pyviz-comms
parso==0.8.4
# via jedi
periodic-table-dataclasses==1.0
# via symbolic-nn-tests
pexpect==4.9.0
# via ipython
pillow==10.4.0
# via bokeh
# via euporie
# via matplotlib
# via timg
# via torchvision
platformdirs==3.11.0
# via euporie
# via jupyter-core
# via wandb
plotly==5.24.0
# via symbolic-nn-tests
polars==1.6.0
# via symbolic-nn-tests
prometheus-client==0.20.0
# via jupyter-server
prompt-toolkit==3.0.47
# via euporie
# via ipython
# via jupyter-console
protobuf==5.28.0
# via tensorboard
# via wandb
psutil==6.0.0
# via ipykernel
# via wandb
ptyprocess==0.7.0
# via pexpect
# via terminado
pure-eval==0.2.3
# via stack-data
pyaml==24.7.0
# via scikit-optimize
pyarrow==17.0.0
# via symbolic-nn-tests
pycparser==2.22
# via cffi
pygments==2.18.0
# via bpython
# via euporie
# via ipython
# via jupyter-console
# via nbconvert
# via rich
# via snoop
pyparsing==3.1.4
# via matplotlib
pyperclip==1.9.0
# via euporie
pysocks==1.7.1
# via requests
python-dateutil==2.9.0.post0
# via arrow
# via jupyter-client
# via kaggle
# via matplotlib
# via pandas
python-json-logger==2.0.7
# via jupyter-events
python-slugify==8.0.4
# via kaggle
pytorch-lightning==2.4.0
# via lightning
pytz==2024.2
# via pandas
pyviz-comms==3.0.3
# via holoviews
# via panel
pyxdg==0.28
# via bpython
pyyaml==6.0.2
# via bokeh
# via jupyter-events
# via jupytext
# via lightning
# via optuna
# via pyaml
# via pytorch-lightning
# via wandb
pyzmq==26.2.0
# via ipykernel
# via jupyter-client
# via jupyter-console
# via jupyter-server
referencing==0.35.1
# via jsonschema
# via jsonschema-specifications
# via jupyter-events
regex==2024.7.24
# via flatlatex
requests==2.32.3
# via bpython
# via gdown
# via jupyterlab-server
# via kaggle
# via panel
# via wandb
rfc3339-validator==0.1.4
# via jsonschema
# via jupyter-events
rfc3986-validator==0.1.1
# via jsonschema
# via jupyter-events
rich==13.3.1
# via typer
rpds-py==0.20.0
# via jsonschema
# via referencing
safetensors==0.4.5
# via symbolic-nn-tests
scikit-learn==1.5.1
# via scikit-optimize
scikit-optimize==0.10.2
# via symbolic-nn-tests
scipy==1.14.1
# via scikit-learn
# via scikit-optimize
send2trash==1.8.3
# via jupyter-server
sentry-sdk==2.14.0
# via wandb
setproctitle==1.3.3
# via wandb
setuptools==74.1.2
# via jupyterlab
# via lightning-utilities
# via symbolic-nn-tests
# via tensorboard
# via torch
# via wandb
shellingham==1.5.4
# via typer
six==1.16.0
# via asttokens
# via bleach
# via blessed
# via docker-pycreds
# via kaggle
# via python-dateutil
# via rfc3339-validator
# via snoop
# via tensorboard
sixelcrop==0.1.7
# via euporie
smmap==5.0.1
# via gitdb
sniffio==1.3.1
# via anyio
# via httpx
snoop==0.4.3
# via symbolic-nn-tests
soupsieve==2.6
# via beautifulsoup4
sqlalchemy==2.0.34
# via alembic
# via optuna
stack-data==0.6.3
# via ipython
sympy==1.13.2
# via torch
tenacity==9.0.0
# via plotly
tensorboard==2.17.1
# via symbolic-nn-tests
tensorboard-data-server==0.7.2
# via tensorboard
terminado==0.18.1
# via jupyter-server
# via jupyter-server-terminals
text-unidecode==1.3
# via python-slugify
threadpoolctl==3.5.0
# via scikit-learn
timg==1.1.6
# via euporie
tinycss2==1.3.0
# via nbconvert
torch==2.4.1
# via lightning
# via pytorch-lightning
# via symbolic-nn-tests
# via torchmetrics
# via torchvision
torchmetrics==1.4.1
# via lightning
# via pytorch-lightning
torchvision==0.19.1
# via symbolic-nn-tests
tornado==6.4.1
# via bokeh
# via ipykernel
# via jupyter-client
# via jupyter-server
# via jupyterlab
# via notebook
# via terminado
tqdm==4.66.5
# via gdown
# via kaggle
# via lightning
# via optuna
# via panel
# via pytorch-lightning
traitlets==5.14.3
# via comm
# via ipykernel
# via ipython
# via ipywidgets
# via jupyter-client
# via jupyter-console
# via jupyter-core
# via jupyter-events
# via jupyter-server
# via jupyterlab
# via matplotlib-inline
# via nbclient
# via nbconvert
# via nbformat
triton==3.0.0
# via torch
typer==0.12.5
# via symbolic-nn-tests
types-python-dateutil==2.9.0.20240906
# via arrow
typing-extensions==4.12.2
# via alembic
# via euporie
# via lightning
# via lightning-utilities
# via panel
# via pytorch-lightning
# via sqlalchemy
# via torch
# via typer
# via typing-inspect
typing-inspect==0.9.0
# via dataclasses-json
tzdata==2024.1
# via pandas
uc-micro-py==1.0.3
# via linkify-it-py
universal-pathlib==0.2.5
# via euporie
uri-template==1.3.0
# via jsonschema
urllib3==2.2.2
# via kaggle
# via requests
# via sentry-sdk
wandb==0.17.9
# via symbolic-nn-tests
wcwidth==0.2.13
# via blessed
# via prompt-toolkit
webcolors==24.8.0
# via jsonschema
webencodings==0.5.1
# via bleach
# via tinycss2
websocket-client==1.8.0
# via jupyter-server
werkzeug==3.0.4
# via tensorboard
widgetsnbextension==4.0.13
# via ipywidgets
xyzservices==2024.9.0
# via bokeh
# via panel
yarl==1.11.1
# via aiohttp

View File

@@ -0,0 +1,11 @@
from . import __main__
import typer
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
def main():
typer.run(__main__.main)

View File

@@ -4,7 +4,7 @@ from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, random_split
PROJECT_ROOT = Path(__file__).parent.parent PROJECT_ROOT = Path(__file__).parent.parent.parent
DATASET_DIR = PROJECT_ROOT / "datasets/" DATASET_DIR = PROJECT_ROOT / "datasets/"

View File

@@ -0,0 +1,93 @@
LEARNING_RATE = 10e-5
def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True):
from .model import main as test_model
logger = []
if tensorboard:
from lightning.pytorch.loggers import TensorBoardLogger
tb_logger = TensorBoardLogger(
save_dir=".",
name="logs/comparison",
version=version,
)
logger.append(tb_logger)
if wandb:
import wandb as _wandb
from lightning.pytorch.loggers import WandbLogger
wandb_logger = WandbLogger(
project="Symbolic_NN_Tests",
name=version,
dir="wandb",
)
logger.append(wandb_logger)
test_model(
logger=logger,
train_loss=train_loss,
val_loss=val_loss,
test_loss=test_loss,
lr=LEARNING_RATE,
)
if wandb:
_wandb.finish()
def run(tensorboard: bool = True, wandb: bool = True):
from . import semantic_loss
from .model import oh_vs_cat_cross_entropy
test(
train_loss=oh_vs_cat_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
version="cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
)
test(
train_loss=semantic_loss.similarity_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
version="similarity_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
)
test(
train_loss=semantic_loss.hasline_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
version="hasline_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
)
test(
train_loss=semantic_loss.hasloop_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
version="hasloop_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
)
test(
train_loss=semantic_loss.multisemantic_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
version="multisemantic_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
)
test(
train_loss=semantic_loss.garbage_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
version="garbage_cross_entropy",
tensorboard=tensorboard,
wandb=wandb,
)

View File

@@ -0,0 +1,72 @@
from functools import lru_cache
import torch
from torch import nn
model = nn.Sequential(
nn.Flatten(1, -1),
nn.Linear(784, 10),
nn.Softmax(dim=1),
)
def collate(batch):
x, y = zip(*batch)
x = [i[0] for i in x]
y = [torch.tensor(i) for i in y]
x = torch.stack(x)
y = torch.tensor(y)
return x, y
# This is just a quick, lazy way to ensure all models are trained on the same dataset
@lru_cache(maxsize=1)
def get_singleton_dataset():
from torchvision.datasets import QMNIST
from symbolic_nn_tests.dataloader import create_dataset
return create_dataset(
dataset=QMNIST,
collate_fn=collate,
batch_size=128,
shuffle_train=True,
num_workers=11,
)
def oh_vs_cat_cross_entropy(y_bin, y_cat):
return nn.functional.cross_entropy(
y_bin,
nn.functional.one_hot(y_cat, num_classes=10).float(),
)
def main(
train_loss=oh_vs_cat_cross_entropy,
val_loss=oh_vs_cat_cross_entropy,
test_loss=oh_vs_cat_cross_entropy,
logger=None,
**kwargs,
):
import lightning as L
from symbolic_nn_tests.train import TrainingWrapper
if logger is None:
from lightning.pytorch.loggers import TensorBoardLogger
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
train, val, test = get_singleton_dataset()
lmodel = TrainingWrapper(
model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss
)
lmodel.configure_optimizers(**kwargs)
trainer = L.Trainer(max_epochs=20, logger=logger)
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
trainer.test(dataloaders=test)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,88 @@
import torch
def create_semantic_cross_entropy(semantic_matrix):
def semantic_cross_entropy(input, target):
ce_loss = torch.nn.functional.cross_entropy(input, target)
penalty_tensor = semantic_matrix[target.argmax(dim=1)]
abs_diff = (target - input).abs()
semantic_penalty = (abs_diff * penalty_tensor).sum()
return ce_loss * semantic_penalty
def oh_vs_cat_semantic_cross_entropy(input_oh, target_cat):
return semantic_cross_entropy(
input_oh, torch.nn.functional.one_hot(target_cat, num_classes=10).float()
)
return oh_vs_cat_semantic_cross_entropy
# NOTE: This similarity matrix defines loss scaling factors for misclassification
# of numbers from our QMNIST dataset. Visually similar numbers (e.g: 3/8) are
# penalised less harshly than visually distinct numbers as this mistake is "less
# mistaken" given our understanding of the visual characteristics of numerals.
# By using this scaling matric we can inject human knowledge into the model via
# the loss function, making this an example of a "semantic loss function"
SIMILARITY_MATRIX = torch.tensor(
[
[2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0],
[1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0],
[1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.5, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 2.0, 1.0, 1.0, 1.0],
[1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
]
).to("cuda")
SIMILARITY_MATRIX /= SIMILARITY_MATRIX.sum() # Normalized to sum of 1
similarity_cross_entropy = create_semantic_cross_entropy(SIMILARITY_MATRIX)
# NOTE: The following matrix encodes a simpler semantic penalty for correctly/incorrectly
# identifying shapes with straight lines in their representation. This can be a bit fuzzy
# in cases like "9" though.
HASLINE_MATRIX = torch.tensor(
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
[False, True, False, False, True, True, False, True, False, True]
).to("cuda")
HASLINE_MATRIX = torch.stack([i ^ HASLINE_MATRIX for i in HASLINE_MATRIX]).type(
torch.float64
)
HASLINE_MATRIX += 1
HASLINE_MATRIX /= HASLINE_MATRIX.sum() # Normalize to sum of 1
hasline_cross_entropy = create_semantic_cross_entropy(HASLINE_MATRIX)
# NOTE: Similarly, we can do the same for closed circular loops in a numeric character
HASLOOP_MATRIX = torch.tensor(
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
[True, False, False, False, False, False, True, False, True, True]
).to("cuda")
HASLOOP_MATRIX = torch.stack([i ^ HASLOOP_MATRIX for i in HASLOOP_MATRIX]).type(
torch.float64
)
HASLOOP_MATRIX += 1
HASLOOP_MATRIX /= HASLOOP_MATRIX.sum() # Normalize to sum of 1
hasloop_cross_entropy = create_semantic_cross_entropy(HASLOOP_MATRIX)
# NOTE: We can also combine all of these semantic matrices
MULTISEMANTIC_MATRIX = SIMILARITY_MATRIX * HASLINE_MATRIX * HASLOOP_MATRIX
MULTISEMANTIC_MATRIX /= MULTISEMANTIC_MATRIX.sum()
multisemantic_cross_entropy = create_semantic_cross_entropy(MULTISEMANTIC_MATRIX)
# NOTE: As a final test, lets make something similar to tehse but where there's no knowledge,
# just random data. This will create a benchmark for the effects of this process wothout the
# "knowledge" component
GARBAGE_MATRIX = torch.rand(10, 10).to("cuda")
GARBAGE_MATRIX /= GARBAGE_MATRIX.sum()
garbage_cross_entropy = create_semantic_cross_entropy(GARBAGE_MATRIX)