mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 14:11:59 +00:00
Moved to rye instead of poetry, to avoid linking problems
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -162,6 +162,7 @@ cython_debug/
|
|||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
.jj
|
.jj
|
||||||
|
.direnv
|
||||||
scratchpad.ipynb
|
scratchpad.ipynb
|
||||||
datasets/
|
datasets/
|
||||||
explore/
|
explore/
|
||||||
|
|||||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.12.4
|
||||||
60
flake.lock
generated
Normal file
60
flake.lock
generated
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
{
|
||||||
|
"nodes": {
|
||||||
|
"nixpkgs": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1726053618,
|
||||||
|
"narHash": "sha256-Xu5EVNdrbJ4XCpVU4moytVTuqoo9LsmQgR8g2BQd9Qc=",
|
||||||
|
"owner": "NixOS",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"rev": "0e6a2434a572fd583cac02e142ab0689895e395a",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "NixOS",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"inputs": {
|
||||||
|
"nixpkgs": "nixpkgs",
|
||||||
|
"utils": "utils"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"systems": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1681028828,
|
||||||
|
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||||
|
"owner": "nix-systems",
|
||||||
|
"repo": "default",
|
||||||
|
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "nix-systems",
|
||||||
|
"repo": "default",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"utils": {
|
||||||
|
"inputs": {
|
||||||
|
"systems": "systems"
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1710146030,
|
||||||
|
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
||||||
|
"owner": "numtide",
|
||||||
|
"repo": "flake-utils",
|
||||||
|
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "numtide",
|
||||||
|
"repo": "flake-utils",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": "root",
|
||||||
|
"version": 7
|
||||||
|
}
|
||||||
49
flake.nix
Normal file
49
flake.nix
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
{
|
||||||
|
inputs = {
|
||||||
|
utils.url = "github:numtide/flake-utils";
|
||||||
|
};
|
||||||
|
|
||||||
|
inputs.nixpkgs.url = "github:NixOS/nixpkgs";
|
||||||
|
|
||||||
|
outputs = { self, nixpkgs, utils }: utils.lib.eachDefaultSystem (system:
|
||||||
|
let
|
||||||
|
pkgs = import nixpkgs {
|
||||||
|
system = "x86_64-linux";
|
||||||
|
config = {
|
||||||
|
allowUnfree = true;
|
||||||
|
cudaSupport = true;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
in
|
||||||
|
{
|
||||||
|
devShell = pkgs.mkShell
|
||||||
|
{
|
||||||
|
NIX_LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath [
|
||||||
|
pkgs.stdenv.cc.cc
|
||||||
|
pkgs.zlib
|
||||||
|
pkgs.libGL
|
||||||
|
pkgs.libGLU
|
||||||
|
pkgs.wlroots
|
||||||
|
pkgs.ncurses5
|
||||||
|
pkgs.linuxKernel.packages.linux_latest_libre.nvidia_x11
|
||||||
|
];
|
||||||
|
NIX_LD = pkgs.lib.fileContents "${pkgs.stdenv.cc}/nix-support/dynamic-linker";
|
||||||
|
|
||||||
|
buildInputs = with pkgs; [
|
||||||
|
ruff
|
||||||
|
ruff-lsp
|
||||||
|
rye
|
||||||
|
uv
|
||||||
|
];
|
||||||
|
|
||||||
|
shellHook = ''
|
||||||
|
export CUDA_PATH=${pkgs.cudatoolkit}
|
||||||
|
export LD_LIBRARY_PATH=$(nix eval --raw nixpkgs#addOpenGLRunpath.driverLink)/lib
|
||||||
|
export EXTRA_LDFLAGS="-L/lib -L${pkgs.linuxPackages.nvidia_x11}/lib"
|
||||||
|
export EXTRA_CCFLAGS="-I/usr/include"
|
||||||
|
'';
|
||||||
|
};
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
5703
poetry.lock
generated
5703
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,41 +1,55 @@
|
|||||||
[tool.poetry]
|
[project]
|
||||||
name = "symbolic-nn-tests"
|
name = "symbolic-nn-tests"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = ""
|
description = "Add your description here"
|
||||||
authors = ["Cian Hughes <chughes000@gmail.com>"]
|
authors = [
|
||||||
license = "MIT"
|
{ name = "Cian Hughes", email = "chughes000@gmail.com" }
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"torch>=2.4.1",
|
||||||
|
"lightning>=2.4.0",
|
||||||
|
"torchvision>=0.19.1",
|
||||||
|
"wandb>=0.17.9",
|
||||||
|
"optuna>=4.0.0",
|
||||||
|
"setuptools>=74.1.2",
|
||||||
|
"gdown>=5.2.0",
|
||||||
|
"bpython>=0.24",
|
||||||
|
"ipython>=8.27.0",
|
||||||
|
"matplotlib-backend-kitty>=2.1.2",
|
||||||
|
"euporie>=2.8.2",
|
||||||
|
"ipykernel>=6.29.5",
|
||||||
|
"tensorboard>=2.17.1",
|
||||||
|
"typer>=0.12.5",
|
||||||
|
"kaggle>=1.6.17",
|
||||||
|
"periodic-table-dataclasses>=1.0",
|
||||||
|
"polars>=1.6.0",
|
||||||
|
"jupyter>=1.1.1",
|
||||||
|
"safetensors>=0.4.5",
|
||||||
|
"alive-progress>=3.1.5",
|
||||||
|
"hvplot>=0.10.0",
|
||||||
|
"pyarrow>=17.0.0",
|
||||||
|
"loguru>=0.7.2",
|
||||||
|
"plotly>=5.24.0",
|
||||||
|
"snoop>=0.4.3",
|
||||||
|
"scikit-optimize>=0.10.2",
|
||||||
|
]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
requires-python = ">= 3.8"
|
||||||
|
classifiers = ["Private :: Do Not Upload"]
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[project.scripts]
|
||||||
python = "^3.11"
|
"symbolic-nn-tests" = "symbolic_nn_tests:main"
|
||||||
torch = "^2.3.0"
|
|
||||||
lightning = "^2.2.4"
|
|
||||||
torchvision = "^0.18.0"
|
|
||||||
wandb = "^0.17.0"
|
|
||||||
optuna = "^3.6.1"
|
|
||||||
setuptools = "^69.5.1"
|
|
||||||
gdown = "^5.2.0"
|
|
||||||
bpython = "^0.24"
|
|
||||||
ipython = "^8.24.0"
|
|
||||||
matplotlib-backend-kitty = "^2.1.2"
|
|
||||||
euporie = "^2.8.2"
|
|
||||||
ipykernel = "^6.29.4"
|
|
||||||
tensorboard = "^2.16.2"
|
|
||||||
typer = "^0.12.3"
|
|
||||||
kaggle = "^1.6.14"
|
|
||||||
periodic-table-dataclasses = "^1.0"
|
|
||||||
polars = "^0.20.28"
|
|
||||||
jupyter = "^1.0.0"
|
|
||||||
safetensors = "^0.4.3"
|
|
||||||
alive-progress = "^3.1.5"
|
|
||||||
hvplot = "^0.10.0"
|
|
||||||
pyarrow = "^16.1.0"
|
|
||||||
loguru = "^0.7.2"
|
|
||||||
plotly = "^5.22.0"
|
|
||||||
snoop = "^0.4.3"
|
|
||||||
scikit-optimize = "^0.10.2"
|
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["hatchling"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.rye]
|
||||||
|
managed = true
|
||||||
|
dev-dependencies = []
|
||||||
|
|
||||||
|
[tool.hatch.metadata]
|
||||||
|
allow-direct-references = true
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["src/symbolic_nn_tests"]
|
||||||
|
|||||||
675
requirements-dev.lock
Normal file
675
requirements-dev.lock
Normal file
@@ -0,0 +1,675 @@
|
|||||||
|
# generated by rye
|
||||||
|
# use `rye lock` or `rye sync` to update this lockfile
|
||||||
|
#
|
||||||
|
# last locked with the following flags:
|
||||||
|
# pre: false
|
||||||
|
# features: []
|
||||||
|
# all-features: false
|
||||||
|
# with-sources: false
|
||||||
|
# generate-hashes: false
|
||||||
|
|
||||||
|
-e file:.
|
||||||
|
about-time==4.2.1
|
||||||
|
# via alive-progress
|
||||||
|
absl-py==2.1.0
|
||||||
|
# via tensorboard
|
||||||
|
aenum==3.1.12
|
||||||
|
# via euporie
|
||||||
|
aiohappyeyeballs==2.4.0
|
||||||
|
# via aiohttp
|
||||||
|
aiohttp==3.10.5
|
||||||
|
# via fsspec
|
||||||
|
aiosignal==1.3.1
|
||||||
|
# via aiohttp
|
||||||
|
alembic==1.13.2
|
||||||
|
# via optuna
|
||||||
|
alive-progress==3.1.5
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
anyio==4.4.0
|
||||||
|
# via httpx
|
||||||
|
# via jupyter-server
|
||||||
|
argon2-cffi==23.1.0
|
||||||
|
# via jupyter-server
|
||||||
|
argon2-cffi-bindings==21.2.0
|
||||||
|
# via argon2-cffi
|
||||||
|
arrow==1.3.0
|
||||||
|
# via isoduration
|
||||||
|
asttokens==2.4.1
|
||||||
|
# via snoop
|
||||||
|
# via stack-data
|
||||||
|
async-lru==2.0.4
|
||||||
|
# via jupyterlab
|
||||||
|
attrs==24.2.0
|
||||||
|
# via aiohttp
|
||||||
|
# via jsonschema
|
||||||
|
# via referencing
|
||||||
|
babel==2.16.0
|
||||||
|
# via jupyterlab-server
|
||||||
|
beautifulsoup4==4.12.3
|
||||||
|
# via gdown
|
||||||
|
# via nbconvert
|
||||||
|
bleach==6.1.0
|
||||||
|
# via kaggle
|
||||||
|
# via nbconvert
|
||||||
|
# via panel
|
||||||
|
blessed==1.20.0
|
||||||
|
# via curtsies
|
||||||
|
bokeh==3.4.3
|
||||||
|
# via holoviews
|
||||||
|
# via hvplot
|
||||||
|
# via panel
|
||||||
|
bpython==0.24
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
certifi==2024.8.30
|
||||||
|
# via httpcore
|
||||||
|
# via httpx
|
||||||
|
# via kaggle
|
||||||
|
# via requests
|
||||||
|
# via sentry-sdk
|
||||||
|
cffi==1.17.1
|
||||||
|
# via argon2-cffi-bindings
|
||||||
|
charset-normalizer==3.3.2
|
||||||
|
# via requests
|
||||||
|
cheap-repr==0.5.2
|
||||||
|
# via snoop
|
||||||
|
click==8.1.7
|
||||||
|
# via typer
|
||||||
|
# via wandb
|
||||||
|
colorcet==3.1.0
|
||||||
|
# via holoviews
|
||||||
|
# via hvplot
|
||||||
|
colorlog==6.8.2
|
||||||
|
# via optuna
|
||||||
|
comm==0.2.2
|
||||||
|
# via ipykernel
|
||||||
|
# via ipywidgets
|
||||||
|
contourpy==1.3.0
|
||||||
|
# via bokeh
|
||||||
|
# via matplotlib
|
||||||
|
curtsies==0.4.2
|
||||||
|
# via bpython
|
||||||
|
cwcwidth==0.1.9
|
||||||
|
# via bpython
|
||||||
|
# via curtsies
|
||||||
|
cycler==0.12.1
|
||||||
|
# via matplotlib
|
||||||
|
dataclasses-json==0.6.7
|
||||||
|
# via periodic-table-dataclasses
|
||||||
|
debugpy==1.8.5
|
||||||
|
# via ipykernel
|
||||||
|
decorator==5.1.1
|
||||||
|
# via ipython
|
||||||
|
defusedxml==0.7.1
|
||||||
|
# via nbconvert
|
||||||
|
docker-pycreds==0.4.0
|
||||||
|
# via wandb
|
||||||
|
euporie==2.8.2
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
executing==2.1.0
|
||||||
|
# via snoop
|
||||||
|
# via stack-data
|
||||||
|
fastjsonschema==2.20.0
|
||||||
|
# via euporie
|
||||||
|
# via nbformat
|
||||||
|
filelock==3.16.0
|
||||||
|
# via gdown
|
||||||
|
# via torch
|
||||||
|
# via triton
|
||||||
|
flatlatex==0.15
|
||||||
|
# via euporie
|
||||||
|
fonttools==4.53.1
|
||||||
|
# via matplotlib
|
||||||
|
fqdn==1.5.1
|
||||||
|
# via jsonschema
|
||||||
|
frozenlist==1.4.1
|
||||||
|
# via aiohttp
|
||||||
|
# via aiosignal
|
||||||
|
fsspec==2024.9.0
|
||||||
|
# via euporie
|
||||||
|
# via lightning
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via torch
|
||||||
|
# via universal-pathlib
|
||||||
|
gdown==5.2.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
gitdb==4.0.11
|
||||||
|
# via gitpython
|
||||||
|
gitpython==3.1.43
|
||||||
|
# via wandb
|
||||||
|
grapheme==0.6.0
|
||||||
|
# via alive-progress
|
||||||
|
greenlet==3.1.0
|
||||||
|
# via bpython
|
||||||
|
# via sqlalchemy
|
||||||
|
grpcio==1.66.1
|
||||||
|
# via tensorboard
|
||||||
|
h11==0.14.0
|
||||||
|
# via httpcore
|
||||||
|
holoviews==1.19.1
|
||||||
|
# via hvplot
|
||||||
|
httpcore==1.0.5
|
||||||
|
# via httpx
|
||||||
|
httpx==0.27.2
|
||||||
|
# via jupyterlab
|
||||||
|
hvplot==0.10.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
idna==3.8
|
||||||
|
# via anyio
|
||||||
|
# via httpx
|
||||||
|
# via jsonschema
|
||||||
|
# via requests
|
||||||
|
# via yarl
|
||||||
|
imagesize==1.4.1
|
||||||
|
# via euporie
|
||||||
|
ipykernel==6.29.5
|
||||||
|
# via jupyter
|
||||||
|
# via jupyter-console
|
||||||
|
# via jupyterlab
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
ipython==8.27.0
|
||||||
|
# via ipykernel
|
||||||
|
# via ipywidgets
|
||||||
|
# via jupyter-console
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
ipywidgets==8.1.5
|
||||||
|
# via jupyter
|
||||||
|
isoduration==20.11.0
|
||||||
|
# via jsonschema
|
||||||
|
jedi==0.19.1
|
||||||
|
# via ipython
|
||||||
|
jinja2==3.1.4
|
||||||
|
# via bokeh
|
||||||
|
# via jupyter-server
|
||||||
|
# via jupyterlab
|
||||||
|
# via jupyterlab-server
|
||||||
|
# via nbconvert
|
||||||
|
# via torch
|
||||||
|
joblib==1.4.2
|
||||||
|
# via scikit-learn
|
||||||
|
# via scikit-optimize
|
||||||
|
json5==0.9.25
|
||||||
|
# via jupyterlab-server
|
||||||
|
jsonpointer==3.0.0
|
||||||
|
# via jsonschema
|
||||||
|
jsonschema==4.23.0
|
||||||
|
# via jupyter-events
|
||||||
|
# via jupyterlab-server
|
||||||
|
# via nbformat
|
||||||
|
jsonschema-specifications==2023.12.1
|
||||||
|
# via jsonschema
|
||||||
|
jupyter==1.1.1
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
jupyter-client==8.6.2
|
||||||
|
# via euporie
|
||||||
|
# via ipykernel
|
||||||
|
# via jupyter-console
|
||||||
|
# via jupyter-server
|
||||||
|
# via nbclient
|
||||||
|
jupyter-console==6.6.3
|
||||||
|
# via jupyter
|
||||||
|
jupyter-core==5.7.2
|
||||||
|
# via ipykernel
|
||||||
|
# via jupyter-client
|
||||||
|
# via jupyter-console
|
||||||
|
# via jupyter-server
|
||||||
|
# via jupyterlab
|
||||||
|
# via nbclient
|
||||||
|
# via nbconvert
|
||||||
|
# via nbformat
|
||||||
|
jupyter-events==0.10.0
|
||||||
|
# via jupyter-server
|
||||||
|
jupyter-lsp==2.2.5
|
||||||
|
# via jupyterlab
|
||||||
|
jupyter-server==2.14.2
|
||||||
|
# via jupyter-lsp
|
||||||
|
# via jupyterlab
|
||||||
|
# via jupyterlab-server
|
||||||
|
# via notebook
|
||||||
|
# via notebook-shim
|
||||||
|
jupyter-server-terminals==0.5.3
|
||||||
|
# via jupyter-server
|
||||||
|
jupyterlab==4.2.5
|
||||||
|
# via jupyter
|
||||||
|
# via notebook
|
||||||
|
jupyterlab-pygments==0.3.0
|
||||||
|
# via nbconvert
|
||||||
|
jupyterlab-server==2.27.3
|
||||||
|
# via jupyterlab
|
||||||
|
# via notebook
|
||||||
|
jupyterlab-widgets==3.0.13
|
||||||
|
# via ipywidgets
|
||||||
|
jupytext==1.16.4
|
||||||
|
# via euporie
|
||||||
|
kaggle==1.6.17
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
kiwisolver==1.4.7
|
||||||
|
# via matplotlib
|
||||||
|
lightning==2.4.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
lightning-utilities==0.11.7
|
||||||
|
# via lightning
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via torchmetrics
|
||||||
|
linkify-it-py==1.0.3
|
||||||
|
# via euporie
|
||||||
|
# via panel
|
||||||
|
loguru==0.7.2
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
mako==1.3.5
|
||||||
|
# via alembic
|
||||||
|
markdown==3.7
|
||||||
|
# via panel
|
||||||
|
# via tensorboard
|
||||||
|
markdown-it-py==2.1.0
|
||||||
|
# via euporie
|
||||||
|
# via jupytext
|
||||||
|
# via mdit-py-plugins
|
||||||
|
# via panel
|
||||||
|
# via rich
|
||||||
|
markupsafe==2.1.5
|
||||||
|
# via jinja2
|
||||||
|
# via mako
|
||||||
|
# via nbconvert
|
||||||
|
# via werkzeug
|
||||||
|
marshmallow==3.22.0
|
||||||
|
# via dataclasses-json
|
||||||
|
matplotlib==3.9.2
|
||||||
|
# via matplotlib-backend-kitty
|
||||||
|
matplotlib-backend-kitty==2.1.2
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
matplotlib-inline==0.1.7
|
||||||
|
# via ipykernel
|
||||||
|
# via ipython
|
||||||
|
mdit-py-plugins==0.3.5
|
||||||
|
# via euporie
|
||||||
|
# via jupytext
|
||||||
|
# via panel
|
||||||
|
mdurl==0.1.2
|
||||||
|
# via markdown-it-py
|
||||||
|
mistune==3.0.2
|
||||||
|
# via nbconvert
|
||||||
|
mpmath==1.3.0
|
||||||
|
# via sympy
|
||||||
|
multidict==6.1.0
|
||||||
|
# via aiohttp
|
||||||
|
# via yarl
|
||||||
|
mypy-extensions==1.0.0
|
||||||
|
# via typing-inspect
|
||||||
|
nbclient==0.10.0
|
||||||
|
# via nbconvert
|
||||||
|
nbconvert==7.16.4
|
||||||
|
# via jupyter
|
||||||
|
# via jupyter-server
|
||||||
|
nbformat==5.10.4
|
||||||
|
# via euporie
|
||||||
|
# via jupyter-server
|
||||||
|
# via jupytext
|
||||||
|
# via nbclient
|
||||||
|
# via nbconvert
|
||||||
|
nest-asyncio==1.6.0
|
||||||
|
# via ipykernel
|
||||||
|
networkx==3.3
|
||||||
|
# via torch
|
||||||
|
notebook==7.2.2
|
||||||
|
# via jupyter
|
||||||
|
notebook-shim==0.2.4
|
||||||
|
# via jupyterlab
|
||||||
|
# via notebook
|
||||||
|
numpy==2.1.1
|
||||||
|
# via bokeh
|
||||||
|
# via contourpy
|
||||||
|
# via holoviews
|
||||||
|
# via hvplot
|
||||||
|
# via matplotlib
|
||||||
|
# via optuna
|
||||||
|
# via pandas
|
||||||
|
# via pyarrow
|
||||||
|
# via scikit-learn
|
||||||
|
# via scikit-optimize
|
||||||
|
# via scipy
|
||||||
|
# via tensorboard
|
||||||
|
# via torchmetrics
|
||||||
|
# via torchvision
|
||||||
|
nvidia-cublas-cu12==12.1.3.1
|
||||||
|
# via nvidia-cudnn-cu12
|
||||||
|
# via nvidia-cusolver-cu12
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-cupti-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-runtime-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cudnn-cu12==9.1.0.70
|
||||||
|
# via torch
|
||||||
|
nvidia-cufft-cu12==11.0.2.54
|
||||||
|
# via torch
|
||||||
|
nvidia-curand-cu12==10.3.2.106
|
||||||
|
# via torch
|
||||||
|
nvidia-cusolver-cu12==11.4.5.107
|
||||||
|
# via torch
|
||||||
|
nvidia-cusparse-cu12==12.1.0.106
|
||||||
|
# via nvidia-cusolver-cu12
|
||||||
|
# via torch
|
||||||
|
nvidia-nccl-cu12==2.20.5
|
||||||
|
# via torch
|
||||||
|
nvidia-nvjitlink-cu12==12.6.68
|
||||||
|
# via nvidia-cusolver-cu12
|
||||||
|
# via nvidia-cusparse-cu12
|
||||||
|
nvidia-nvtx-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
optuna==4.0.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
overrides==7.7.0
|
||||||
|
# via jupyter-server
|
||||||
|
packaging==24.1
|
||||||
|
# via bokeh
|
||||||
|
# via holoviews
|
||||||
|
# via hvplot
|
||||||
|
# via ipykernel
|
||||||
|
# via jupyter-server
|
||||||
|
# via jupyterlab
|
||||||
|
# via jupyterlab-server
|
||||||
|
# via jupytext
|
||||||
|
# via lightning
|
||||||
|
# via lightning-utilities
|
||||||
|
# via marshmallow
|
||||||
|
# via matplotlib
|
||||||
|
# via nbconvert
|
||||||
|
# via optuna
|
||||||
|
# via plotly
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via scikit-optimize
|
||||||
|
# via tensorboard
|
||||||
|
# via torchmetrics
|
||||||
|
pandas==2.2.2
|
||||||
|
# via bokeh
|
||||||
|
# via holoviews
|
||||||
|
# via hvplot
|
||||||
|
# via panel
|
||||||
|
pandocfilters==1.5.1
|
||||||
|
# via nbconvert
|
||||||
|
panel==1.4.5
|
||||||
|
# via holoviews
|
||||||
|
# via hvplot
|
||||||
|
param==2.1.1
|
||||||
|
# via holoviews
|
||||||
|
# via hvplot
|
||||||
|
# via panel
|
||||||
|
# via pyviz-comms
|
||||||
|
parso==0.8.4
|
||||||
|
# via jedi
|
||||||
|
periodic-table-dataclasses==1.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
pexpect==4.9.0
|
||||||
|
# via ipython
|
||||||
|
pillow==10.4.0
|
||||||
|
# via bokeh
|
||||||
|
# via euporie
|
||||||
|
# via matplotlib
|
||||||
|
# via timg
|
||||||
|
# via torchvision
|
||||||
|
platformdirs==3.11.0
|
||||||
|
# via euporie
|
||||||
|
# via jupyter-core
|
||||||
|
# via wandb
|
||||||
|
plotly==5.24.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
polars==1.6.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
prometheus-client==0.20.0
|
||||||
|
# via jupyter-server
|
||||||
|
prompt-toolkit==3.0.47
|
||||||
|
# via euporie
|
||||||
|
# via ipython
|
||||||
|
# via jupyter-console
|
||||||
|
protobuf==5.28.0
|
||||||
|
# via tensorboard
|
||||||
|
# via wandb
|
||||||
|
psutil==6.0.0
|
||||||
|
# via ipykernel
|
||||||
|
# via wandb
|
||||||
|
ptyprocess==0.7.0
|
||||||
|
# via pexpect
|
||||||
|
# via terminado
|
||||||
|
pure-eval==0.2.3
|
||||||
|
# via stack-data
|
||||||
|
pyaml==24.7.0
|
||||||
|
# via scikit-optimize
|
||||||
|
pyarrow==17.0.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
pycparser==2.22
|
||||||
|
# via cffi
|
||||||
|
pygments==2.18.0
|
||||||
|
# via bpython
|
||||||
|
# via euporie
|
||||||
|
# via ipython
|
||||||
|
# via jupyter-console
|
||||||
|
# via nbconvert
|
||||||
|
# via rich
|
||||||
|
# via snoop
|
||||||
|
pyparsing==3.1.4
|
||||||
|
# via matplotlib
|
||||||
|
pyperclip==1.9.0
|
||||||
|
# via euporie
|
||||||
|
pysocks==1.7.1
|
||||||
|
# via requests
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
# via arrow
|
||||||
|
# via jupyter-client
|
||||||
|
# via kaggle
|
||||||
|
# via matplotlib
|
||||||
|
# via pandas
|
||||||
|
python-json-logger==2.0.7
|
||||||
|
# via jupyter-events
|
||||||
|
python-slugify==8.0.4
|
||||||
|
# via kaggle
|
||||||
|
pytorch-lightning==2.4.0
|
||||||
|
# via lightning
|
||||||
|
pytz==2024.2
|
||||||
|
# via pandas
|
||||||
|
pyviz-comms==3.0.3
|
||||||
|
# via holoviews
|
||||||
|
# via panel
|
||||||
|
pyxdg==0.28
|
||||||
|
# via bpython
|
||||||
|
pyyaml==6.0.2
|
||||||
|
# via bokeh
|
||||||
|
# via jupyter-events
|
||||||
|
# via jupytext
|
||||||
|
# via lightning
|
||||||
|
# via optuna
|
||||||
|
# via pyaml
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via wandb
|
||||||
|
pyzmq==26.2.0
|
||||||
|
# via ipykernel
|
||||||
|
# via jupyter-client
|
||||||
|
# via jupyter-console
|
||||||
|
# via jupyter-server
|
||||||
|
referencing==0.35.1
|
||||||
|
# via jsonschema
|
||||||
|
# via jsonschema-specifications
|
||||||
|
# via jupyter-events
|
||||||
|
regex==2024.7.24
|
||||||
|
# via flatlatex
|
||||||
|
requests==2.32.3
|
||||||
|
# via bpython
|
||||||
|
# via gdown
|
||||||
|
# via jupyterlab-server
|
||||||
|
# via kaggle
|
||||||
|
# via panel
|
||||||
|
# via wandb
|
||||||
|
rfc3339-validator==0.1.4
|
||||||
|
# via jsonschema
|
||||||
|
# via jupyter-events
|
||||||
|
rfc3986-validator==0.1.1
|
||||||
|
# via jsonschema
|
||||||
|
# via jupyter-events
|
||||||
|
rich==13.3.1
|
||||||
|
# via typer
|
||||||
|
rpds-py==0.20.0
|
||||||
|
# via jsonschema
|
||||||
|
# via referencing
|
||||||
|
safetensors==0.4.5
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
scikit-learn==1.5.1
|
||||||
|
# via scikit-optimize
|
||||||
|
scikit-optimize==0.10.2
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
scipy==1.14.1
|
||||||
|
# via scikit-learn
|
||||||
|
# via scikit-optimize
|
||||||
|
send2trash==1.8.3
|
||||||
|
# via jupyter-server
|
||||||
|
sentry-sdk==2.14.0
|
||||||
|
# via wandb
|
||||||
|
setproctitle==1.3.3
|
||||||
|
# via wandb
|
||||||
|
setuptools==74.1.2
|
||||||
|
# via jupyterlab
|
||||||
|
# via lightning-utilities
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
# via tensorboard
|
||||||
|
# via torch
|
||||||
|
# via wandb
|
||||||
|
shellingham==1.5.4
|
||||||
|
# via typer
|
||||||
|
six==1.16.0
|
||||||
|
# via asttokens
|
||||||
|
# via bleach
|
||||||
|
# via blessed
|
||||||
|
# via docker-pycreds
|
||||||
|
# via kaggle
|
||||||
|
# via python-dateutil
|
||||||
|
# via rfc3339-validator
|
||||||
|
# via snoop
|
||||||
|
# via tensorboard
|
||||||
|
sixelcrop==0.1.7
|
||||||
|
# via euporie
|
||||||
|
smmap==5.0.1
|
||||||
|
# via gitdb
|
||||||
|
sniffio==1.3.1
|
||||||
|
# via anyio
|
||||||
|
# via httpx
|
||||||
|
snoop==0.4.3
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
soupsieve==2.6
|
||||||
|
# via beautifulsoup4
|
||||||
|
sqlalchemy==2.0.34
|
||||||
|
# via alembic
|
||||||
|
# via optuna
|
||||||
|
stack-data==0.6.3
|
||||||
|
# via ipython
|
||||||
|
sympy==1.13.2
|
||||||
|
# via torch
|
||||||
|
tenacity==9.0.0
|
||||||
|
# via plotly
|
||||||
|
tensorboard==2.17.1
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
tensorboard-data-server==0.7.2
|
||||||
|
# via tensorboard
|
||||||
|
terminado==0.18.1
|
||||||
|
# via jupyter-server
|
||||||
|
# via jupyter-server-terminals
|
||||||
|
text-unidecode==1.3
|
||||||
|
# via python-slugify
|
||||||
|
threadpoolctl==3.5.0
|
||||||
|
# via scikit-learn
|
||||||
|
timg==1.1.6
|
||||||
|
# via euporie
|
||||||
|
tinycss2==1.3.0
|
||||||
|
# via nbconvert
|
||||||
|
torch==2.4.1
|
||||||
|
# via lightning
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
# via torchmetrics
|
||||||
|
# via torchvision
|
||||||
|
torchmetrics==1.4.1
|
||||||
|
# via lightning
|
||||||
|
# via pytorch-lightning
|
||||||
|
torchvision==0.19.1
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
tornado==6.4.1
|
||||||
|
# via bokeh
|
||||||
|
# via ipykernel
|
||||||
|
# via jupyter-client
|
||||||
|
# via jupyter-server
|
||||||
|
# via jupyterlab
|
||||||
|
# via notebook
|
||||||
|
# via terminado
|
||||||
|
tqdm==4.66.5
|
||||||
|
# via gdown
|
||||||
|
# via kaggle
|
||||||
|
# via lightning
|
||||||
|
# via optuna
|
||||||
|
# via panel
|
||||||
|
# via pytorch-lightning
|
||||||
|
traitlets==5.14.3
|
||||||
|
# via comm
|
||||||
|
# via ipykernel
|
||||||
|
# via ipython
|
||||||
|
# via ipywidgets
|
||||||
|
# via jupyter-client
|
||||||
|
# via jupyter-console
|
||||||
|
# via jupyter-core
|
||||||
|
# via jupyter-events
|
||||||
|
# via jupyter-server
|
||||||
|
# via jupyterlab
|
||||||
|
# via matplotlib-inline
|
||||||
|
# via nbclient
|
||||||
|
# via nbconvert
|
||||||
|
# via nbformat
|
||||||
|
triton==3.0.0
|
||||||
|
# via torch
|
||||||
|
typer==0.12.5
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
types-python-dateutil==2.9.0.20240906
|
||||||
|
# via arrow
|
||||||
|
typing-extensions==4.12.2
|
||||||
|
# via alembic
|
||||||
|
# via euporie
|
||||||
|
# via lightning
|
||||||
|
# via lightning-utilities
|
||||||
|
# via panel
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via sqlalchemy
|
||||||
|
# via torch
|
||||||
|
# via typer
|
||||||
|
# via typing-inspect
|
||||||
|
typing-inspect==0.9.0
|
||||||
|
# via dataclasses-json
|
||||||
|
tzdata==2024.1
|
||||||
|
# via pandas
|
||||||
|
uc-micro-py==1.0.3
|
||||||
|
# via linkify-it-py
|
||||||
|
universal-pathlib==0.2.5
|
||||||
|
# via euporie
|
||||||
|
uri-template==1.3.0
|
||||||
|
# via jsonschema
|
||||||
|
urllib3==2.2.2
|
||||||
|
# via kaggle
|
||||||
|
# via requests
|
||||||
|
# via sentry-sdk
|
||||||
|
wandb==0.17.9
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
wcwidth==0.2.13
|
||||||
|
# via blessed
|
||||||
|
# via prompt-toolkit
|
||||||
|
webcolors==24.8.0
|
||||||
|
# via jsonschema
|
||||||
|
webencodings==0.5.1
|
||||||
|
# via bleach
|
||||||
|
# via tinycss2
|
||||||
|
websocket-client==1.8.0
|
||||||
|
# via jupyter-server
|
||||||
|
werkzeug==3.0.4
|
||||||
|
# via tensorboard
|
||||||
|
widgetsnbextension==4.0.13
|
||||||
|
# via ipywidgets
|
||||||
|
xyzservices==2024.9.0
|
||||||
|
# via bokeh
|
||||||
|
# via panel
|
||||||
|
yarl==1.11.1
|
||||||
|
# via aiohttp
|
||||||
675
requirements.lock
Normal file
675
requirements.lock
Normal file
@@ -0,0 +1,675 @@
|
|||||||
|
# generated by rye
|
||||||
|
# use `rye lock` or `rye sync` to update this lockfile
|
||||||
|
#
|
||||||
|
# last locked with the following flags:
|
||||||
|
# pre: false
|
||||||
|
# features: []
|
||||||
|
# all-features: false
|
||||||
|
# with-sources: false
|
||||||
|
# generate-hashes: false
|
||||||
|
|
||||||
|
-e file:.
|
||||||
|
about-time==4.2.1
|
||||||
|
# via alive-progress
|
||||||
|
absl-py==2.1.0
|
||||||
|
# via tensorboard
|
||||||
|
aenum==3.1.12
|
||||||
|
# via euporie
|
||||||
|
aiohappyeyeballs==2.4.0
|
||||||
|
# via aiohttp
|
||||||
|
aiohttp==3.10.5
|
||||||
|
# via fsspec
|
||||||
|
aiosignal==1.3.1
|
||||||
|
# via aiohttp
|
||||||
|
alembic==1.13.2
|
||||||
|
# via optuna
|
||||||
|
alive-progress==3.1.5
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
anyio==4.4.0
|
||||||
|
# via httpx
|
||||||
|
# via jupyter-server
|
||||||
|
argon2-cffi==23.1.0
|
||||||
|
# via jupyter-server
|
||||||
|
argon2-cffi-bindings==21.2.0
|
||||||
|
# via argon2-cffi
|
||||||
|
arrow==1.3.0
|
||||||
|
# via isoduration
|
||||||
|
asttokens==2.4.1
|
||||||
|
# via snoop
|
||||||
|
# via stack-data
|
||||||
|
async-lru==2.0.4
|
||||||
|
# via jupyterlab
|
||||||
|
attrs==24.2.0
|
||||||
|
# via aiohttp
|
||||||
|
# via jsonschema
|
||||||
|
# via referencing
|
||||||
|
babel==2.16.0
|
||||||
|
# via jupyterlab-server
|
||||||
|
beautifulsoup4==4.12.3
|
||||||
|
# via gdown
|
||||||
|
# via nbconvert
|
||||||
|
bleach==6.1.0
|
||||||
|
# via kaggle
|
||||||
|
# via nbconvert
|
||||||
|
# via panel
|
||||||
|
blessed==1.20.0
|
||||||
|
# via curtsies
|
||||||
|
bokeh==3.4.3
|
||||||
|
# via holoviews
|
||||||
|
# via hvplot
|
||||||
|
# via panel
|
||||||
|
bpython==0.24
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
certifi==2024.8.30
|
||||||
|
# via httpcore
|
||||||
|
# via httpx
|
||||||
|
# via kaggle
|
||||||
|
# via requests
|
||||||
|
# via sentry-sdk
|
||||||
|
cffi==1.17.1
|
||||||
|
# via argon2-cffi-bindings
|
||||||
|
charset-normalizer==3.3.2
|
||||||
|
# via requests
|
||||||
|
cheap-repr==0.5.2
|
||||||
|
# via snoop
|
||||||
|
click==8.1.7
|
||||||
|
# via typer
|
||||||
|
# via wandb
|
||||||
|
colorcet==3.1.0
|
||||||
|
# via holoviews
|
||||||
|
# via hvplot
|
||||||
|
colorlog==6.8.2
|
||||||
|
# via optuna
|
||||||
|
comm==0.2.2
|
||||||
|
# via ipykernel
|
||||||
|
# via ipywidgets
|
||||||
|
contourpy==1.3.0
|
||||||
|
# via bokeh
|
||||||
|
# via matplotlib
|
||||||
|
curtsies==0.4.2
|
||||||
|
# via bpython
|
||||||
|
cwcwidth==0.1.9
|
||||||
|
# via bpython
|
||||||
|
# via curtsies
|
||||||
|
cycler==0.12.1
|
||||||
|
# via matplotlib
|
||||||
|
dataclasses-json==0.6.7
|
||||||
|
# via periodic-table-dataclasses
|
||||||
|
debugpy==1.8.5
|
||||||
|
# via ipykernel
|
||||||
|
decorator==5.1.1
|
||||||
|
# via ipython
|
||||||
|
defusedxml==0.7.1
|
||||||
|
# via nbconvert
|
||||||
|
docker-pycreds==0.4.0
|
||||||
|
# via wandb
|
||||||
|
euporie==2.8.2
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
executing==2.1.0
|
||||||
|
# via snoop
|
||||||
|
# via stack-data
|
||||||
|
fastjsonschema==2.20.0
|
||||||
|
# via euporie
|
||||||
|
# via nbformat
|
||||||
|
filelock==3.16.0
|
||||||
|
# via gdown
|
||||||
|
# via torch
|
||||||
|
# via triton
|
||||||
|
flatlatex==0.15
|
||||||
|
# via euporie
|
||||||
|
fonttools==4.53.1
|
||||||
|
# via matplotlib
|
||||||
|
fqdn==1.5.1
|
||||||
|
# via jsonschema
|
||||||
|
frozenlist==1.4.1
|
||||||
|
# via aiohttp
|
||||||
|
# via aiosignal
|
||||||
|
fsspec==2024.9.0
|
||||||
|
# via euporie
|
||||||
|
# via lightning
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via torch
|
||||||
|
# via universal-pathlib
|
||||||
|
gdown==5.2.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
gitdb==4.0.11
|
||||||
|
# via gitpython
|
||||||
|
gitpython==3.1.43
|
||||||
|
# via wandb
|
||||||
|
grapheme==0.6.0
|
||||||
|
# via alive-progress
|
||||||
|
greenlet==3.1.0
|
||||||
|
# via bpython
|
||||||
|
# via sqlalchemy
|
||||||
|
grpcio==1.66.1
|
||||||
|
# via tensorboard
|
||||||
|
h11==0.14.0
|
||||||
|
# via httpcore
|
||||||
|
holoviews==1.19.1
|
||||||
|
# via hvplot
|
||||||
|
httpcore==1.0.5
|
||||||
|
# via httpx
|
||||||
|
httpx==0.27.2
|
||||||
|
# via jupyterlab
|
||||||
|
hvplot==0.10.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
idna==3.8
|
||||||
|
# via anyio
|
||||||
|
# via httpx
|
||||||
|
# via jsonschema
|
||||||
|
# via requests
|
||||||
|
# via yarl
|
||||||
|
imagesize==1.4.1
|
||||||
|
# via euporie
|
||||||
|
ipykernel==6.29.5
|
||||||
|
# via jupyter
|
||||||
|
# via jupyter-console
|
||||||
|
# via jupyterlab
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
ipython==8.27.0
|
||||||
|
# via ipykernel
|
||||||
|
# via ipywidgets
|
||||||
|
# via jupyter-console
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
ipywidgets==8.1.5
|
||||||
|
# via jupyter
|
||||||
|
isoduration==20.11.0
|
||||||
|
# via jsonschema
|
||||||
|
jedi==0.19.1
|
||||||
|
# via ipython
|
||||||
|
jinja2==3.1.4
|
||||||
|
# via bokeh
|
||||||
|
# via jupyter-server
|
||||||
|
# via jupyterlab
|
||||||
|
# via jupyterlab-server
|
||||||
|
# via nbconvert
|
||||||
|
# via torch
|
||||||
|
joblib==1.4.2
|
||||||
|
# via scikit-learn
|
||||||
|
# via scikit-optimize
|
||||||
|
json5==0.9.25
|
||||||
|
# via jupyterlab-server
|
||||||
|
jsonpointer==3.0.0
|
||||||
|
# via jsonschema
|
||||||
|
jsonschema==4.23.0
|
||||||
|
# via jupyter-events
|
||||||
|
# via jupyterlab-server
|
||||||
|
# via nbformat
|
||||||
|
jsonschema-specifications==2023.12.1
|
||||||
|
# via jsonschema
|
||||||
|
jupyter==1.1.1
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
jupyter-client==8.6.2
|
||||||
|
# via euporie
|
||||||
|
# via ipykernel
|
||||||
|
# via jupyter-console
|
||||||
|
# via jupyter-server
|
||||||
|
# via nbclient
|
||||||
|
jupyter-console==6.6.3
|
||||||
|
# via jupyter
|
||||||
|
jupyter-core==5.7.2
|
||||||
|
# via ipykernel
|
||||||
|
# via jupyter-client
|
||||||
|
# via jupyter-console
|
||||||
|
# via jupyter-server
|
||||||
|
# via jupyterlab
|
||||||
|
# via nbclient
|
||||||
|
# via nbconvert
|
||||||
|
# via nbformat
|
||||||
|
jupyter-events==0.10.0
|
||||||
|
# via jupyter-server
|
||||||
|
jupyter-lsp==2.2.5
|
||||||
|
# via jupyterlab
|
||||||
|
jupyter-server==2.14.2
|
||||||
|
# via jupyter-lsp
|
||||||
|
# via jupyterlab
|
||||||
|
# via jupyterlab-server
|
||||||
|
# via notebook
|
||||||
|
# via notebook-shim
|
||||||
|
jupyter-server-terminals==0.5.3
|
||||||
|
# via jupyter-server
|
||||||
|
jupyterlab==4.2.5
|
||||||
|
# via jupyter
|
||||||
|
# via notebook
|
||||||
|
jupyterlab-pygments==0.3.0
|
||||||
|
# via nbconvert
|
||||||
|
jupyterlab-server==2.27.3
|
||||||
|
# via jupyterlab
|
||||||
|
# via notebook
|
||||||
|
jupyterlab-widgets==3.0.13
|
||||||
|
# via ipywidgets
|
||||||
|
jupytext==1.16.4
|
||||||
|
# via euporie
|
||||||
|
kaggle==1.6.17
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
kiwisolver==1.4.7
|
||||||
|
# via matplotlib
|
||||||
|
lightning==2.4.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
lightning-utilities==0.11.7
|
||||||
|
# via lightning
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via torchmetrics
|
||||||
|
linkify-it-py==1.0.3
|
||||||
|
# via euporie
|
||||||
|
# via panel
|
||||||
|
loguru==0.7.2
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
mako==1.3.5
|
||||||
|
# via alembic
|
||||||
|
markdown==3.7
|
||||||
|
# via panel
|
||||||
|
# via tensorboard
|
||||||
|
markdown-it-py==2.1.0
|
||||||
|
# via euporie
|
||||||
|
# via jupytext
|
||||||
|
# via mdit-py-plugins
|
||||||
|
# via panel
|
||||||
|
# via rich
|
||||||
|
markupsafe==2.1.5
|
||||||
|
# via jinja2
|
||||||
|
# via mako
|
||||||
|
# via nbconvert
|
||||||
|
# via werkzeug
|
||||||
|
marshmallow==3.22.0
|
||||||
|
# via dataclasses-json
|
||||||
|
matplotlib==3.9.2
|
||||||
|
# via matplotlib-backend-kitty
|
||||||
|
matplotlib-backend-kitty==2.1.2
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
matplotlib-inline==0.1.7
|
||||||
|
# via ipykernel
|
||||||
|
# via ipython
|
||||||
|
mdit-py-plugins==0.3.5
|
||||||
|
# via euporie
|
||||||
|
# via jupytext
|
||||||
|
# via panel
|
||||||
|
mdurl==0.1.2
|
||||||
|
# via markdown-it-py
|
||||||
|
mistune==3.0.2
|
||||||
|
# via nbconvert
|
||||||
|
mpmath==1.3.0
|
||||||
|
# via sympy
|
||||||
|
multidict==6.1.0
|
||||||
|
# via aiohttp
|
||||||
|
# via yarl
|
||||||
|
mypy-extensions==1.0.0
|
||||||
|
# via typing-inspect
|
||||||
|
nbclient==0.10.0
|
||||||
|
# via nbconvert
|
||||||
|
nbconvert==7.16.4
|
||||||
|
# via jupyter
|
||||||
|
# via jupyter-server
|
||||||
|
nbformat==5.10.4
|
||||||
|
# via euporie
|
||||||
|
# via jupyter-server
|
||||||
|
# via jupytext
|
||||||
|
# via nbclient
|
||||||
|
# via nbconvert
|
||||||
|
nest-asyncio==1.6.0
|
||||||
|
# via ipykernel
|
||||||
|
networkx==3.3
|
||||||
|
# via torch
|
||||||
|
notebook==7.2.2
|
||||||
|
# via jupyter
|
||||||
|
notebook-shim==0.2.4
|
||||||
|
# via jupyterlab
|
||||||
|
# via notebook
|
||||||
|
numpy==2.1.1
|
||||||
|
# via bokeh
|
||||||
|
# via contourpy
|
||||||
|
# via holoviews
|
||||||
|
# via hvplot
|
||||||
|
# via matplotlib
|
||||||
|
# via optuna
|
||||||
|
# via pandas
|
||||||
|
# via pyarrow
|
||||||
|
# via scikit-learn
|
||||||
|
# via scikit-optimize
|
||||||
|
# via scipy
|
||||||
|
# via tensorboard
|
||||||
|
# via torchmetrics
|
||||||
|
# via torchvision
|
||||||
|
nvidia-cublas-cu12==12.1.3.1
|
||||||
|
# via nvidia-cudnn-cu12
|
||||||
|
# via nvidia-cusolver-cu12
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-cupti-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-runtime-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cudnn-cu12==9.1.0.70
|
||||||
|
# via torch
|
||||||
|
nvidia-cufft-cu12==11.0.2.54
|
||||||
|
# via torch
|
||||||
|
nvidia-curand-cu12==10.3.2.106
|
||||||
|
# via torch
|
||||||
|
nvidia-cusolver-cu12==11.4.5.107
|
||||||
|
# via torch
|
||||||
|
nvidia-cusparse-cu12==12.1.0.106
|
||||||
|
# via nvidia-cusolver-cu12
|
||||||
|
# via torch
|
||||||
|
nvidia-nccl-cu12==2.20.5
|
||||||
|
# via torch
|
||||||
|
nvidia-nvjitlink-cu12==12.6.68
|
||||||
|
# via nvidia-cusolver-cu12
|
||||||
|
# via nvidia-cusparse-cu12
|
||||||
|
nvidia-nvtx-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
optuna==4.0.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
overrides==7.7.0
|
||||||
|
# via jupyter-server
|
||||||
|
packaging==24.1
|
||||||
|
# via bokeh
|
||||||
|
# via holoviews
|
||||||
|
# via hvplot
|
||||||
|
# via ipykernel
|
||||||
|
# via jupyter-server
|
||||||
|
# via jupyterlab
|
||||||
|
# via jupyterlab-server
|
||||||
|
# via jupytext
|
||||||
|
# via lightning
|
||||||
|
# via lightning-utilities
|
||||||
|
# via marshmallow
|
||||||
|
# via matplotlib
|
||||||
|
# via nbconvert
|
||||||
|
# via optuna
|
||||||
|
# via plotly
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via scikit-optimize
|
||||||
|
# via tensorboard
|
||||||
|
# via torchmetrics
|
||||||
|
pandas==2.2.2
|
||||||
|
# via bokeh
|
||||||
|
# via holoviews
|
||||||
|
# via hvplot
|
||||||
|
# via panel
|
||||||
|
pandocfilters==1.5.1
|
||||||
|
# via nbconvert
|
||||||
|
panel==1.4.5
|
||||||
|
# via holoviews
|
||||||
|
# via hvplot
|
||||||
|
param==2.1.1
|
||||||
|
# via holoviews
|
||||||
|
# via hvplot
|
||||||
|
# via panel
|
||||||
|
# via pyviz-comms
|
||||||
|
parso==0.8.4
|
||||||
|
# via jedi
|
||||||
|
periodic-table-dataclasses==1.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
pexpect==4.9.0
|
||||||
|
# via ipython
|
||||||
|
pillow==10.4.0
|
||||||
|
# via bokeh
|
||||||
|
# via euporie
|
||||||
|
# via matplotlib
|
||||||
|
# via timg
|
||||||
|
# via torchvision
|
||||||
|
platformdirs==3.11.0
|
||||||
|
# via euporie
|
||||||
|
# via jupyter-core
|
||||||
|
# via wandb
|
||||||
|
plotly==5.24.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
polars==1.6.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
prometheus-client==0.20.0
|
||||||
|
# via jupyter-server
|
||||||
|
prompt-toolkit==3.0.47
|
||||||
|
# via euporie
|
||||||
|
# via ipython
|
||||||
|
# via jupyter-console
|
||||||
|
protobuf==5.28.0
|
||||||
|
# via tensorboard
|
||||||
|
# via wandb
|
||||||
|
psutil==6.0.0
|
||||||
|
# via ipykernel
|
||||||
|
# via wandb
|
||||||
|
ptyprocess==0.7.0
|
||||||
|
# via pexpect
|
||||||
|
# via terminado
|
||||||
|
pure-eval==0.2.3
|
||||||
|
# via stack-data
|
||||||
|
pyaml==24.7.0
|
||||||
|
# via scikit-optimize
|
||||||
|
pyarrow==17.0.0
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
pycparser==2.22
|
||||||
|
# via cffi
|
||||||
|
pygments==2.18.0
|
||||||
|
# via bpython
|
||||||
|
# via euporie
|
||||||
|
# via ipython
|
||||||
|
# via jupyter-console
|
||||||
|
# via nbconvert
|
||||||
|
# via rich
|
||||||
|
# via snoop
|
||||||
|
pyparsing==3.1.4
|
||||||
|
# via matplotlib
|
||||||
|
pyperclip==1.9.0
|
||||||
|
# via euporie
|
||||||
|
pysocks==1.7.1
|
||||||
|
# via requests
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
# via arrow
|
||||||
|
# via jupyter-client
|
||||||
|
# via kaggle
|
||||||
|
# via matplotlib
|
||||||
|
# via pandas
|
||||||
|
python-json-logger==2.0.7
|
||||||
|
# via jupyter-events
|
||||||
|
python-slugify==8.0.4
|
||||||
|
# via kaggle
|
||||||
|
pytorch-lightning==2.4.0
|
||||||
|
# via lightning
|
||||||
|
pytz==2024.2
|
||||||
|
# via pandas
|
||||||
|
pyviz-comms==3.0.3
|
||||||
|
# via holoviews
|
||||||
|
# via panel
|
||||||
|
pyxdg==0.28
|
||||||
|
# via bpython
|
||||||
|
pyyaml==6.0.2
|
||||||
|
# via bokeh
|
||||||
|
# via jupyter-events
|
||||||
|
# via jupytext
|
||||||
|
# via lightning
|
||||||
|
# via optuna
|
||||||
|
# via pyaml
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via wandb
|
||||||
|
pyzmq==26.2.0
|
||||||
|
# via ipykernel
|
||||||
|
# via jupyter-client
|
||||||
|
# via jupyter-console
|
||||||
|
# via jupyter-server
|
||||||
|
referencing==0.35.1
|
||||||
|
# via jsonschema
|
||||||
|
# via jsonschema-specifications
|
||||||
|
# via jupyter-events
|
||||||
|
regex==2024.7.24
|
||||||
|
# via flatlatex
|
||||||
|
requests==2.32.3
|
||||||
|
# via bpython
|
||||||
|
# via gdown
|
||||||
|
# via jupyterlab-server
|
||||||
|
# via kaggle
|
||||||
|
# via panel
|
||||||
|
# via wandb
|
||||||
|
rfc3339-validator==0.1.4
|
||||||
|
# via jsonschema
|
||||||
|
# via jupyter-events
|
||||||
|
rfc3986-validator==0.1.1
|
||||||
|
# via jsonschema
|
||||||
|
# via jupyter-events
|
||||||
|
rich==13.3.1
|
||||||
|
# via typer
|
||||||
|
rpds-py==0.20.0
|
||||||
|
# via jsonschema
|
||||||
|
# via referencing
|
||||||
|
safetensors==0.4.5
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
scikit-learn==1.5.1
|
||||||
|
# via scikit-optimize
|
||||||
|
scikit-optimize==0.10.2
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
scipy==1.14.1
|
||||||
|
# via scikit-learn
|
||||||
|
# via scikit-optimize
|
||||||
|
send2trash==1.8.3
|
||||||
|
# via jupyter-server
|
||||||
|
sentry-sdk==2.14.0
|
||||||
|
# via wandb
|
||||||
|
setproctitle==1.3.3
|
||||||
|
# via wandb
|
||||||
|
setuptools==74.1.2
|
||||||
|
# via jupyterlab
|
||||||
|
# via lightning-utilities
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
# via tensorboard
|
||||||
|
# via torch
|
||||||
|
# via wandb
|
||||||
|
shellingham==1.5.4
|
||||||
|
# via typer
|
||||||
|
six==1.16.0
|
||||||
|
# via asttokens
|
||||||
|
# via bleach
|
||||||
|
# via blessed
|
||||||
|
# via docker-pycreds
|
||||||
|
# via kaggle
|
||||||
|
# via python-dateutil
|
||||||
|
# via rfc3339-validator
|
||||||
|
# via snoop
|
||||||
|
# via tensorboard
|
||||||
|
sixelcrop==0.1.7
|
||||||
|
# via euporie
|
||||||
|
smmap==5.0.1
|
||||||
|
# via gitdb
|
||||||
|
sniffio==1.3.1
|
||||||
|
# via anyio
|
||||||
|
# via httpx
|
||||||
|
snoop==0.4.3
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
soupsieve==2.6
|
||||||
|
# via beautifulsoup4
|
||||||
|
sqlalchemy==2.0.34
|
||||||
|
# via alembic
|
||||||
|
# via optuna
|
||||||
|
stack-data==0.6.3
|
||||||
|
# via ipython
|
||||||
|
sympy==1.13.2
|
||||||
|
# via torch
|
||||||
|
tenacity==9.0.0
|
||||||
|
# via plotly
|
||||||
|
tensorboard==2.17.1
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
tensorboard-data-server==0.7.2
|
||||||
|
# via tensorboard
|
||||||
|
terminado==0.18.1
|
||||||
|
# via jupyter-server
|
||||||
|
# via jupyter-server-terminals
|
||||||
|
text-unidecode==1.3
|
||||||
|
# via python-slugify
|
||||||
|
threadpoolctl==3.5.0
|
||||||
|
# via scikit-learn
|
||||||
|
timg==1.1.6
|
||||||
|
# via euporie
|
||||||
|
tinycss2==1.3.0
|
||||||
|
# via nbconvert
|
||||||
|
torch==2.4.1
|
||||||
|
# via lightning
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
# via torchmetrics
|
||||||
|
# via torchvision
|
||||||
|
torchmetrics==1.4.1
|
||||||
|
# via lightning
|
||||||
|
# via pytorch-lightning
|
||||||
|
torchvision==0.19.1
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
tornado==6.4.1
|
||||||
|
# via bokeh
|
||||||
|
# via ipykernel
|
||||||
|
# via jupyter-client
|
||||||
|
# via jupyter-server
|
||||||
|
# via jupyterlab
|
||||||
|
# via notebook
|
||||||
|
# via terminado
|
||||||
|
tqdm==4.66.5
|
||||||
|
# via gdown
|
||||||
|
# via kaggle
|
||||||
|
# via lightning
|
||||||
|
# via optuna
|
||||||
|
# via panel
|
||||||
|
# via pytorch-lightning
|
||||||
|
traitlets==5.14.3
|
||||||
|
# via comm
|
||||||
|
# via ipykernel
|
||||||
|
# via ipython
|
||||||
|
# via ipywidgets
|
||||||
|
# via jupyter-client
|
||||||
|
# via jupyter-console
|
||||||
|
# via jupyter-core
|
||||||
|
# via jupyter-events
|
||||||
|
# via jupyter-server
|
||||||
|
# via jupyterlab
|
||||||
|
# via matplotlib-inline
|
||||||
|
# via nbclient
|
||||||
|
# via nbconvert
|
||||||
|
# via nbformat
|
||||||
|
triton==3.0.0
|
||||||
|
# via torch
|
||||||
|
typer==0.12.5
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
types-python-dateutil==2.9.0.20240906
|
||||||
|
# via arrow
|
||||||
|
typing-extensions==4.12.2
|
||||||
|
# via alembic
|
||||||
|
# via euporie
|
||||||
|
# via lightning
|
||||||
|
# via lightning-utilities
|
||||||
|
# via panel
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via sqlalchemy
|
||||||
|
# via torch
|
||||||
|
# via typer
|
||||||
|
# via typing-inspect
|
||||||
|
typing-inspect==0.9.0
|
||||||
|
# via dataclasses-json
|
||||||
|
tzdata==2024.1
|
||||||
|
# via pandas
|
||||||
|
uc-micro-py==1.0.3
|
||||||
|
# via linkify-it-py
|
||||||
|
universal-pathlib==0.2.5
|
||||||
|
# via euporie
|
||||||
|
uri-template==1.3.0
|
||||||
|
# via jsonschema
|
||||||
|
urllib3==2.2.2
|
||||||
|
# via kaggle
|
||||||
|
# via requests
|
||||||
|
# via sentry-sdk
|
||||||
|
wandb==0.17.9
|
||||||
|
# via symbolic-nn-tests
|
||||||
|
wcwidth==0.2.13
|
||||||
|
# via blessed
|
||||||
|
# via prompt-toolkit
|
||||||
|
webcolors==24.8.0
|
||||||
|
# via jsonschema
|
||||||
|
webencodings==0.5.1
|
||||||
|
# via bleach
|
||||||
|
# via tinycss2
|
||||||
|
websocket-client==1.8.0
|
||||||
|
# via jupyter-server
|
||||||
|
werkzeug==3.0.4
|
||||||
|
# via tensorboard
|
||||||
|
widgetsnbextension==4.0.13
|
||||||
|
# via ipywidgets
|
||||||
|
xyzservices==2024.9.0
|
||||||
|
# via bokeh
|
||||||
|
# via panel
|
||||||
|
yarl==1.11.1
|
||||||
|
# via aiohttp
|
||||||
11
src/symbolic_nn_tests/__init__.py
Normal file
11
src/symbolic_nn_tests/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from . import __main__
|
||||||
|
|
||||||
|
import typer
|
||||||
|
import ssl
|
||||||
|
|
||||||
|
|
||||||
|
ssl._create_default_https_context = ssl._create_unverified_context
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
typer.run(__main__.main)
|
||||||
@@ -4,7 +4,7 @@ from torchvision.transforms import ToTensor
|
|||||||
from torch.utils.data import DataLoader, random_split
|
from torch.utils.data import DataLoader, random_split
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent
|
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
||||||
DATASET_DIR = PROJECT_ROOT / "datasets/"
|
DATASET_DIR = PROJECT_ROOT / "datasets/"
|
||||||
|
|
||||||
|
|
||||||
93
src/symbolic_nn_tests/local/__init__.py
Normal file
93
src/symbolic_nn_tests/local/__init__.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
LEARNING_RATE = 10e-5
|
||||||
|
|
||||||
|
|
||||||
|
def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True):
|
||||||
|
from .model import main as test_model
|
||||||
|
|
||||||
|
logger = []
|
||||||
|
|
||||||
|
if tensorboard:
|
||||||
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
|
tb_logger = TensorBoardLogger(
|
||||||
|
save_dir=".",
|
||||||
|
name="logs/comparison",
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
|
logger.append(tb_logger)
|
||||||
|
|
||||||
|
if wandb:
|
||||||
|
import wandb as _wandb
|
||||||
|
from lightning.pytorch.loggers import WandbLogger
|
||||||
|
|
||||||
|
wandb_logger = WandbLogger(
|
||||||
|
project="Symbolic_NN_Tests",
|
||||||
|
name=version,
|
||||||
|
dir="wandb",
|
||||||
|
)
|
||||||
|
logger.append(wandb_logger)
|
||||||
|
|
||||||
|
test_model(
|
||||||
|
logger=logger,
|
||||||
|
train_loss=train_loss,
|
||||||
|
val_loss=val_loss,
|
||||||
|
test_loss=test_loss,
|
||||||
|
lr=LEARNING_RATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
if wandb:
|
||||||
|
_wandb.finish()
|
||||||
|
|
||||||
|
|
||||||
|
def run(tensorboard: bool = True, wandb: bool = True):
|
||||||
|
from . import semantic_loss
|
||||||
|
from .model import oh_vs_cat_cross_entropy
|
||||||
|
|
||||||
|
test(
|
||||||
|
train_loss=oh_vs_cat_cross_entropy,
|
||||||
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
version="cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
train_loss=semantic_loss.similarity_cross_entropy,
|
||||||
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
version="similarity_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
train_loss=semantic_loss.hasline_cross_entropy,
|
||||||
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
version="hasline_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
train_loss=semantic_loss.hasloop_cross_entropy,
|
||||||
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
version="hasloop_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
train_loss=semantic_loss.multisemantic_cross_entropy,
|
||||||
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
version="multisemantic_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
|
test(
|
||||||
|
train_loss=semantic_loss.garbage_cross_entropy,
|
||||||
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
version="garbage_cross_entropy",
|
||||||
|
tensorboard=tensorboard,
|
||||||
|
wandb=wandb,
|
||||||
|
)
|
||||||
72
src/symbolic_nn_tests/local/model.py
Normal file
72
src/symbolic_nn_tests/local/model.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
model = nn.Sequential(
|
||||||
|
nn.Flatten(1, -1),
|
||||||
|
nn.Linear(784, 10),
|
||||||
|
nn.Softmax(dim=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def collate(batch):
|
||||||
|
x, y = zip(*batch)
|
||||||
|
x = [i[0] for i in x]
|
||||||
|
y = [torch.tensor(i) for i in y]
|
||||||
|
x = torch.stack(x)
|
||||||
|
y = torch.tensor(y)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
# This is just a quick, lazy way to ensure all models are trained on the same dataset
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_singleton_dataset():
|
||||||
|
from torchvision.datasets import QMNIST
|
||||||
|
|
||||||
|
from symbolic_nn_tests.dataloader import create_dataset
|
||||||
|
|
||||||
|
return create_dataset(
|
||||||
|
dataset=QMNIST,
|
||||||
|
collate_fn=collate,
|
||||||
|
batch_size=128,
|
||||||
|
shuffle_train=True,
|
||||||
|
num_workers=11,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def oh_vs_cat_cross_entropy(y_bin, y_cat):
|
||||||
|
return nn.functional.cross_entropy(
|
||||||
|
y_bin,
|
||||||
|
nn.functional.one_hot(y_cat, num_classes=10).float(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
train_loss=oh_vs_cat_cross_entropy,
|
||||||
|
val_loss=oh_vs_cat_cross_entropy,
|
||||||
|
test_loss=oh_vs_cat_cross_entropy,
|
||||||
|
logger=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
import lightning as L
|
||||||
|
|
||||||
|
from symbolic_nn_tests.train import TrainingWrapper
|
||||||
|
|
||||||
|
if logger is None:
|
||||||
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
|
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
||||||
|
|
||||||
|
train, val, test = get_singleton_dataset()
|
||||||
|
lmodel = TrainingWrapper(
|
||||||
|
model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss
|
||||||
|
)
|
||||||
|
lmodel.configure_optimizers(**kwargs)
|
||||||
|
trainer = L.Trainer(max_epochs=20, logger=logger)
|
||||||
|
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||||
|
trainer.test(dataloaders=test)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
88
src/symbolic_nn_tests/local/semantic_loss.py
Normal file
88
src/symbolic_nn_tests/local/semantic_loss.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def create_semantic_cross_entropy(semantic_matrix):
|
||||||
|
def semantic_cross_entropy(input, target):
|
||||||
|
ce_loss = torch.nn.functional.cross_entropy(input, target)
|
||||||
|
|
||||||
|
penalty_tensor = semantic_matrix[target.argmax(dim=1)]
|
||||||
|
abs_diff = (target - input).abs()
|
||||||
|
semantic_penalty = (abs_diff * penalty_tensor).sum()
|
||||||
|
return ce_loss * semantic_penalty
|
||||||
|
|
||||||
|
def oh_vs_cat_semantic_cross_entropy(input_oh, target_cat):
|
||||||
|
return semantic_cross_entropy(
|
||||||
|
input_oh, torch.nn.functional.one_hot(target_cat, num_classes=10).float()
|
||||||
|
)
|
||||||
|
|
||||||
|
return oh_vs_cat_semantic_cross_entropy
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: This similarity matrix defines loss scaling factors for misclassification
|
||||||
|
# of numbers from our QMNIST dataset. Visually similar numbers (e.g: 3/8) are
|
||||||
|
# penalised less harshly than visually distinct numbers as this mistake is "less
|
||||||
|
# mistaken" given our understanding of the visual characteristics of numerals.
|
||||||
|
# By using this scaling matric we can inject human knowledge into the model via
|
||||||
|
# the loss function, making this an example of a "semantic loss function"
|
||||||
|
SIMILARITY_MATRIX = torch.tensor(
|
||||||
|
[
|
||||||
|
[2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
||||||
|
[1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.5, 1.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 2.0, 1.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
|
||||||
|
]
|
||||||
|
).to("cuda")
|
||||||
|
SIMILARITY_MATRIX /= SIMILARITY_MATRIX.sum() # Normalized to sum of 1
|
||||||
|
|
||||||
|
similarity_cross_entropy = create_semantic_cross_entropy(SIMILARITY_MATRIX)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: The following matrix encodes a simpler semantic penalty for correctly/incorrectly
|
||||||
|
# identifying shapes with straight lines in their representation. This can be a bit fuzzy
|
||||||
|
# in cases like "9" though.
|
||||||
|
HASLINE_MATRIX = torch.tensor(
|
||||||
|
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
|
||||||
|
[False, True, False, False, True, True, False, True, False, True]
|
||||||
|
).to("cuda")
|
||||||
|
HASLINE_MATRIX = torch.stack([i ^ HASLINE_MATRIX for i in HASLINE_MATRIX]).type(
|
||||||
|
torch.float64
|
||||||
|
)
|
||||||
|
HASLINE_MATRIX += 1
|
||||||
|
HASLINE_MATRIX /= HASLINE_MATRIX.sum() # Normalize to sum of 1
|
||||||
|
|
||||||
|
hasline_cross_entropy = create_semantic_cross_entropy(HASLINE_MATRIX)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: Similarly, we can do the same for closed circular loops in a numeric character
|
||||||
|
HASLOOP_MATRIX = torch.tensor(
|
||||||
|
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
|
||||||
|
[True, False, False, False, False, False, True, False, True, True]
|
||||||
|
).to("cuda")
|
||||||
|
HASLOOP_MATRIX = torch.stack([i ^ HASLOOP_MATRIX for i in HASLOOP_MATRIX]).type(
|
||||||
|
torch.float64
|
||||||
|
)
|
||||||
|
HASLOOP_MATRIX += 1
|
||||||
|
HASLOOP_MATRIX /= HASLOOP_MATRIX.sum() # Normalize to sum of 1
|
||||||
|
|
||||||
|
hasloop_cross_entropy = create_semantic_cross_entropy(HASLOOP_MATRIX)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: We can also combine all of these semantic matrices
|
||||||
|
MULTISEMANTIC_MATRIX = SIMILARITY_MATRIX * HASLINE_MATRIX * HASLOOP_MATRIX
|
||||||
|
MULTISEMANTIC_MATRIX /= MULTISEMANTIC_MATRIX.sum()
|
||||||
|
|
||||||
|
multisemantic_cross_entropy = create_semantic_cross_entropy(MULTISEMANTIC_MATRIX)
|
||||||
|
|
||||||
|
# NOTE: As a final test, lets make something similar to tehse but where there's no knowledge,
|
||||||
|
# just random data. This will create a benchmark for the effects of this process wothout the
|
||||||
|
# "knowledge" component
|
||||||
|
GARBAGE_MATRIX = torch.rand(10, 10).to("cuda")
|
||||||
|
GARBAGE_MATRIX /= GARBAGE_MATRIX.sum()
|
||||||
|
|
||||||
|
garbage_cross_entropy = create_semantic_cross_entropy(GARBAGE_MATRIX)
|
||||||
Reference in New Issue
Block a user