mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 14:11:59 +00:00
Moved to rye instead of poetry, to avoid linking problems
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -162,6 +162,7 @@ cython_debug/
|
||||
#.idea/
|
||||
|
||||
.jj
|
||||
.direnv
|
||||
scratchpad.ipynb
|
||||
datasets/
|
||||
explore/
|
||||
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.12.4
|
||||
60
flake.lock
generated
Normal file
60
flake.lock
generated
Normal file
@@ -0,0 +1,60 @@
|
||||
{
|
||||
"nodes": {
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1726053618,
|
||||
"narHash": "sha256-Xu5EVNdrbJ4XCpVU4moytVTuqoo9LsmQgR8g2BQd9Qc=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "0e6a2434a572fd583cac02e142ab0689895e395a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"nixpkgs": "nixpkgs",
|
||||
"utils": "utils"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1710146030,
|
||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
||||
49
flake.nix
Normal file
49
flake.nix
Normal file
@@ -0,0 +1,49 @@
|
||||
{
|
||||
inputs = {
|
||||
utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
|
||||
inputs.nixpkgs.url = "github:NixOS/nixpkgs";
|
||||
|
||||
outputs = { self, nixpkgs, utils }: utils.lib.eachDefaultSystem (system:
|
||||
let
|
||||
pkgs = import nixpkgs {
|
||||
system = "x86_64-linux";
|
||||
config = {
|
||||
allowUnfree = true;
|
||||
cudaSupport = true;
|
||||
};
|
||||
};
|
||||
in
|
||||
{
|
||||
devShell = pkgs.mkShell
|
||||
{
|
||||
NIX_LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath [
|
||||
pkgs.stdenv.cc.cc
|
||||
pkgs.zlib
|
||||
pkgs.libGL
|
||||
pkgs.libGLU
|
||||
pkgs.wlroots
|
||||
pkgs.ncurses5
|
||||
pkgs.linuxKernel.packages.linux_latest_libre.nvidia_x11
|
||||
];
|
||||
NIX_LD = pkgs.lib.fileContents "${pkgs.stdenv.cc}/nix-support/dynamic-linker";
|
||||
|
||||
buildInputs = with pkgs; [
|
||||
ruff
|
||||
ruff-lsp
|
||||
rye
|
||||
uv
|
||||
];
|
||||
|
||||
shellHook = ''
|
||||
export CUDA_PATH=${pkgs.cudatoolkit}
|
||||
export LD_LIBRARY_PATH=$(nix eval --raw nixpkgs#addOpenGLRunpath.driverLink)/lib
|
||||
export EXTRA_LDFLAGS="-L/lib -L${pkgs.linuxPackages.nvidia_x11}/lib"
|
||||
export EXTRA_CCFLAGS="-I/usr/include"
|
||||
'';
|
||||
};
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
5703
poetry.lock
generated
5703
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,41 +1,55 @@
|
||||
[tool.poetry]
|
||||
[project]
|
||||
name = "symbolic-nn-tests"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Cian Hughes <chughes000@gmail.com>"]
|
||||
license = "MIT"
|
||||
description = "Add your description here"
|
||||
authors = [
|
||||
{ name = "Cian Hughes", email = "chughes000@gmail.com" }
|
||||
]
|
||||
dependencies = [
|
||||
"torch>=2.4.1",
|
||||
"lightning>=2.4.0",
|
||||
"torchvision>=0.19.1",
|
||||
"wandb>=0.17.9",
|
||||
"optuna>=4.0.0",
|
||||
"setuptools>=74.1.2",
|
||||
"gdown>=5.2.0",
|
||||
"bpython>=0.24",
|
||||
"ipython>=8.27.0",
|
||||
"matplotlib-backend-kitty>=2.1.2",
|
||||
"euporie>=2.8.2",
|
||||
"ipykernel>=6.29.5",
|
||||
"tensorboard>=2.17.1",
|
||||
"typer>=0.12.5",
|
||||
"kaggle>=1.6.17",
|
||||
"periodic-table-dataclasses>=1.0",
|
||||
"polars>=1.6.0",
|
||||
"jupyter>=1.1.1",
|
||||
"safetensors>=0.4.5",
|
||||
"alive-progress>=3.1.5",
|
||||
"hvplot>=0.10.0",
|
||||
"pyarrow>=17.0.0",
|
||||
"loguru>=0.7.2",
|
||||
"plotly>=5.24.0",
|
||||
"snoop>=0.4.3",
|
||||
"scikit-optimize>=0.10.2",
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.8"
|
||||
classifiers = ["Private :: Do Not Upload"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
torch = "^2.3.0"
|
||||
lightning = "^2.2.4"
|
||||
torchvision = "^0.18.0"
|
||||
wandb = "^0.17.0"
|
||||
optuna = "^3.6.1"
|
||||
setuptools = "^69.5.1"
|
||||
gdown = "^5.2.0"
|
||||
bpython = "^0.24"
|
||||
ipython = "^8.24.0"
|
||||
matplotlib-backend-kitty = "^2.1.2"
|
||||
euporie = "^2.8.2"
|
||||
ipykernel = "^6.29.4"
|
||||
tensorboard = "^2.16.2"
|
||||
typer = "^0.12.3"
|
||||
kaggle = "^1.6.14"
|
||||
periodic-table-dataclasses = "^1.0"
|
||||
polars = "^0.20.28"
|
||||
jupyter = "^1.0.0"
|
||||
safetensors = "^0.4.3"
|
||||
alive-progress = "^3.1.5"
|
||||
hvplot = "^0.10.0"
|
||||
pyarrow = "^16.1.0"
|
||||
loguru = "^0.7.2"
|
||||
plotly = "^5.22.0"
|
||||
snoop = "^0.4.3"
|
||||
scikit-optimize = "^0.10.2"
|
||||
|
||||
[project.scripts]
|
||||
"symbolic-nn-tests" = "symbolic_nn_tests:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.rye]
|
||||
managed = true
|
||||
dev-dependencies = []
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/symbolic_nn_tests"]
|
||||
|
||||
675
requirements-dev.lock
Normal file
675
requirements-dev.lock
Normal file
@@ -0,0 +1,675 @@
|
||||
# generated by rye
|
||||
# use `rye lock` or `rye sync` to update this lockfile
|
||||
#
|
||||
# last locked with the following flags:
|
||||
# pre: false
|
||||
# features: []
|
||||
# all-features: false
|
||||
# with-sources: false
|
||||
# generate-hashes: false
|
||||
|
||||
-e file:.
|
||||
about-time==4.2.1
|
||||
# via alive-progress
|
||||
absl-py==2.1.0
|
||||
# via tensorboard
|
||||
aenum==3.1.12
|
||||
# via euporie
|
||||
aiohappyeyeballs==2.4.0
|
||||
# via aiohttp
|
||||
aiohttp==3.10.5
|
||||
# via fsspec
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
alembic==1.13.2
|
||||
# via optuna
|
||||
alive-progress==3.1.5
|
||||
# via symbolic-nn-tests
|
||||
anyio==4.4.0
|
||||
# via httpx
|
||||
# via jupyter-server
|
||||
argon2-cffi==23.1.0
|
||||
# via jupyter-server
|
||||
argon2-cffi-bindings==21.2.0
|
||||
# via argon2-cffi
|
||||
arrow==1.3.0
|
||||
# via isoduration
|
||||
asttokens==2.4.1
|
||||
# via snoop
|
||||
# via stack-data
|
||||
async-lru==2.0.4
|
||||
# via jupyterlab
|
||||
attrs==24.2.0
|
||||
# via aiohttp
|
||||
# via jsonschema
|
||||
# via referencing
|
||||
babel==2.16.0
|
||||
# via jupyterlab-server
|
||||
beautifulsoup4==4.12.3
|
||||
# via gdown
|
||||
# via nbconvert
|
||||
bleach==6.1.0
|
||||
# via kaggle
|
||||
# via nbconvert
|
||||
# via panel
|
||||
blessed==1.20.0
|
||||
# via curtsies
|
||||
bokeh==3.4.3
|
||||
# via holoviews
|
||||
# via hvplot
|
||||
# via panel
|
||||
bpython==0.24
|
||||
# via symbolic-nn-tests
|
||||
certifi==2024.8.30
|
||||
# via httpcore
|
||||
# via httpx
|
||||
# via kaggle
|
||||
# via requests
|
||||
# via sentry-sdk
|
||||
cffi==1.17.1
|
||||
# via argon2-cffi-bindings
|
||||
charset-normalizer==3.3.2
|
||||
# via requests
|
||||
cheap-repr==0.5.2
|
||||
# via snoop
|
||||
click==8.1.7
|
||||
# via typer
|
||||
# via wandb
|
||||
colorcet==3.1.0
|
||||
# via holoviews
|
||||
# via hvplot
|
||||
colorlog==6.8.2
|
||||
# via optuna
|
||||
comm==0.2.2
|
||||
# via ipykernel
|
||||
# via ipywidgets
|
||||
contourpy==1.3.0
|
||||
# via bokeh
|
||||
# via matplotlib
|
||||
curtsies==0.4.2
|
||||
# via bpython
|
||||
cwcwidth==0.1.9
|
||||
# via bpython
|
||||
# via curtsies
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
dataclasses-json==0.6.7
|
||||
# via periodic-table-dataclasses
|
||||
debugpy==1.8.5
|
||||
# via ipykernel
|
||||
decorator==5.1.1
|
||||
# via ipython
|
||||
defusedxml==0.7.1
|
||||
# via nbconvert
|
||||
docker-pycreds==0.4.0
|
||||
# via wandb
|
||||
euporie==2.8.2
|
||||
# via symbolic-nn-tests
|
||||
executing==2.1.0
|
||||
# via snoop
|
||||
# via stack-data
|
||||
fastjsonschema==2.20.0
|
||||
# via euporie
|
||||
# via nbformat
|
||||
filelock==3.16.0
|
||||
# via gdown
|
||||
# via torch
|
||||
# via triton
|
||||
flatlatex==0.15
|
||||
# via euporie
|
||||
fonttools==4.53.1
|
||||
# via matplotlib
|
||||
fqdn==1.5.1
|
||||
# via jsonschema
|
||||
frozenlist==1.4.1
|
||||
# via aiohttp
|
||||
# via aiosignal
|
||||
fsspec==2024.9.0
|
||||
# via euporie
|
||||
# via lightning
|
||||
# via pytorch-lightning
|
||||
# via torch
|
||||
# via universal-pathlib
|
||||
gdown==5.2.0
|
||||
# via symbolic-nn-tests
|
||||
gitdb==4.0.11
|
||||
# via gitpython
|
||||
gitpython==3.1.43
|
||||
# via wandb
|
||||
grapheme==0.6.0
|
||||
# via alive-progress
|
||||
greenlet==3.1.0
|
||||
# via bpython
|
||||
# via sqlalchemy
|
||||
grpcio==1.66.1
|
||||
# via tensorboard
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
holoviews==1.19.1
|
||||
# via hvplot
|
||||
httpcore==1.0.5
|
||||
# via httpx
|
||||
httpx==0.27.2
|
||||
# via jupyterlab
|
||||
hvplot==0.10.0
|
||||
# via symbolic-nn-tests
|
||||
idna==3.8
|
||||
# via anyio
|
||||
# via httpx
|
||||
# via jsonschema
|
||||
# via requests
|
||||
# via yarl
|
||||
imagesize==1.4.1
|
||||
# via euporie
|
||||
ipykernel==6.29.5
|
||||
# via jupyter
|
||||
# via jupyter-console
|
||||
# via jupyterlab
|
||||
# via symbolic-nn-tests
|
||||
ipython==8.27.0
|
||||
# via ipykernel
|
||||
# via ipywidgets
|
||||
# via jupyter-console
|
||||
# via symbolic-nn-tests
|
||||
ipywidgets==8.1.5
|
||||
# via jupyter
|
||||
isoduration==20.11.0
|
||||
# via jsonschema
|
||||
jedi==0.19.1
|
||||
# via ipython
|
||||
jinja2==3.1.4
|
||||
# via bokeh
|
||||
# via jupyter-server
|
||||
# via jupyterlab
|
||||
# via jupyterlab-server
|
||||
# via nbconvert
|
||||
# via torch
|
||||
joblib==1.4.2
|
||||
# via scikit-learn
|
||||
# via scikit-optimize
|
||||
json5==0.9.25
|
||||
# via jupyterlab-server
|
||||
jsonpointer==3.0.0
|
||||
# via jsonschema
|
||||
jsonschema==4.23.0
|
||||
# via jupyter-events
|
||||
# via jupyterlab-server
|
||||
# via nbformat
|
||||
jsonschema-specifications==2023.12.1
|
||||
# via jsonschema
|
||||
jupyter==1.1.1
|
||||
# via symbolic-nn-tests
|
||||
jupyter-client==8.6.2
|
||||
# via euporie
|
||||
# via ipykernel
|
||||
# via jupyter-console
|
||||
# via jupyter-server
|
||||
# via nbclient
|
||||
jupyter-console==6.6.3
|
||||
# via jupyter
|
||||
jupyter-core==5.7.2
|
||||
# via ipykernel
|
||||
# via jupyter-client
|
||||
# via jupyter-console
|
||||
# via jupyter-server
|
||||
# via jupyterlab
|
||||
# via nbclient
|
||||
# via nbconvert
|
||||
# via nbformat
|
||||
jupyter-events==0.10.0
|
||||
# via jupyter-server
|
||||
jupyter-lsp==2.2.5
|
||||
# via jupyterlab
|
||||
jupyter-server==2.14.2
|
||||
# via jupyter-lsp
|
||||
# via jupyterlab
|
||||
# via jupyterlab-server
|
||||
# via notebook
|
||||
# via notebook-shim
|
||||
jupyter-server-terminals==0.5.3
|
||||
# via jupyter-server
|
||||
jupyterlab==4.2.5
|
||||
# via jupyter
|
||||
# via notebook
|
||||
jupyterlab-pygments==0.3.0
|
||||
# via nbconvert
|
||||
jupyterlab-server==2.27.3
|
||||
# via jupyterlab
|
||||
# via notebook
|
||||
jupyterlab-widgets==3.0.13
|
||||
# via ipywidgets
|
||||
jupytext==1.16.4
|
||||
# via euporie
|
||||
kaggle==1.6.17
|
||||
# via symbolic-nn-tests
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
lightning==2.4.0
|
||||
# via symbolic-nn-tests
|
||||
lightning-utilities==0.11.7
|
||||
# via lightning
|
||||
# via pytorch-lightning
|
||||
# via torchmetrics
|
||||
linkify-it-py==1.0.3
|
||||
# via euporie
|
||||
# via panel
|
||||
loguru==0.7.2
|
||||
# via symbolic-nn-tests
|
||||
mako==1.3.5
|
||||
# via alembic
|
||||
markdown==3.7
|
||||
# via panel
|
||||
# via tensorboard
|
||||
markdown-it-py==2.1.0
|
||||
# via euporie
|
||||
# via jupytext
|
||||
# via mdit-py-plugins
|
||||
# via panel
|
||||
# via rich
|
||||
markupsafe==2.1.5
|
||||
# via jinja2
|
||||
# via mako
|
||||
# via nbconvert
|
||||
# via werkzeug
|
||||
marshmallow==3.22.0
|
||||
# via dataclasses-json
|
||||
matplotlib==3.9.2
|
||||
# via matplotlib-backend-kitty
|
||||
matplotlib-backend-kitty==2.1.2
|
||||
# via symbolic-nn-tests
|
||||
matplotlib-inline==0.1.7
|
||||
# via ipykernel
|
||||
# via ipython
|
||||
mdit-py-plugins==0.3.5
|
||||
# via euporie
|
||||
# via jupytext
|
||||
# via panel
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistune==3.0.2
|
||||
# via nbconvert
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
multidict==6.1.0
|
||||
# via aiohttp
|
||||
# via yarl
|
||||
mypy-extensions==1.0.0
|
||||
# via typing-inspect
|
||||
nbclient==0.10.0
|
||||
# via nbconvert
|
||||
nbconvert==7.16.4
|
||||
# via jupyter
|
||||
# via jupyter-server
|
||||
nbformat==5.10.4
|
||||
# via euporie
|
||||
# via jupyter-server
|
||||
# via jupytext
|
||||
# via nbclient
|
||||
# via nbconvert
|
||||
nest-asyncio==1.6.0
|
||||
# via ipykernel
|
||||
networkx==3.3
|
||||
# via torch
|
||||
notebook==7.2.2
|
||||
# via jupyter
|
||||
notebook-shim==0.2.4
|
||||
# via jupyterlab
|
||||
# via notebook
|
||||
numpy==2.1.1
|
||||
# via bokeh
|
||||
# via contourpy
|
||||
# via holoviews
|
||||
# via hvplot
|
||||
# via matplotlib
|
||||
# via optuna
|
||||
# via pandas
|
||||
# via pyarrow
|
||||
# via scikit-learn
|
||||
# via scikit-optimize
|
||||
# via scipy
|
||||
# via tensorboard
|
||||
# via torchmetrics
|
||||
# via torchvision
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
# via nvidia-cudnn-cu12
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==9.1.0.70
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.20.5
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.6.68
|
||||
# via nvidia-cusolver-cu12
|
||||
# via nvidia-cusparse-cu12
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
# via torch
|
||||
optuna==4.0.0
|
||||
# via symbolic-nn-tests
|
||||
overrides==7.7.0
|
||||
# via jupyter-server
|
||||
packaging==24.1
|
||||
# via bokeh
|
||||
# via holoviews
|
||||
# via hvplot
|
||||
# via ipykernel
|
||||
# via jupyter-server
|
||||
# via jupyterlab
|
||||
# via jupyterlab-server
|
||||
# via jupytext
|
||||
# via lightning
|
||||
# via lightning-utilities
|
||||
# via marshmallow
|
||||
# via matplotlib
|
||||
# via nbconvert
|
||||
# via optuna
|
||||
# via plotly
|
||||
# via pytorch-lightning
|
||||
# via scikit-optimize
|
||||
# via tensorboard
|
||||
# via torchmetrics
|
||||
pandas==2.2.2
|
||||
# via bokeh
|
||||
# via holoviews
|
||||
# via hvplot
|
||||
# via panel
|
||||
pandocfilters==1.5.1
|
||||
# via nbconvert
|
||||
panel==1.4.5
|
||||
# via holoviews
|
||||
# via hvplot
|
||||
param==2.1.1
|
||||
# via holoviews
|
||||
# via hvplot
|
||||
# via panel
|
||||
# via pyviz-comms
|
||||
parso==0.8.4
|
||||
# via jedi
|
||||
periodic-table-dataclasses==1.0
|
||||
# via symbolic-nn-tests
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pillow==10.4.0
|
||||
# via bokeh
|
||||
# via euporie
|
||||
# via matplotlib
|
||||
# via timg
|
||||
# via torchvision
|
||||
platformdirs==3.11.0
|
||||
# via euporie
|
||||
# via jupyter-core
|
||||
# via wandb
|
||||
plotly==5.24.0
|
||||
# via symbolic-nn-tests
|
||||
polars==1.6.0
|
||||
# via symbolic-nn-tests
|
||||
prometheus-client==0.20.0
|
||||
# via jupyter-server
|
||||
prompt-toolkit==3.0.47
|
||||
# via euporie
|
||||
# via ipython
|
||||
# via jupyter-console
|
||||
protobuf==5.28.0
|
||||
# via tensorboard
|
||||
# via wandb
|
||||
psutil==6.0.0
|
||||
# via ipykernel
|
||||
# via wandb
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
# via terminado
|
||||
pure-eval==0.2.3
|
||||
# via stack-data
|
||||
pyaml==24.7.0
|
||||
# via scikit-optimize
|
||||
pyarrow==17.0.0
|
||||
# via symbolic-nn-tests
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pygments==2.18.0
|
||||
# via bpython
|
||||
# via euporie
|
||||
# via ipython
|
||||
# via jupyter-console
|
||||
# via nbconvert
|
||||
# via rich
|
||||
# via snoop
|
||||
pyparsing==3.1.4
|
||||
# via matplotlib
|
||||
pyperclip==1.9.0
|
||||
# via euporie
|
||||
pysocks==1.7.1
|
||||
# via requests
|
||||
python-dateutil==2.9.0.post0
|
||||
# via arrow
|
||||
# via jupyter-client
|
||||
# via kaggle
|
||||
# via matplotlib
|
||||
# via pandas
|
||||
python-json-logger==2.0.7
|
||||
# via jupyter-events
|
||||
python-slugify==8.0.4
|
||||
# via kaggle
|
||||
pytorch-lightning==2.4.0
|
||||
# via lightning
|
||||
pytz==2024.2
|
||||
# via pandas
|
||||
pyviz-comms==3.0.3
|
||||
# via holoviews
|
||||
# via panel
|
||||
pyxdg==0.28
|
||||
# via bpython
|
||||
pyyaml==6.0.2
|
||||
# via bokeh
|
||||
# via jupyter-events
|
||||
# via jupytext
|
||||
# via lightning
|
||||
# via optuna
|
||||
# via pyaml
|
||||
# via pytorch-lightning
|
||||
# via wandb
|
||||
pyzmq==26.2.0
|
||||
# via ipykernel
|
||||
# via jupyter-client
|
||||
# via jupyter-console
|
||||
# via jupyter-server
|
||||
referencing==0.35.1
|
||||
# via jsonschema
|
||||
# via jsonschema-specifications
|
||||
# via jupyter-events
|
||||
regex==2024.7.24
|
||||
# via flatlatex
|
||||
requests==2.32.3
|
||||
# via bpython
|
||||
# via gdown
|
||||
# via jupyterlab-server
|
||||
# via kaggle
|
||||
# via panel
|
||||
# via wandb
|
||||
rfc3339-validator==0.1.4
|
||||
# via jsonschema
|
||||
# via jupyter-events
|
||||
rfc3986-validator==0.1.1
|
||||
# via jsonschema
|
||||
# via jupyter-events
|
||||
rich==13.3.1
|
||||
# via typer
|
||||
rpds-py==0.20.0
|
||||
# via jsonschema
|
||||
# via referencing
|
||||
safetensors==0.4.5
|
||||
# via symbolic-nn-tests
|
||||
scikit-learn==1.5.1
|
||||
# via scikit-optimize
|
||||
scikit-optimize==0.10.2
|
||||
# via symbolic-nn-tests
|
||||
scipy==1.14.1
|
||||
# via scikit-learn
|
||||
# via scikit-optimize
|
||||
send2trash==1.8.3
|
||||
# via jupyter-server
|
||||
sentry-sdk==2.14.0
|
||||
# via wandb
|
||||
setproctitle==1.3.3
|
||||
# via wandb
|
||||
setuptools==74.1.2
|
||||
# via jupyterlab
|
||||
# via lightning-utilities
|
||||
# via symbolic-nn-tests
|
||||
# via tensorboard
|
||||
# via torch
|
||||
# via wandb
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.16.0
|
||||
# via asttokens
|
||||
# via bleach
|
||||
# via blessed
|
||||
# via docker-pycreds
|
||||
# via kaggle
|
||||
# via python-dateutil
|
||||
# via rfc3339-validator
|
||||
# via snoop
|
||||
# via tensorboard
|
||||
sixelcrop==0.1.7
|
||||
# via euporie
|
||||
smmap==5.0.1
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
# via httpx
|
||||
snoop==0.4.3
|
||||
# via symbolic-nn-tests
|
||||
soupsieve==2.6
|
||||
# via beautifulsoup4
|
||||
sqlalchemy==2.0.34
|
||||
# via alembic
|
||||
# via optuna
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
sympy==1.13.2
|
||||
# via torch
|
||||
tenacity==9.0.0
|
||||
# via plotly
|
||||
tensorboard==2.17.1
|
||||
# via symbolic-nn-tests
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
terminado==0.18.1
|
||||
# via jupyter-server
|
||||
# via jupyter-server-terminals
|
||||
text-unidecode==1.3
|
||||
# via python-slugify
|
||||
threadpoolctl==3.5.0
|
||||
# via scikit-learn
|
||||
timg==1.1.6
|
||||
# via euporie
|
||||
tinycss2==1.3.0
|
||||
# via nbconvert
|
||||
torch==2.4.1
|
||||
# via lightning
|
||||
# via pytorch-lightning
|
||||
# via symbolic-nn-tests
|
||||
# via torchmetrics
|
||||
# via torchvision
|
||||
torchmetrics==1.4.1
|
||||
# via lightning
|
||||
# via pytorch-lightning
|
||||
torchvision==0.19.1
|
||||
# via symbolic-nn-tests
|
||||
tornado==6.4.1
|
||||
# via bokeh
|
||||
# via ipykernel
|
||||
# via jupyter-client
|
||||
# via jupyter-server
|
||||
# via jupyterlab
|
||||
# via notebook
|
||||
# via terminado
|
||||
tqdm==4.66.5
|
||||
# via gdown
|
||||
# via kaggle
|
||||
# via lightning
|
||||
# via optuna
|
||||
# via panel
|
||||
# via pytorch-lightning
|
||||
traitlets==5.14.3
|
||||
# via comm
|
||||
# via ipykernel
|
||||
# via ipython
|
||||
# via ipywidgets
|
||||
# via jupyter-client
|
||||
# via jupyter-console
|
||||
# via jupyter-core
|
||||
# via jupyter-events
|
||||
# via jupyter-server
|
||||
# via jupyterlab
|
||||
# via matplotlib-inline
|
||||
# via nbclient
|
||||
# via nbconvert
|
||||
# via nbformat
|
||||
triton==3.0.0
|
||||
# via torch
|
||||
typer==0.12.5
|
||||
# via symbolic-nn-tests
|
||||
types-python-dateutil==2.9.0.20240906
|
||||
# via arrow
|
||||
typing-extensions==4.12.2
|
||||
# via alembic
|
||||
# via euporie
|
||||
# via lightning
|
||||
# via lightning-utilities
|
||||
# via panel
|
||||
# via pytorch-lightning
|
||||
# via sqlalchemy
|
||||
# via torch
|
||||
# via typer
|
||||
# via typing-inspect
|
||||
typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
tzdata==2024.1
|
||||
# via pandas
|
||||
uc-micro-py==1.0.3
|
||||
# via linkify-it-py
|
||||
universal-pathlib==0.2.5
|
||||
# via euporie
|
||||
uri-template==1.3.0
|
||||
# via jsonschema
|
||||
urllib3==2.2.2
|
||||
# via kaggle
|
||||
# via requests
|
||||
# via sentry-sdk
|
||||
wandb==0.17.9
|
||||
# via symbolic-nn-tests
|
||||
wcwidth==0.2.13
|
||||
# via blessed
|
||||
# via prompt-toolkit
|
||||
webcolors==24.8.0
|
||||
# via jsonschema
|
||||
webencodings==0.5.1
|
||||
# via bleach
|
||||
# via tinycss2
|
||||
websocket-client==1.8.0
|
||||
# via jupyter-server
|
||||
werkzeug==3.0.4
|
||||
# via tensorboard
|
||||
widgetsnbextension==4.0.13
|
||||
# via ipywidgets
|
||||
xyzservices==2024.9.0
|
||||
# via bokeh
|
||||
# via panel
|
||||
yarl==1.11.1
|
||||
# via aiohttp
|
||||
675
requirements.lock
Normal file
675
requirements.lock
Normal file
@@ -0,0 +1,675 @@
|
||||
# generated by rye
|
||||
# use `rye lock` or `rye sync` to update this lockfile
|
||||
#
|
||||
# last locked with the following flags:
|
||||
# pre: false
|
||||
# features: []
|
||||
# all-features: false
|
||||
# with-sources: false
|
||||
# generate-hashes: false
|
||||
|
||||
-e file:.
|
||||
about-time==4.2.1
|
||||
# via alive-progress
|
||||
absl-py==2.1.0
|
||||
# via tensorboard
|
||||
aenum==3.1.12
|
||||
# via euporie
|
||||
aiohappyeyeballs==2.4.0
|
||||
# via aiohttp
|
||||
aiohttp==3.10.5
|
||||
# via fsspec
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
alembic==1.13.2
|
||||
# via optuna
|
||||
alive-progress==3.1.5
|
||||
# via symbolic-nn-tests
|
||||
anyio==4.4.0
|
||||
# via httpx
|
||||
# via jupyter-server
|
||||
argon2-cffi==23.1.0
|
||||
# via jupyter-server
|
||||
argon2-cffi-bindings==21.2.0
|
||||
# via argon2-cffi
|
||||
arrow==1.3.0
|
||||
# via isoduration
|
||||
asttokens==2.4.1
|
||||
# via snoop
|
||||
# via stack-data
|
||||
async-lru==2.0.4
|
||||
# via jupyterlab
|
||||
attrs==24.2.0
|
||||
# via aiohttp
|
||||
# via jsonschema
|
||||
# via referencing
|
||||
babel==2.16.0
|
||||
# via jupyterlab-server
|
||||
beautifulsoup4==4.12.3
|
||||
# via gdown
|
||||
# via nbconvert
|
||||
bleach==6.1.0
|
||||
# via kaggle
|
||||
# via nbconvert
|
||||
# via panel
|
||||
blessed==1.20.0
|
||||
# via curtsies
|
||||
bokeh==3.4.3
|
||||
# via holoviews
|
||||
# via hvplot
|
||||
# via panel
|
||||
bpython==0.24
|
||||
# via symbolic-nn-tests
|
||||
certifi==2024.8.30
|
||||
# via httpcore
|
||||
# via httpx
|
||||
# via kaggle
|
||||
# via requests
|
||||
# via sentry-sdk
|
||||
cffi==1.17.1
|
||||
# via argon2-cffi-bindings
|
||||
charset-normalizer==3.3.2
|
||||
# via requests
|
||||
cheap-repr==0.5.2
|
||||
# via snoop
|
||||
click==8.1.7
|
||||
# via typer
|
||||
# via wandb
|
||||
colorcet==3.1.0
|
||||
# via holoviews
|
||||
# via hvplot
|
||||
colorlog==6.8.2
|
||||
# via optuna
|
||||
comm==0.2.2
|
||||
# via ipykernel
|
||||
# via ipywidgets
|
||||
contourpy==1.3.0
|
||||
# via bokeh
|
||||
# via matplotlib
|
||||
curtsies==0.4.2
|
||||
# via bpython
|
||||
cwcwidth==0.1.9
|
||||
# via bpython
|
||||
# via curtsies
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
dataclasses-json==0.6.7
|
||||
# via periodic-table-dataclasses
|
||||
debugpy==1.8.5
|
||||
# via ipykernel
|
||||
decorator==5.1.1
|
||||
# via ipython
|
||||
defusedxml==0.7.1
|
||||
# via nbconvert
|
||||
docker-pycreds==0.4.0
|
||||
# via wandb
|
||||
euporie==2.8.2
|
||||
# via symbolic-nn-tests
|
||||
executing==2.1.0
|
||||
# via snoop
|
||||
# via stack-data
|
||||
fastjsonschema==2.20.0
|
||||
# via euporie
|
||||
# via nbformat
|
||||
filelock==3.16.0
|
||||
# via gdown
|
||||
# via torch
|
||||
# via triton
|
||||
flatlatex==0.15
|
||||
# via euporie
|
||||
fonttools==4.53.1
|
||||
# via matplotlib
|
||||
fqdn==1.5.1
|
||||
# via jsonschema
|
||||
frozenlist==1.4.1
|
||||
# via aiohttp
|
||||
# via aiosignal
|
||||
fsspec==2024.9.0
|
||||
# via euporie
|
||||
# via lightning
|
||||
# via pytorch-lightning
|
||||
# via torch
|
||||
# via universal-pathlib
|
||||
gdown==5.2.0
|
||||
# via symbolic-nn-tests
|
||||
gitdb==4.0.11
|
||||
# via gitpython
|
||||
gitpython==3.1.43
|
||||
# via wandb
|
||||
grapheme==0.6.0
|
||||
# via alive-progress
|
||||
greenlet==3.1.0
|
||||
# via bpython
|
||||
# via sqlalchemy
|
||||
grpcio==1.66.1
|
||||
# via tensorboard
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
holoviews==1.19.1
|
||||
# via hvplot
|
||||
httpcore==1.0.5
|
||||
# via httpx
|
||||
httpx==0.27.2
|
||||
# via jupyterlab
|
||||
hvplot==0.10.0
|
||||
# via symbolic-nn-tests
|
||||
idna==3.8
|
||||
# via anyio
|
||||
# via httpx
|
||||
# via jsonschema
|
||||
# via requests
|
||||
# via yarl
|
||||
imagesize==1.4.1
|
||||
# via euporie
|
||||
ipykernel==6.29.5
|
||||
# via jupyter
|
||||
# via jupyter-console
|
||||
# via jupyterlab
|
||||
# via symbolic-nn-tests
|
||||
ipython==8.27.0
|
||||
# via ipykernel
|
||||
# via ipywidgets
|
||||
# via jupyter-console
|
||||
# via symbolic-nn-tests
|
||||
ipywidgets==8.1.5
|
||||
# via jupyter
|
||||
isoduration==20.11.0
|
||||
# via jsonschema
|
||||
jedi==0.19.1
|
||||
# via ipython
|
||||
jinja2==3.1.4
|
||||
# via bokeh
|
||||
# via jupyter-server
|
||||
# via jupyterlab
|
||||
# via jupyterlab-server
|
||||
# via nbconvert
|
||||
# via torch
|
||||
joblib==1.4.2
|
||||
# via scikit-learn
|
||||
# via scikit-optimize
|
||||
json5==0.9.25
|
||||
# via jupyterlab-server
|
||||
jsonpointer==3.0.0
|
||||
# via jsonschema
|
||||
jsonschema==4.23.0
|
||||
# via jupyter-events
|
||||
# via jupyterlab-server
|
||||
# via nbformat
|
||||
jsonschema-specifications==2023.12.1
|
||||
# via jsonschema
|
||||
jupyter==1.1.1
|
||||
# via symbolic-nn-tests
|
||||
jupyter-client==8.6.2
|
||||
# via euporie
|
||||
# via ipykernel
|
||||
# via jupyter-console
|
||||
# via jupyter-server
|
||||
# via nbclient
|
||||
jupyter-console==6.6.3
|
||||
# via jupyter
|
||||
jupyter-core==5.7.2
|
||||
# via ipykernel
|
||||
# via jupyter-client
|
||||
# via jupyter-console
|
||||
# via jupyter-server
|
||||
# via jupyterlab
|
||||
# via nbclient
|
||||
# via nbconvert
|
||||
# via nbformat
|
||||
jupyter-events==0.10.0
|
||||
# via jupyter-server
|
||||
jupyter-lsp==2.2.5
|
||||
# via jupyterlab
|
||||
jupyter-server==2.14.2
|
||||
# via jupyter-lsp
|
||||
# via jupyterlab
|
||||
# via jupyterlab-server
|
||||
# via notebook
|
||||
# via notebook-shim
|
||||
jupyter-server-terminals==0.5.3
|
||||
# via jupyter-server
|
||||
jupyterlab==4.2.5
|
||||
# via jupyter
|
||||
# via notebook
|
||||
jupyterlab-pygments==0.3.0
|
||||
# via nbconvert
|
||||
jupyterlab-server==2.27.3
|
||||
# via jupyterlab
|
||||
# via notebook
|
||||
jupyterlab-widgets==3.0.13
|
||||
# via ipywidgets
|
||||
jupytext==1.16.4
|
||||
# via euporie
|
||||
kaggle==1.6.17
|
||||
# via symbolic-nn-tests
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
lightning==2.4.0
|
||||
# via symbolic-nn-tests
|
||||
lightning-utilities==0.11.7
|
||||
# via lightning
|
||||
# via pytorch-lightning
|
||||
# via torchmetrics
|
||||
linkify-it-py==1.0.3
|
||||
# via euporie
|
||||
# via panel
|
||||
loguru==0.7.2
|
||||
# via symbolic-nn-tests
|
||||
mako==1.3.5
|
||||
# via alembic
|
||||
markdown==3.7
|
||||
# via panel
|
||||
# via tensorboard
|
||||
markdown-it-py==2.1.0
|
||||
# via euporie
|
||||
# via jupytext
|
||||
# via mdit-py-plugins
|
||||
# via panel
|
||||
# via rich
|
||||
markupsafe==2.1.5
|
||||
# via jinja2
|
||||
# via mako
|
||||
# via nbconvert
|
||||
# via werkzeug
|
||||
marshmallow==3.22.0
|
||||
# via dataclasses-json
|
||||
matplotlib==3.9.2
|
||||
# via matplotlib-backend-kitty
|
||||
matplotlib-backend-kitty==2.1.2
|
||||
# via symbolic-nn-tests
|
||||
matplotlib-inline==0.1.7
|
||||
# via ipykernel
|
||||
# via ipython
|
||||
mdit-py-plugins==0.3.5
|
||||
# via euporie
|
||||
# via jupytext
|
||||
# via panel
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistune==3.0.2
|
||||
# via nbconvert
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
multidict==6.1.0
|
||||
# via aiohttp
|
||||
# via yarl
|
||||
mypy-extensions==1.0.0
|
||||
# via typing-inspect
|
||||
nbclient==0.10.0
|
||||
# via nbconvert
|
||||
nbconvert==7.16.4
|
||||
# via jupyter
|
||||
# via jupyter-server
|
||||
nbformat==5.10.4
|
||||
# via euporie
|
||||
# via jupyter-server
|
||||
# via jupytext
|
||||
# via nbclient
|
||||
# via nbconvert
|
||||
nest-asyncio==1.6.0
|
||||
# via ipykernel
|
||||
networkx==3.3
|
||||
# via torch
|
||||
notebook==7.2.2
|
||||
# via jupyter
|
||||
notebook-shim==0.2.4
|
||||
# via jupyterlab
|
||||
# via notebook
|
||||
numpy==2.1.1
|
||||
# via bokeh
|
||||
# via contourpy
|
||||
# via holoviews
|
||||
# via hvplot
|
||||
# via matplotlib
|
||||
# via optuna
|
||||
# via pandas
|
||||
# via pyarrow
|
||||
# via scikit-learn
|
||||
# via scikit-optimize
|
||||
# via scipy
|
||||
# via tensorboard
|
||||
# via torchmetrics
|
||||
# via torchvision
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
# via nvidia-cudnn-cu12
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==9.1.0.70
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.20.5
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.6.68
|
||||
# via nvidia-cusolver-cu12
|
||||
# via nvidia-cusparse-cu12
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
# via torch
|
||||
optuna==4.0.0
|
||||
# via symbolic-nn-tests
|
||||
overrides==7.7.0
|
||||
# via jupyter-server
|
||||
packaging==24.1
|
||||
# via bokeh
|
||||
# via holoviews
|
||||
# via hvplot
|
||||
# via ipykernel
|
||||
# via jupyter-server
|
||||
# via jupyterlab
|
||||
# via jupyterlab-server
|
||||
# via jupytext
|
||||
# via lightning
|
||||
# via lightning-utilities
|
||||
# via marshmallow
|
||||
# via matplotlib
|
||||
# via nbconvert
|
||||
# via optuna
|
||||
# via plotly
|
||||
# via pytorch-lightning
|
||||
# via scikit-optimize
|
||||
# via tensorboard
|
||||
# via torchmetrics
|
||||
pandas==2.2.2
|
||||
# via bokeh
|
||||
# via holoviews
|
||||
# via hvplot
|
||||
# via panel
|
||||
pandocfilters==1.5.1
|
||||
# via nbconvert
|
||||
panel==1.4.5
|
||||
# via holoviews
|
||||
# via hvplot
|
||||
param==2.1.1
|
||||
# via holoviews
|
||||
# via hvplot
|
||||
# via panel
|
||||
# via pyviz-comms
|
||||
parso==0.8.4
|
||||
# via jedi
|
||||
periodic-table-dataclasses==1.0
|
||||
# via symbolic-nn-tests
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pillow==10.4.0
|
||||
# via bokeh
|
||||
# via euporie
|
||||
# via matplotlib
|
||||
# via timg
|
||||
# via torchvision
|
||||
platformdirs==3.11.0
|
||||
# via euporie
|
||||
# via jupyter-core
|
||||
# via wandb
|
||||
plotly==5.24.0
|
||||
# via symbolic-nn-tests
|
||||
polars==1.6.0
|
||||
# via symbolic-nn-tests
|
||||
prometheus-client==0.20.0
|
||||
# via jupyter-server
|
||||
prompt-toolkit==3.0.47
|
||||
# via euporie
|
||||
# via ipython
|
||||
# via jupyter-console
|
||||
protobuf==5.28.0
|
||||
# via tensorboard
|
||||
# via wandb
|
||||
psutil==6.0.0
|
||||
# via ipykernel
|
||||
# via wandb
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
# via terminado
|
||||
pure-eval==0.2.3
|
||||
# via stack-data
|
||||
pyaml==24.7.0
|
||||
# via scikit-optimize
|
||||
pyarrow==17.0.0
|
||||
# via symbolic-nn-tests
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pygments==2.18.0
|
||||
# via bpython
|
||||
# via euporie
|
||||
# via ipython
|
||||
# via jupyter-console
|
||||
# via nbconvert
|
||||
# via rich
|
||||
# via snoop
|
||||
pyparsing==3.1.4
|
||||
# via matplotlib
|
||||
pyperclip==1.9.0
|
||||
# via euporie
|
||||
pysocks==1.7.1
|
||||
# via requests
|
||||
python-dateutil==2.9.0.post0
|
||||
# via arrow
|
||||
# via jupyter-client
|
||||
# via kaggle
|
||||
# via matplotlib
|
||||
# via pandas
|
||||
python-json-logger==2.0.7
|
||||
# via jupyter-events
|
||||
python-slugify==8.0.4
|
||||
# via kaggle
|
||||
pytorch-lightning==2.4.0
|
||||
# via lightning
|
||||
pytz==2024.2
|
||||
# via pandas
|
||||
pyviz-comms==3.0.3
|
||||
# via holoviews
|
||||
# via panel
|
||||
pyxdg==0.28
|
||||
# via bpython
|
||||
pyyaml==6.0.2
|
||||
# via bokeh
|
||||
# via jupyter-events
|
||||
# via jupytext
|
||||
# via lightning
|
||||
# via optuna
|
||||
# via pyaml
|
||||
# via pytorch-lightning
|
||||
# via wandb
|
||||
pyzmq==26.2.0
|
||||
# via ipykernel
|
||||
# via jupyter-client
|
||||
# via jupyter-console
|
||||
# via jupyter-server
|
||||
referencing==0.35.1
|
||||
# via jsonschema
|
||||
# via jsonschema-specifications
|
||||
# via jupyter-events
|
||||
regex==2024.7.24
|
||||
# via flatlatex
|
||||
requests==2.32.3
|
||||
# via bpython
|
||||
# via gdown
|
||||
# via jupyterlab-server
|
||||
# via kaggle
|
||||
# via panel
|
||||
# via wandb
|
||||
rfc3339-validator==0.1.4
|
||||
# via jsonschema
|
||||
# via jupyter-events
|
||||
rfc3986-validator==0.1.1
|
||||
# via jsonschema
|
||||
# via jupyter-events
|
||||
rich==13.3.1
|
||||
# via typer
|
||||
rpds-py==0.20.0
|
||||
# via jsonschema
|
||||
# via referencing
|
||||
safetensors==0.4.5
|
||||
# via symbolic-nn-tests
|
||||
scikit-learn==1.5.1
|
||||
# via scikit-optimize
|
||||
scikit-optimize==0.10.2
|
||||
# via symbolic-nn-tests
|
||||
scipy==1.14.1
|
||||
# via scikit-learn
|
||||
# via scikit-optimize
|
||||
send2trash==1.8.3
|
||||
# via jupyter-server
|
||||
sentry-sdk==2.14.0
|
||||
# via wandb
|
||||
setproctitle==1.3.3
|
||||
# via wandb
|
||||
setuptools==74.1.2
|
||||
# via jupyterlab
|
||||
# via lightning-utilities
|
||||
# via symbolic-nn-tests
|
||||
# via tensorboard
|
||||
# via torch
|
||||
# via wandb
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.16.0
|
||||
# via asttokens
|
||||
# via bleach
|
||||
# via blessed
|
||||
# via docker-pycreds
|
||||
# via kaggle
|
||||
# via python-dateutil
|
||||
# via rfc3339-validator
|
||||
# via snoop
|
||||
# via tensorboard
|
||||
sixelcrop==0.1.7
|
||||
# via euporie
|
||||
smmap==5.0.1
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
# via httpx
|
||||
snoop==0.4.3
|
||||
# via symbolic-nn-tests
|
||||
soupsieve==2.6
|
||||
# via beautifulsoup4
|
||||
sqlalchemy==2.0.34
|
||||
# via alembic
|
||||
# via optuna
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
sympy==1.13.2
|
||||
# via torch
|
||||
tenacity==9.0.0
|
||||
# via plotly
|
||||
tensorboard==2.17.1
|
||||
# via symbolic-nn-tests
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
terminado==0.18.1
|
||||
# via jupyter-server
|
||||
# via jupyter-server-terminals
|
||||
text-unidecode==1.3
|
||||
# via python-slugify
|
||||
threadpoolctl==3.5.0
|
||||
# via scikit-learn
|
||||
timg==1.1.6
|
||||
# via euporie
|
||||
tinycss2==1.3.0
|
||||
# via nbconvert
|
||||
torch==2.4.1
|
||||
# via lightning
|
||||
# via pytorch-lightning
|
||||
# via symbolic-nn-tests
|
||||
# via torchmetrics
|
||||
# via torchvision
|
||||
torchmetrics==1.4.1
|
||||
# via lightning
|
||||
# via pytorch-lightning
|
||||
torchvision==0.19.1
|
||||
# via symbolic-nn-tests
|
||||
tornado==6.4.1
|
||||
# via bokeh
|
||||
# via ipykernel
|
||||
# via jupyter-client
|
||||
# via jupyter-server
|
||||
# via jupyterlab
|
||||
# via notebook
|
||||
# via terminado
|
||||
tqdm==4.66.5
|
||||
# via gdown
|
||||
# via kaggle
|
||||
# via lightning
|
||||
# via optuna
|
||||
# via panel
|
||||
# via pytorch-lightning
|
||||
traitlets==5.14.3
|
||||
# via comm
|
||||
# via ipykernel
|
||||
# via ipython
|
||||
# via ipywidgets
|
||||
# via jupyter-client
|
||||
# via jupyter-console
|
||||
# via jupyter-core
|
||||
# via jupyter-events
|
||||
# via jupyter-server
|
||||
# via jupyterlab
|
||||
# via matplotlib-inline
|
||||
# via nbclient
|
||||
# via nbconvert
|
||||
# via nbformat
|
||||
triton==3.0.0
|
||||
# via torch
|
||||
typer==0.12.5
|
||||
# via symbolic-nn-tests
|
||||
types-python-dateutil==2.9.0.20240906
|
||||
# via arrow
|
||||
typing-extensions==4.12.2
|
||||
# via alembic
|
||||
# via euporie
|
||||
# via lightning
|
||||
# via lightning-utilities
|
||||
# via panel
|
||||
# via pytorch-lightning
|
||||
# via sqlalchemy
|
||||
# via torch
|
||||
# via typer
|
||||
# via typing-inspect
|
||||
typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
tzdata==2024.1
|
||||
# via pandas
|
||||
uc-micro-py==1.0.3
|
||||
# via linkify-it-py
|
||||
universal-pathlib==0.2.5
|
||||
# via euporie
|
||||
uri-template==1.3.0
|
||||
# via jsonschema
|
||||
urllib3==2.2.2
|
||||
# via kaggle
|
||||
# via requests
|
||||
# via sentry-sdk
|
||||
wandb==0.17.9
|
||||
# via symbolic-nn-tests
|
||||
wcwidth==0.2.13
|
||||
# via blessed
|
||||
# via prompt-toolkit
|
||||
webcolors==24.8.0
|
||||
# via jsonschema
|
||||
webencodings==0.5.1
|
||||
# via bleach
|
||||
# via tinycss2
|
||||
websocket-client==1.8.0
|
||||
# via jupyter-server
|
||||
werkzeug==3.0.4
|
||||
# via tensorboard
|
||||
widgetsnbextension==4.0.13
|
||||
# via ipywidgets
|
||||
xyzservices==2024.9.0
|
||||
# via bokeh
|
||||
# via panel
|
||||
yarl==1.11.1
|
||||
# via aiohttp
|
||||
11
src/symbolic_nn_tests/__init__.py
Normal file
11
src/symbolic_nn_tests/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from . import __main__
|
||||
|
||||
import typer
|
||||
import ssl
|
||||
|
||||
|
||||
ssl._create_default_https_context = ssl._create_unverified_context
|
||||
|
||||
|
||||
def main():
|
||||
typer.run(__main__.main)
|
||||
@@ -4,7 +4,7 @@ from torchvision.transforms import ToTensor
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
||||
DATASET_DIR = PROJECT_ROOT / "datasets/"
|
||||
|
||||
|
||||
93
src/symbolic_nn_tests/local/__init__.py
Normal file
93
src/symbolic_nn_tests/local/__init__.py
Normal file
@@ -0,0 +1,93 @@
|
||||
LEARNING_RATE = 10e-5
|
||||
|
||||
|
||||
def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True):
|
||||
from .model import main as test_model
|
||||
|
||||
logger = []
|
||||
|
||||
if tensorboard:
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
tb_logger = TensorBoardLogger(
|
||||
save_dir=".",
|
||||
name="logs/comparison",
|
||||
version=version,
|
||||
)
|
||||
logger.append(tb_logger)
|
||||
|
||||
if wandb:
|
||||
import wandb as _wandb
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
|
||||
wandb_logger = WandbLogger(
|
||||
project="Symbolic_NN_Tests",
|
||||
name=version,
|
||||
dir="wandb",
|
||||
)
|
||||
logger.append(wandb_logger)
|
||||
|
||||
test_model(
|
||||
logger=logger,
|
||||
train_loss=train_loss,
|
||||
val_loss=val_loss,
|
||||
test_loss=test_loss,
|
||||
lr=LEARNING_RATE,
|
||||
)
|
||||
|
||||
if wandb:
|
||||
_wandb.finish()
|
||||
|
||||
|
||||
def run(tensorboard: bool = True, wandb: bool = True):
|
||||
from . import semantic_loss
|
||||
from .model import oh_vs_cat_cross_entropy
|
||||
|
||||
test(
|
||||
train_loss=oh_vs_cat_cross_entropy,
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
version="cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
train_loss=semantic_loss.similarity_cross_entropy,
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
version="similarity_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
train_loss=semantic_loss.hasline_cross_entropy,
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
version="hasline_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
train_loss=semantic_loss.hasloop_cross_entropy,
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
version="hasloop_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
train_loss=semantic_loss.multisemantic_cross_entropy,
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
version="multisemantic_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
test(
|
||||
train_loss=semantic_loss.garbage_cross_entropy,
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
version="garbage_cross_entropy",
|
||||
tensorboard=tensorboard,
|
||||
wandb=wandb,
|
||||
)
|
||||
72
src/symbolic_nn_tests/local/model.py
Normal file
72
src/symbolic_nn_tests/local/model.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from functools import lru_cache
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
model = nn.Sequential(
|
||||
nn.Flatten(1, -1),
|
||||
nn.Linear(784, 10),
|
||||
nn.Softmax(dim=1),
|
||||
)
|
||||
|
||||
|
||||
def collate(batch):
|
||||
x, y = zip(*batch)
|
||||
x = [i[0] for i in x]
|
||||
y = [torch.tensor(i) for i in y]
|
||||
x = torch.stack(x)
|
||||
y = torch.tensor(y)
|
||||
return x, y
|
||||
|
||||
|
||||
# This is just a quick, lazy way to ensure all models are trained on the same dataset
|
||||
@lru_cache(maxsize=1)
|
||||
def get_singleton_dataset():
|
||||
from torchvision.datasets import QMNIST
|
||||
|
||||
from symbolic_nn_tests.dataloader import create_dataset
|
||||
|
||||
return create_dataset(
|
||||
dataset=QMNIST,
|
||||
collate_fn=collate,
|
||||
batch_size=128,
|
||||
shuffle_train=True,
|
||||
num_workers=11,
|
||||
)
|
||||
|
||||
|
||||
def oh_vs_cat_cross_entropy(y_bin, y_cat):
|
||||
return nn.functional.cross_entropy(
|
||||
y_bin,
|
||||
nn.functional.one_hot(y_cat, num_classes=10).float(),
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
train_loss=oh_vs_cat_cross_entropy,
|
||||
val_loss=oh_vs_cat_cross_entropy,
|
||||
test_loss=oh_vs_cat_cross_entropy,
|
||||
logger=None,
|
||||
**kwargs,
|
||||
):
|
||||
import lightning as L
|
||||
|
||||
from symbolic_nn_tests.train import TrainingWrapper
|
||||
|
||||
if logger is None:
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
logger = TensorBoardLogger(save_dir=".", name="logs/ffnn")
|
||||
|
||||
train, val, test = get_singleton_dataset()
|
||||
lmodel = TrainingWrapper(
|
||||
model, train_loss=train_loss, val_loss=val_loss, test_loss=test_loss
|
||||
)
|
||||
lmodel.configure_optimizers(**kwargs)
|
||||
trainer = L.Trainer(max_epochs=20, logger=logger)
|
||||
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||
trainer.test(dataloaders=test)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
88
src/symbolic_nn_tests/local/semantic_loss.py
Normal file
88
src/symbolic_nn_tests/local/semantic_loss.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import torch
|
||||
|
||||
|
||||
def create_semantic_cross_entropy(semantic_matrix):
|
||||
def semantic_cross_entropy(input, target):
|
||||
ce_loss = torch.nn.functional.cross_entropy(input, target)
|
||||
|
||||
penalty_tensor = semantic_matrix[target.argmax(dim=1)]
|
||||
abs_diff = (target - input).abs()
|
||||
semantic_penalty = (abs_diff * penalty_tensor).sum()
|
||||
return ce_loss * semantic_penalty
|
||||
|
||||
def oh_vs_cat_semantic_cross_entropy(input_oh, target_cat):
|
||||
return semantic_cross_entropy(
|
||||
input_oh, torch.nn.functional.one_hot(target_cat, num_classes=10).float()
|
||||
)
|
||||
|
||||
return oh_vs_cat_semantic_cross_entropy
|
||||
|
||||
|
||||
# NOTE: This similarity matrix defines loss scaling factors for misclassification
|
||||
# of numbers from our QMNIST dataset. Visually similar numbers (e.g: 3/8) are
|
||||
# penalised less harshly than visually distinct numbers as this mistake is "less
|
||||
# mistaken" given our understanding of the visual characteristics of numerals.
|
||||
# By using this scaling matric we can inject human knowledge into the model via
|
||||
# the loss function, making this an example of a "semantic loss function"
|
||||
SIMILARITY_MATRIX = torch.tensor(
|
||||
[
|
||||
[2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
[1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0],
|
||||
[1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0],
|
||||
[1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.5, 1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 2.0, 1.0, 1.0, 1.0],
|
||||
[1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0],
|
||||
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
|
||||
]
|
||||
).to("cuda")
|
||||
SIMILARITY_MATRIX /= SIMILARITY_MATRIX.sum() # Normalized to sum of 1
|
||||
|
||||
similarity_cross_entropy = create_semantic_cross_entropy(SIMILARITY_MATRIX)
|
||||
|
||||
|
||||
# NOTE: The following matrix encodes a simpler semantic penalty for correctly/incorrectly
|
||||
# identifying shapes with straight lines in their representation. This can be a bit fuzzy
|
||||
# in cases like "9" though.
|
||||
HASLINE_MATRIX = torch.tensor(
|
||||
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
|
||||
[False, True, False, False, True, True, False, True, False, True]
|
||||
).to("cuda")
|
||||
HASLINE_MATRIX = torch.stack([i ^ HASLINE_MATRIX for i in HASLINE_MATRIX]).type(
|
||||
torch.float64
|
||||
)
|
||||
HASLINE_MATRIX += 1
|
||||
HASLINE_MATRIX /= HASLINE_MATRIX.sum() # Normalize to sum of 1
|
||||
|
||||
hasline_cross_entropy = create_semantic_cross_entropy(HASLINE_MATRIX)
|
||||
|
||||
|
||||
# NOTE: Similarly, we can do the same for closed circular loops in a numeric character
|
||||
HASLOOP_MATRIX = torch.tensor(
|
||||
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
|
||||
[True, False, False, False, False, False, True, False, True, True]
|
||||
).to("cuda")
|
||||
HASLOOP_MATRIX = torch.stack([i ^ HASLOOP_MATRIX for i in HASLOOP_MATRIX]).type(
|
||||
torch.float64
|
||||
)
|
||||
HASLOOP_MATRIX += 1
|
||||
HASLOOP_MATRIX /= HASLOOP_MATRIX.sum() # Normalize to sum of 1
|
||||
|
||||
hasloop_cross_entropy = create_semantic_cross_entropy(HASLOOP_MATRIX)
|
||||
|
||||
|
||||
# NOTE: We can also combine all of these semantic matrices
|
||||
MULTISEMANTIC_MATRIX = SIMILARITY_MATRIX * HASLINE_MATRIX * HASLOOP_MATRIX
|
||||
MULTISEMANTIC_MATRIX /= MULTISEMANTIC_MATRIX.sum()
|
||||
|
||||
multisemantic_cross_entropy = create_semantic_cross_entropy(MULTISEMANTIC_MATRIX)
|
||||
|
||||
# NOTE: As a final test, lets make something similar to tehse but where there's no knowledge,
|
||||
# just random data. This will create a benchmark for the effects of this process wothout the
|
||||
# "knowledge" component
|
||||
GARBAGE_MATRIX = torch.rand(10, 10).to("cuda")
|
||||
GARBAGE_MATRIX /= GARBAGE_MATRIX.sum()
|
||||
|
||||
garbage_cross_entropy = create_semantic_cross_entropy(GARBAGE_MATRIX)
|
||||
Reference in New Issue
Block a user