mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 14:11:59 +00:00
First working attempt at semantic loss for expt2
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -164,6 +164,7 @@ cython_debug/
|
|||||||
.jj
|
.jj
|
||||||
scratchpad.ipynb
|
scratchpad.ipynb
|
||||||
datasets/
|
datasets/
|
||||||
|
explore/
|
||||||
lightning_logs/
|
lightning_logs/
|
||||||
logs/
|
logs/
|
||||||
wandb/
|
wandb/
|
||||||
|
|||||||
1
.mailmap
Normal file
1
.mailmap
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Cian Hughes <cian.hughes@dcu.ie> <chughes000@gmail.com>
|
||||||
345
poetry.lock
generated
345
poetry.lock
generated
@@ -412,6 +412,28 @@ jinxed = {version = ">=1.1.0", markers = "platform_system == \"Windows\""}
|
|||||||
six = ">=1.9.0"
|
six = ">=1.9.0"
|
||||||
wcwidth = ">=0.1.4"
|
wcwidth = ">=0.1.4"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bokeh"
|
||||||
|
version = "3.4.1"
|
||||||
|
description = "Interactive plots and applications in the browser from Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.9"
|
||||||
|
files = [
|
||||||
|
{file = "bokeh-3.4.1-py3-none-any.whl", hash = "sha256:1e3c502a0a8205338fc74dadbfa321f8a0965441b39501e36796a47b4017b642"},
|
||||||
|
{file = "bokeh-3.4.1.tar.gz", hash = "sha256:d824961e4265367b0750ce58b07e564ad0b83ca64b335521cd3421e9b9f10d89"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
contourpy = ">=1.2"
|
||||||
|
Jinja2 = ">=2.9"
|
||||||
|
numpy = ">=1.16"
|
||||||
|
packaging = ">=16.8"
|
||||||
|
pandas = ">=1.2"
|
||||||
|
pillow = ">=7.1.0"
|
||||||
|
PyYAML = ">=3.10"
|
||||||
|
tornado = ">=6.2"
|
||||||
|
xyzservices = ">=2021.09.1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bpython"
|
name = "bpython"
|
||||||
version = "0.24"
|
version = "0.24"
|
||||||
@@ -636,6 +658,25 @@ files = [
|
|||||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "colorcet"
|
||||||
|
version = "3.1.0"
|
||||||
|
description = "Collection of perceptually uniform colormaps"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "colorcet-3.1.0-py3-none-any.whl", hash = "sha256:2a7d59cc8d0f7938eeedd08aad3152b5319b4ba3bcb7a612398cc17a384cb296"},
|
||||||
|
{file = "colorcet-3.1.0.tar.gz", hash = "sha256:2921b3cd81a2288aaf2d63dbc0ce3c26dcd882e8c389cc505d6886bf7aa9a4eb"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
all = ["colorcet[doc]", "colorcet[examples]", "colorcet[tests-extra]", "colorcet[tests]"]
|
||||||
|
doc = ["colorcet[examples]", "nbsite (>=0.8.4)", "sphinx-copybutton"]
|
||||||
|
examples = ["bokeh", "holoviews", "matplotlib", "numpy"]
|
||||||
|
tests = ["packaging", "pre-commit", "pytest (>=2.8.5)", "pytest-cov"]
|
||||||
|
tests-examples = ["colorcet[examples]", "nbval"]
|
||||||
|
tests-extra = ["colorcet[tests]", "pytest-mpl"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "colorlog"
|
name = "colorlog"
|
||||||
version = "6.8.2"
|
version = "6.8.2"
|
||||||
@@ -1392,6 +1433,43 @@ files = [
|
|||||||
{file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"},
|
{file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "holoviews"
|
||||||
|
version = "1.18.3"
|
||||||
|
description = "Stop plotting your data - annotate your data and let it visualize itself."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.9"
|
||||||
|
files = [
|
||||||
|
{file = "holoviews-1.18.3-py2.py3-none-any.whl", hash = "sha256:b94b96560b64a84c07e89115aaf9b226e6009684800ec84d3c88cbad122c0c46"},
|
||||||
|
{file = "holoviews-1.18.3.tar.gz", hash = "sha256:578e30e89d72754f97a83ebe08198fec8e87cc7e49b25b9f31ec393f939ca500"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
colorcet = "*"
|
||||||
|
numpy = ">=1.0"
|
||||||
|
packaging = "*"
|
||||||
|
pandas = ">=0.20.0"
|
||||||
|
panel = ">=1.0"
|
||||||
|
param = ">=1.12.0,<3.0"
|
||||||
|
pyviz-comms = ">=0.7.4"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
all = ["bokeh (>=3.1)", "cftime", "codecov", "contourpy", "cudf", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "graphviz", "ibis-framework", "ipython (>=5.4.0)", "matplotlib (>=3)", "myst-nb (<1)", "nbconvert", "nbsite (>=0.8.4,<0.9.0)", "nbval", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "playwright", "plotly (>=4.0)", "pooch", "pre-commit", "pyarrow", "pytest", "pytest-cov", "pytest-github-actions-annotate-failures", "pytest-playwright", "pytest-rerunfailures", "pytest-xdist", "ruff", "scikit-image", "scipy", "scipy (>=1.10)", "selenium", "shapely", "spatialpandas", "streamz (>=0.5.0)", "xarray (>=0.10.4)"]
|
||||||
|
build = ["param (>=1.7.0)", "pyct (>=0.4.4)", "setuptools (>=30.3.0)"]
|
||||||
|
doc = ["bokeh (>=3.1)", "cftime", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "graphviz", "ipython (>=5.4.0)", "matplotlib (>=3)", "myst-nb (<1)", "nbsite (>=0.8.4,<0.9.0)", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "plotly (>=4.0)", "pooch", "pyarrow", "scikit-image", "scipy", "selenium", "shapely", "streamz (>=0.5.0)", "xarray (>=0.10.4)"]
|
||||||
|
examples = ["bokeh (>=3.1)", "cftime", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ipython (>=5.4.0)", "matplotlib (>=3)", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "plotly (>=4.0)", "pooch", "pyarrow", "scikit-image", "scipy", "shapely", "streamz (>=0.5.0)", "xarray (>=0.10.4)"]
|
||||||
|
examples-tests = ["bokeh (>=3.1)", "cftime", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbval", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "plotly (>=4.0)", "pooch", "pyarrow", "scikit-image", "scipy", "shapely", "streamz (>=0.5.0)", "xarray (>=0.10.4)"]
|
||||||
|
lint = ["pre-commit", "ruff"]
|
||||||
|
notebook = ["ipython (>=5.4.0)", "notebook"]
|
||||||
|
recommended = ["bokeh (>=3.1)", "ipython (>=5.4.0)", "matplotlib (>=3)", "notebook"]
|
||||||
|
tests = ["bokeh (>=3.1)", "cftime", "contourpy", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ibis-framework", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbconvert", "networkx", "pillow", "plotly (>=4.0)", "pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "scipy (>=1.10)", "selenium", "shapely", "spatialpandas", "xarray (>=0.10.4)"]
|
||||||
|
tests-ci = ["codecov", "pytest-github-actions-annotate-failures"]
|
||||||
|
tests-core = ["bokeh (>=3.1)", "contourpy", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbconvert", "pillow", "plotly (>=4.0)", "pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist"]
|
||||||
|
tests-gpu = ["bokeh (>=3.1)", "cftime", "contourpy", "cudf", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ibis-framework", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbconvert", "networkx", "pillow", "plotly (>=4.0)", "pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "scipy (>=1.10)", "selenium", "shapely", "spatialpandas", "xarray (>=0.10.4)"]
|
||||||
|
tests-nb = ["nbval"]
|
||||||
|
ui = ["playwright", "pytest-playwright"]
|
||||||
|
unit-tests = ["bokeh (>=3.1)", "cftime", "contourpy", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ibis-framework", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbconvert", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "plotly (>=4.0)", "pooch", "pre-commit", "pyarrow", "pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "ruff", "scikit-image", "scipy", "scipy (>=1.10)", "selenium", "shapely", "spatialpandas", "streamz (>=0.5.0)", "xarray (>=0.10.4)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "httpcore"
|
name = "httpcore"
|
||||||
version = "1.0.5"
|
version = "1.0.5"
|
||||||
@@ -1437,6 +1515,41 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
|
|||||||
http2 = ["h2 (>=3,<5)"]
|
http2 = ["h2 (>=3,<5)"]
|
||||||
socks = ["socksio (==1.*)"]
|
socks = ["socksio (==1.*)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hvplot"
|
||||||
|
version = "0.10.0"
|
||||||
|
description = "A high-level plotting API for the PyData ecosystem built on HoloViews."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "hvplot-0.10.0-py3-none-any.whl", hash = "sha256:fe90ccb48163a6a62ae5bd6b008c2cb15cbf5b276f6ad6839ef5470b1c480d16"},
|
||||||
|
{file = "hvplot-0.10.0.tar.gz", hash = "sha256:e87486a95bfe151ab52ef163a5e93d9cbd043992cf0b755ccadd2bf36fedd376"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
bokeh = ">=1.0.0"
|
||||||
|
colorcet = ">=2"
|
||||||
|
holoviews = ">=1.11.0"
|
||||||
|
numpy = ">=1.15"
|
||||||
|
packaging = "*"
|
||||||
|
pandas = "*"
|
||||||
|
panel = ">=0.11.0"
|
||||||
|
param = ">=1.12.0,<3.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev-extras = ["setuptools-scm (>=6)"]
|
||||||
|
doc = ["hvplot[examples]", "nbsite (>=0.8.4)", "sphinxext-rediraffe"]
|
||||||
|
examples = ["dask[dataframe] (>=2021.3.0)", "datashader (>=0.6.5)", "fugue[sql]", "geodatasets (>=2023.12.0)", "hvplot[fugue-sql]", "ibis-framework[duckdb]", "intake (>=0.6.5,<2.0.0)", "intake-parquet (>=0.2.3)", "intake-xarray (>=0.5.0)", "ipywidgets", "matplotlib", "networkx (>=2.6.3)", "notebook (>=5.4)", "numba (>=0.51.0)", "pillow (>=8.2.0)", "plotly", "polars", "pooch (>=1.6.0)", "s3fs (>=2022.1.0)", "scikit-image (>=0.17.2)", "scipy (>=1.5.3)", "selenium (>=3.141.0)", "streamz (>=0.3.0)", "xarray (>=0.18.2)", "xyzservices (>=2022.9.0)"]
|
||||||
|
examples-tests = ["hvplot[examples]", "hvplot[tests-nb]"]
|
||||||
|
fugue-sql = ["fugue-sql-antlr (>=0.2.0)", "jinja2", "qpd (>=0.4.4)", "sqlglot"]
|
||||||
|
geo = ["cartopy", "fiona", "geopandas", "geoviews (>=1.9.0)", "pyproj", "rasterio", "rioxarray", "spatialpandas (>=0.4.3)"]
|
||||||
|
graphviz = ["pygraphviz"]
|
||||||
|
hvdev = ["colorcet (>=0.0.1a1)", "datashader (>=0.0.1a1)", "holoviews (>=0.0.1a1)", "panel (>=0.0.1a1)", "param (>=0.0.1a1)", "pyviz-comms (>=0.0.1a1)"]
|
||||||
|
hvdev-geo = ["geoviews (>=0.0.1a1)"]
|
||||||
|
tests = ["fugue[sql]", "hvplot[fugue-sql]", "hvplot[tests-core]", "ibis-framework[duckdb]", "polars"]
|
||||||
|
tests-core = ["dask[dataframe]", "ipywidgets", "matplotlib", "parameterized", "plotly", "pooch", "pre-commit", "pytest", "pytest-cov", "ruff", "scipy", "xarray"]
|
||||||
|
tests-nb = ["nbval", "pytest-xdist"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "idna"
|
name = "idna"
|
||||||
version = "3.7"
|
version = "3.7"
|
||||||
@@ -3001,6 +3114,78 @@ files = [
|
|||||||
{file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
|
{file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pandas"
|
||||||
|
version = "2.2.2"
|
||||||
|
description = "Powerful data structures for data analysis, time series, and statistics"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.9"
|
||||||
|
files = [
|
||||||
|
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
|
||||||
|
{file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
|
||||||
|
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
|
||||||
|
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
|
||||||
|
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
|
||||||
|
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"},
|
||||||
|
{file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"},
|
||||||
|
{file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"},
|
||||||
|
{file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"},
|
||||||
|
{file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"},
|
||||||
|
{file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"},
|
||||||
|
{file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"},
|
||||||
|
{file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"},
|
||||||
|
{file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"},
|
||||||
|
{file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"},
|
||||||
|
{file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"},
|
||||||
|
{file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"},
|
||||||
|
{file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"},
|
||||||
|
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"},
|
||||||
|
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
|
||||||
|
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
|
||||||
|
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
|
||||||
|
{file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
|
||||||
|
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
|
||||||
|
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
|
||||||
|
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
|
||||||
|
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92fd6b027924a7e178ac202cfbe25e53368db90d56872d20ffae94b96c7acc57"},
|
||||||
|
{file = "pandas-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:640cef9aa381b60e296db324337a554aeeb883ead99dc8f6c18e81a93942f5f4"},
|
||||||
|
{file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
numpy = [
|
||||||
|
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||||
|
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||||
|
]
|
||||||
|
python-dateutil = ">=2.8.2"
|
||||||
|
pytz = ">=2020.1"
|
||||||
|
tzdata = ">=2022.7"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"]
|
||||||
|
aws = ["s3fs (>=2022.11.0)"]
|
||||||
|
clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"]
|
||||||
|
compression = ["zstandard (>=0.19.0)"]
|
||||||
|
computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"]
|
||||||
|
consortium-standard = ["dataframe-api-compat (>=0.1.7)"]
|
||||||
|
excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"]
|
||||||
|
feather = ["pyarrow (>=10.0.1)"]
|
||||||
|
fss = ["fsspec (>=2022.11.0)"]
|
||||||
|
gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"]
|
||||||
|
hdf5 = ["tables (>=3.8.0)"]
|
||||||
|
html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"]
|
||||||
|
mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"]
|
||||||
|
output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"]
|
||||||
|
parquet = ["pyarrow (>=10.0.1)"]
|
||||||
|
performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"]
|
||||||
|
plot = ["matplotlib (>=3.6.3)"]
|
||||||
|
postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"]
|
||||||
|
pyarrow = ["pyarrow (>=10.0.1)"]
|
||||||
|
spss = ["pyreadstat (>=1.2.0)"]
|
||||||
|
sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"]
|
||||||
|
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
|
||||||
|
xml = ["lxml (>=4.9.2)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pandocfilters"
|
name = "pandocfilters"
|
||||||
version = "1.5.1"
|
version = "1.5.1"
|
||||||
@@ -3012,6 +3197,64 @@ files = [
|
|||||||
{file = "pandocfilters-1.5.1.tar.gz", hash = "sha256:002b4a555ee4ebc03f8b66307e287fa492e4a77b4ea14d3f934328297bb4939e"},
|
{file = "pandocfilters-1.5.1.tar.gz", hash = "sha256:002b4a555ee4ebc03f8b66307e287fa492e4a77b4ea14d3f934328297bb4939e"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "panel"
|
||||||
|
version = "1.4.4"
|
||||||
|
description = "The powerful data exploration & web app framework for Python."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.9"
|
||||||
|
files = [
|
||||||
|
{file = "panel-1.4.4-py3-none-any.whl", hash = "sha256:b49bb9676567b0c0730bf69348c057247080811aec56364dd4fcfba80e5e09a0"},
|
||||||
|
{file = "panel-1.4.4.tar.gz", hash = "sha256:659e9fc5b495e6519c5d07e8148fa5eeed9bc648356ec83fc299381ba5a726ef"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
bleach = "*"
|
||||||
|
bokeh = ">=3.4.0,<3.5.0"
|
||||||
|
linkify-it-py = "*"
|
||||||
|
markdown = "*"
|
||||||
|
markdown-it-py = "*"
|
||||||
|
mdit-py-plugins = "*"
|
||||||
|
pandas = ">=1.2"
|
||||||
|
param = ">=2.1.0,<3.0"
|
||||||
|
pyviz-comms = ">=2.0.0"
|
||||||
|
requests = "*"
|
||||||
|
tqdm = ">=4.48.0"
|
||||||
|
typing-extensions = "*"
|
||||||
|
xyzservices = ">=2021.09.1"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
all = ["aiohttp", "altair", "anywidget", "channels", "croniter", "dask-expr", "datashader", "diskcache", "django (<4)", "fastparquet", "flake8", "folium", "graphviz", "holoviews (>=1.16.0)", "hvplot", "ipyleaflet", "ipympl", "ipython (>=7.0)", "ipyvolume", "ipyvuetify", "ipywidgets", "ipywidgets-bokeh", "jupyter-bokeh (>=3.0.7)", "jupyter-server", "jupyterlab", "lxml", "matplotlib", "nbsite (>=0.8.4)", "nbval", "networkx (>=2.5)", "numba (<0.58)", "numpy", "pandas (<2.1.0)", "pandas (>=1.3)", "parameterized", "pillow", "playwright", "plotly", "plotly (>=4.0)", "pre-commit", "psutil", "pydeck", "pygraphviz", "pyinstrument (>=4.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-playwright", "pytest-rerunfailures", "pytest-xdist", "python-graphviz", "pyvista", "reacton", "scikit-image", "scikit-learn", "scipy", "seaborn", "streamz", "textual", "tomli", "twine", "vega-datasets", "vtk", "watchfiles", "xarray", "xgboost"]
|
||||||
|
all-pip = ["aiohttp", "altair", "anywidget", "channels", "croniter", "dask-expr", "datashader", "diskcache", "django (<4)", "fastparquet", "flake8", "folium", "graphviz", "holoviews (>=1.16.0)", "hvplot", "ipyleaflet", "ipympl", "ipython (>=7.0)", "ipyvolume", "ipyvuetify", "ipywidgets", "ipywidgets-bokeh", "jupyter-bokeh (>=3.0.7)", "jupyter-server", "jupyterlab", "lxml", "matplotlib", "nbsite (>=0.8.4)", "nbval", "networkx (>=2.5)", "numba (<0.58)", "numpy", "pandas (<2.1.0)", "pandas (>=1.3)", "parameterized", "pillow", "playwright", "plotly", "plotly (>=4.0)", "pre-commit", "psutil", "pydeck", "pyinstrument (>=4.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-playwright", "pytest-rerunfailures", "pytest-xdist", "pyvista", "reacton", "scikit-image", "scikit-learn", "scipy", "seaborn", "streamz", "textual", "tomli", "twine", "vega-datasets", "vtk", "watchfiles", "xarray", "xgboost"]
|
||||||
|
build = ["bleach", "bokeh (>=3.4.0,<3.5.0)", "cryptography (<39)", "markdown", "packaging", "param (>=2.0.0)", "pyviz-comms (>=2.0.0)", "requests", "setuptools (>=42)", "tqdm (>=4.48.0)", "urllib3 (<2.0)"]
|
||||||
|
doc = ["holoviews (>=1.16.0)", "jupyterlab", "lxml", "matplotlib", "nbsite (>=0.8.4)", "pandas (<2.1.0)", "pillow", "plotly"]
|
||||||
|
examples = ["aiohttp", "altair", "channels", "croniter", "dask-expr", "datashader", "django (<4)", "fastparquet", "folium", "graphviz", "holoviews (>=1.16.0)", "hvplot", "ipyleaflet", "ipympl", "ipyvolume", "ipyvuetify", "ipywidgets", "ipywidgets-bokeh", "jupyter-bokeh (>=3.0.7)", "networkx (>=2.5)", "plotly (>=4.0)", "pydeck", "pygraphviz", "pyinstrument (>=4.0)", "python-graphviz", "pyvista", "reacton", "scikit-image", "scikit-learn", "seaborn", "streamz", "textual", "vega-datasets", "vtk", "xarray", "xgboost"]
|
||||||
|
recommended = ["holoviews (>=1.16.0)", "jupyterlab", "matplotlib", "pillow", "plotly"]
|
||||||
|
tests = ["altair", "anywidget", "diskcache", "flake8", "folium", "holoviews (>=1.16.0)", "ipympl", "ipython (>=7.0)", "ipyvuetify", "ipywidgets-bokeh", "nbval", "numba (<0.58)", "numpy", "pandas (>=1.3)", "parameterized", "pre-commit", "psutil", "pytest", "pytest-asyncio", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "reacton", "scipy", "textual", "twine", "watchfiles"]
|
||||||
|
tests-core = ["altair", "anywidget", "diskcache", "flake8", "folium", "holoviews (>=1.16.0)", "ipython (>=7.0)", "nbval", "numpy", "pandas (>=1.3)", "parameterized", "pre-commit", "psutil", "pytest", "pytest-asyncio", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "scipy", "textual", "watchfiles"]
|
||||||
|
ui = ["jupyter-server", "playwright", "pytest-playwright", "tomli"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "param"
|
||||||
|
version = "2.1.0"
|
||||||
|
description = "Make your Python code clearer and more reliable by declaring Parameters."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "param-2.1.0-py3-none-any.whl", hash = "sha256:f31d3745d227347d29b5868c4e4e3077df07463889b91d3bb28e634fde211e1c"},
|
||||||
|
{file = "param-2.1.0.tar.gz", hash = "sha256:a7b30b08b547e2b78b02aeba6ed34e3c6a638f8e4824a76a96ffa2d7cf57e71f"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
all = ["param[doc]", "param[lint]", "param[tests-full]"]
|
||||||
|
doc = ["nbsite (==0.8.4)", "param[examples]", "sphinx-remove-toctrees"]
|
||||||
|
examples = ["aiohttp", "pandas", "panel"]
|
||||||
|
lint = ["flake8", "pre-commit"]
|
||||||
|
tests = ["coverage[toml]", "pytest", "pytest-asyncio"]
|
||||||
|
tests-deser = ["odfpy", "openpyxl", "pyarrow", "tables", "xlrd"]
|
||||||
|
tests-examples = ["nbval", "param[examples]", "pytest (<8.1)", "pytest-asyncio", "pytest-xdist"]
|
||||||
|
tests-full = ["cloudpickle", "gmpy", "ipython", "jsonschema", "nest-asyncio", "numpy", "pandas", "param[tests-deser]", "param[tests-examples]", "param[tests]"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "parso"
|
name = "parso"
|
||||||
version = "0.8.4"
|
version = "0.8.4"
|
||||||
@@ -3310,6 +3553,54 @@ files = [
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
tests = ["pytest"]
|
tests = ["pytest"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyarrow"
|
||||||
|
version = "16.1.0"
|
||||||
|
description = "Python library for Apache Arrow"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"},
|
||||||
|
{file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"},
|
||||||
|
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"},
|
||||||
|
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"},
|
||||||
|
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"},
|
||||||
|
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"},
|
||||||
|
{file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"},
|
||||||
|
{file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"},
|
||||||
|
{file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"},
|
||||||
|
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"},
|
||||||
|
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"},
|
||||||
|
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"},
|
||||||
|
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"},
|
||||||
|
{file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"},
|
||||||
|
{file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"},
|
||||||
|
{file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"},
|
||||||
|
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"},
|
||||||
|
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"},
|
||||||
|
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"},
|
||||||
|
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"},
|
||||||
|
{file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"},
|
||||||
|
{file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"},
|
||||||
|
{file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"},
|
||||||
|
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"},
|
||||||
|
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"},
|
||||||
|
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"},
|
||||||
|
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"},
|
||||||
|
{file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"},
|
||||||
|
{file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"},
|
||||||
|
{file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"},
|
||||||
|
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"},
|
||||||
|
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"},
|
||||||
|
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"},
|
||||||
|
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"},
|
||||||
|
{file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"},
|
||||||
|
{file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
numpy = ">=1.16.6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pycparser"
|
name = "pycparser"
|
||||||
version = "2.22"
|
version = "2.22"
|
||||||
@@ -3444,6 +3735,36 @@ extra = ["bitsandbytes (==0.41.0)", "hydra-core (>=1.0.5)", "jsonargparse[signat
|
|||||||
strategies = ["deepspeed (>=0.8.2,<=0.9.3)"]
|
strategies = ["deepspeed (>=0.8.2,<=0.9.3)"]
|
||||||
test = ["cloudpickle (>=1.3)", "coverage (==7.3.1)", "fastapi", "onnx (>=0.14.0)", "onnxruntime (>=0.15.0)", "pandas (>1.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "uvicorn"]
|
test = ["cloudpickle (>=1.3)", "coverage (==7.3.1)", "fastapi", "onnx (>=0.14.0)", "onnxruntime (>=0.15.0)", "pandas (>1.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "uvicorn"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytz"
|
||||||
|
version = "2024.1"
|
||||||
|
description = "World timezone definitions, modern and historical"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"},
|
||||||
|
{file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyviz-comms"
|
||||||
|
version = "3.0.2"
|
||||||
|
description = "A JupyterLab extension for rendering HoloViz content."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "pyviz_comms-3.0.2-py3-none-any.whl", hash = "sha256:31541b976a21b7738557c3ea23bd8e44e94e736b9ed269570dcc28db4449d7e3"},
|
||||||
|
{file = "pyviz_comms-3.0.2.tar.gz", hash = "sha256:3167df932656416c4bd711205dad47e986a3ebae1f316258ddc26f9e01513ef7"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
param = "*"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
all = ["pyviz-comms[build]", "pyviz-comms[tests]"]
|
||||||
|
build = ["jupyterlab (>=4.0,<5.0)", "keyring", "rfc3986", "setuptools (>=40.8.0)", "twine"]
|
||||||
|
tests = ["flake8", "pytest"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pywin32"
|
name = "pywin32"
|
||||||
version = "306"
|
version = "306"
|
||||||
@@ -4823,6 +5144,17 @@ files = [
|
|||||||
mypy-extensions = ">=0.3.0"
|
mypy-extensions = ">=0.3.0"
|
||||||
typing-extensions = ">=3.7.4"
|
typing-extensions = ">=3.7.4"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tzdata"
|
||||||
|
version = "2024.1"
|
||||||
|
description = "Provider of IANA time zone data"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=2"
|
||||||
|
files = [
|
||||||
|
{file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"},
|
||||||
|
{file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "uc-micro-py"
|
name = "uc-micro-py"
|
||||||
version = "1.0.3"
|
version = "1.0.3"
|
||||||
@@ -5009,6 +5341,17 @@ files = [
|
|||||||
{file = "widgetsnbextension-4.0.10.tar.gz", hash = "sha256:64196c5ff3b9a9183a8e699a4227fb0b7002f252c814098e66c4d1cd0644688f"},
|
{file = "widgetsnbextension-4.0.10.tar.gz", hash = "sha256:64196c5ff3b9a9183a8e699a4227fb0b7002f252c814098e66c4d1cd0644688f"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "xyzservices"
|
||||||
|
version = "2024.4.0"
|
||||||
|
description = "Source of XYZ tiles providers"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "xyzservices-2024.4.0-py3-none-any.whl", hash = "sha256:b83e48c5b776c9969fffcfff57b03d02b1b1cd6607a9d9c4e7f568b01ef47f4c"},
|
||||||
|
{file = "xyzservices-2024.4.0.tar.gz", hash = "sha256:6a04f11487a6fb77d92a98984cd107fbd9157fd5e65f929add9c3d6e604ee88c"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "yarl"
|
name = "yarl"
|
||||||
version = "1.9.4"
|
version = "1.9.4"
|
||||||
@@ -5115,4 +5458,4 @@ multidict = ">=4.0"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "74e84f457ce411d0612153d658f0e43f92f140297fcc1812501b74e7757fedc3"
|
content-hash = "f607472660b04b7f6f5d49a4561730f788a46f0d1e0176322e872111b00481cd"
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ polars = "^0.20.28"
|
|||||||
jupyter = "^1.0.0"
|
jupyter = "^1.0.0"
|
||||||
safetensors = "^0.4.3"
|
safetensors = "^0.4.3"
|
||||||
alive-progress = "^3.1.5"
|
alive-progress = "^3.1.5"
|
||||||
|
hvplot = "^0.10.0"
|
||||||
|
pyarrow = "^16.1.0"
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@@ -40,12 +40,12 @@ class Model(nn.Module):
|
|||||||
x0, x1 = x
|
x0, x1 = x
|
||||||
y0 = self.encode_x0(x0)
|
y0 = self.encode_x0(x0)
|
||||||
y1 = self.encode_x1(x1)
|
y1 = self.encode_x1(x1)
|
||||||
x = torch.cat([y0, y1], dim=1)
|
y = torch.cat([y0, y1], dim=1)
|
||||||
y = self.ff(x)
|
y = self.ff(y)
|
||||||
if self.return_module_y:
|
if self.return_module_y:
|
||||||
return y, y0, y1
|
return x, (y, y0, y1)
|
||||||
else:
|
else:
|
||||||
return y
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
# This is just a quick, lazy way to ensure all models are trained on the same dataset
|
# This is just a quick, lazy way to ensure all models are trained on the same dataset
|
||||||
@@ -59,7 +59,95 @@ def get_singleton_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(loss_func=nn.functional.smooth_l1_loss, logger=None, **kwargs):
|
def smooth_l1_loss(out, y):
|
||||||
|
_, y_pred = out
|
||||||
|
return nn.functional.smooth_l1_loss(y_pred, y)
|
||||||
|
|
||||||
|
|
||||||
|
def sech(x):
|
||||||
|
return torch.reciprocal(torch.cosh(x))
|
||||||
|
|
||||||
|
|
||||||
|
def linear_fit(x, y):
|
||||||
|
mean_x = torch.mean(x)
|
||||||
|
mean_y = torch.mean(y)
|
||||||
|
cov_xy = torch.mean(x * y) - (mean_x * mean_y)
|
||||||
|
var_x = torch.mean(x * x) - (mean_x * mean_x)
|
||||||
|
m = cov_xy / var_x
|
||||||
|
c = mean_y - (m * mean_x)
|
||||||
|
return m, c
|
||||||
|
|
||||||
|
|
||||||
|
def line(x, m, c):
|
||||||
|
return (m * x) + c
|
||||||
|
|
||||||
|
|
||||||
|
def linear_residuals(x, y, m, c):
|
||||||
|
return y - line(x, m, c)
|
||||||
|
|
||||||
|
|
||||||
|
def semantic_loss(x, y_pred, w, a):
|
||||||
|
m, c = linear_fit(x, y_pred)
|
||||||
|
residuals = linear_residuals(x, y_pred, m, c)
|
||||||
|
scaled_residuals = residuals * sech(w * x)
|
||||||
|
slope_penalty = torch.nn.functional.softmax(a * m, dim=0)
|
||||||
|
loss = torch.mean(scaled_residuals**2) + torch.mean(slope_penalty)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def loss(out, y):
|
||||||
|
x, y_pred = out
|
||||||
|
x0, x1 = x
|
||||||
|
|
||||||
|
# Here, we want to make semantic use of the differential electronegativity of the molecule
|
||||||
|
# so start by calculating that
|
||||||
|
mean_electronegativities = torch.tensor(
|
||||||
|
[i[:, 3].mean() for i in x0], dtype=torch.float32
|
||||||
|
).to(y_pred.device)
|
||||||
|
diff_electronegativity = (
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
(i[:, 3] - mean).abs().sum()
|
||||||
|
for i, mean in zip(x0, mean_electronegativities)
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
* 4.0
|
||||||
|
).to(y_pred.device)
|
||||||
|
|
||||||
|
# Then, we need to get a linear best fit on that. Our semantic info is based on a graph of
|
||||||
|
# En (y) vs differential electronegativity on the x vs y axes, so y_pred is y here
|
||||||
|
m, c = linear_fit(diff_electronegativity, y_pred)
|
||||||
|
|
||||||
|
# To start with, we want to calculate a penalty based on deviation from a linear relationship
|
||||||
|
# Scaling is being based on 1/sech(w*r) as this increases multiplier as deviation grows.
|
||||||
|
# `w` was selected based on noting that the residual spread before eneg scaling was about 25;
|
||||||
|
# enegs were normalised as x/4, so we want to incentivize a spread of about 25/4~=6, and w=0.2
|
||||||
|
# causes the penalty function to cross 2 at just over 6. Yes, that's a bit arbitrary but we're
|
||||||
|
# just steering the model not applying hard constraints to it shold be fine.
|
||||||
|
residual_penalty = (
|
||||||
|
(
|
||||||
|
linear_residuals(diff_electronegativity, y_pred, m, c)
|
||||||
|
/ sech(0.2 * diff_electronegativity)
|
||||||
|
)
|
||||||
|
.abs()
|
||||||
|
.float()
|
||||||
|
.mean()
|
||||||
|
)
|
||||||
|
|
||||||
|
# We also need to calculate a penalty that incentivizes a positive slope. For this, im using softmax
|
||||||
|
# to scale the slope as it will penalise negative slopes while not just creating a reward hack for
|
||||||
|
# maximizing slope. The softmax function approximates 1 from about 5 onwards, so if we multiply m by
|
||||||
|
# 500, then our penalty should be almost minimised for any slope above 0.01 and maximised below 0.01.
|
||||||
|
# This should suffice for incentivizing the model to favour positive slopes.
|
||||||
|
slope_penalty = (torch.nn.functional.softmax(-m * 500.0) + 1).mean()
|
||||||
|
|
||||||
|
# Finally, let's get a smooth L1 loss and scale it based on these penalty functions
|
||||||
|
return nn.functional.smooth_l1_loss(y_pred, y) * residual_penalty * slope_penalty
|
||||||
|
|
||||||
|
|
||||||
|
# def main(loss_func=smooth_l1_loss, logger=None, **kwargs):
|
||||||
|
def main(loss_func=loss, logger=None, **kwargs):
|
||||||
import lightning as L
|
import lightning as L
|
||||||
|
|
||||||
from symbolic_nn_tests.train import TrainingWrapper
|
from symbolic_nn_tests.train import TrainingWrapper
|
||||||
@@ -72,7 +160,7 @@ def main(loss_func=nn.functional.smooth_l1_loss, logger=None, **kwargs):
|
|||||||
train, val, test = get_singleton_dataset()
|
train, val, test = get_singleton_dataset()
|
||||||
lmodel = TrainingWrapper(Model(), loss_func=loss_func)
|
lmodel = TrainingWrapper(Model(), loss_func=loss_func)
|
||||||
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
|
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
|
||||||
trainer = L.Trainer(max_epochs=20, logger=logger)
|
trainer = L.Trainer(max_epochs=10, logger=logger, num_sanity_val_steps=0)
|
||||||
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
|
||||||
trainer.test(dataloaders=test)
|
trainer.test(dataloaders=test)
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1,16 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: implement semantic loss functions
|
||||||
|
# These functions would enforce relationships we expect to be present in the results
|
||||||
|
# if the model is performing correctly. The first ones to implement would be:
|
||||||
|
# - En should be proportional to molecular mass
|
||||||
|
# - Bigger molecule = more dof and more strain
|
||||||
|
# - En should be inversely proportional to the differential in electronegativity
|
||||||
|
# - Higher electroneg diff = more stable molecule
|
||||||
|
# - calc diff as `torch.sum( torch.abs( electronegativity - electronegativity.mean() ) )`
|
||||||
|
# Best way to enforce this relationship would probably be to apply a multiplier based on
|
||||||
|
# a normalized sigmoid curve. This would incentivize the model to ensure slope has correct sign
|
||||||
|
# without creating a reward hack for maximizing/minimizing m and preventing exploding gradients.
|
||||||
|
# It also allows us to avoid the assumption of linearity: we only care about the direction of
|
||||||
|
# proportionality.
|
||||||
|
|||||||
20
symbolic_nn_tests/experiment2/train.py
Normal file
20
symbolic_nn_tests/experiment2/train.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from symbolic_nn_tests.train import TrainingWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticModuleTrainingWrapper(TrainingWrapper):
|
||||||
|
def __init__(self, model, *args, loss_func0, loss_func1, loss_agg, **kwargs):
|
||||||
|
assert len(args) == 0
|
||||||
|
|
||||||
|
super().__init__(model, **kwargs)
|
||||||
|
self.loss_func0 = loss_func0
|
||||||
|
self.loss_func1 = loss_func1
|
||||||
|
self.loss_agg = loss_agg
|
||||||
|
|
||||||
|
def _forward_step(self, batch, batch_idx, label=""):
|
||||||
|
x, y = batch
|
||||||
|
y_pred, y0, y1 = self.model(x)
|
||||||
|
loss = self.loss_func(y_pred, y)
|
||||||
|
loss0 = self.loss_func0(y0, x)
|
||||||
|
loss1 = self.loss_func1(y1, x)
|
||||||
|
self.log(f"{label}{'_' if label else ''}loss", loss)
|
||||||
|
return self.loss_agg(loss, loss0, loss1)
|
||||||
Reference in New Issue
Block a user