First working attempt at semantic loss for expt2

This commit is contained in:
2024-06-06 12:40:41 +01:00
parent d2ec5c0c1a
commit 165ef045ea
7 changed files with 477 additions and 7 deletions

1
.gitignore vendored
View File

@@ -164,6 +164,7 @@ cython_debug/
.jj
scratchpad.ipynb
datasets/
explore/
lightning_logs/
logs/
wandb/

1
.mailmap Normal file
View File

@@ -0,0 +1 @@
Cian Hughes <cian.hughes@dcu.ie> <chughes000@gmail.com>

345
poetry.lock generated
View File

@@ -412,6 +412,28 @@ jinxed = {version = ">=1.1.0", markers = "platform_system == \"Windows\""}
six = ">=1.9.0"
wcwidth = ">=0.1.4"
[[package]]
name = "bokeh"
version = "3.4.1"
description = "Interactive plots and applications in the browser from Python"
optional = false
python-versions = ">=3.9"
files = [
{file = "bokeh-3.4.1-py3-none-any.whl", hash = "sha256:1e3c502a0a8205338fc74dadbfa321f8a0965441b39501e36796a47b4017b642"},
{file = "bokeh-3.4.1.tar.gz", hash = "sha256:d824961e4265367b0750ce58b07e564ad0b83ca64b335521cd3421e9b9f10d89"},
]
[package.dependencies]
contourpy = ">=1.2"
Jinja2 = ">=2.9"
numpy = ">=1.16"
packaging = ">=16.8"
pandas = ">=1.2"
pillow = ">=7.1.0"
PyYAML = ">=3.10"
tornado = ">=6.2"
xyzservices = ">=2021.09.1"
[[package]]
name = "bpython"
version = "0.24"
@@ -636,6 +658,25 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "colorcet"
version = "3.1.0"
description = "Collection of perceptually uniform colormaps"
optional = false
python-versions = ">=3.7"
files = [
{file = "colorcet-3.1.0-py3-none-any.whl", hash = "sha256:2a7d59cc8d0f7938eeedd08aad3152b5319b4ba3bcb7a612398cc17a384cb296"},
{file = "colorcet-3.1.0.tar.gz", hash = "sha256:2921b3cd81a2288aaf2d63dbc0ce3c26dcd882e8c389cc505d6886bf7aa9a4eb"},
]
[package.extras]
all = ["colorcet[doc]", "colorcet[examples]", "colorcet[tests-extra]", "colorcet[tests]"]
doc = ["colorcet[examples]", "nbsite (>=0.8.4)", "sphinx-copybutton"]
examples = ["bokeh", "holoviews", "matplotlib", "numpy"]
tests = ["packaging", "pre-commit", "pytest (>=2.8.5)", "pytest-cov"]
tests-examples = ["colorcet[examples]", "nbval"]
tests-extra = ["colorcet[tests]", "pytest-mpl"]
[[package]]
name = "colorlog"
version = "6.8.2"
@@ -1392,6 +1433,43 @@ files = [
{file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"},
]
[[package]]
name = "holoviews"
version = "1.18.3"
description = "Stop plotting your data - annotate your data and let it visualize itself."
optional = false
python-versions = ">=3.9"
files = [
{file = "holoviews-1.18.3-py2.py3-none-any.whl", hash = "sha256:b94b96560b64a84c07e89115aaf9b226e6009684800ec84d3c88cbad122c0c46"},
{file = "holoviews-1.18.3.tar.gz", hash = "sha256:578e30e89d72754f97a83ebe08198fec8e87cc7e49b25b9f31ec393f939ca500"},
]
[package.dependencies]
colorcet = "*"
numpy = ">=1.0"
packaging = "*"
pandas = ">=0.20.0"
panel = ">=1.0"
param = ">=1.12.0,<3.0"
pyviz-comms = ">=0.7.4"
[package.extras]
all = ["bokeh (>=3.1)", "cftime", "codecov", "contourpy", "cudf", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "graphviz", "ibis-framework", "ipython (>=5.4.0)", "matplotlib (>=3)", "myst-nb (<1)", "nbconvert", "nbsite (>=0.8.4,<0.9.0)", "nbval", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "playwright", "plotly (>=4.0)", "pooch", "pre-commit", "pyarrow", "pytest", "pytest-cov", "pytest-github-actions-annotate-failures", "pytest-playwright", "pytest-rerunfailures", "pytest-xdist", "ruff", "scikit-image", "scipy", "scipy (>=1.10)", "selenium", "shapely", "spatialpandas", "streamz (>=0.5.0)", "xarray (>=0.10.4)"]
build = ["param (>=1.7.0)", "pyct (>=0.4.4)", "setuptools (>=30.3.0)"]
doc = ["bokeh (>=3.1)", "cftime", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "graphviz", "ipython (>=5.4.0)", "matplotlib (>=3)", "myst-nb (<1)", "nbsite (>=0.8.4,<0.9.0)", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "plotly (>=4.0)", "pooch", "pyarrow", "scikit-image", "scipy", "selenium", "shapely", "streamz (>=0.5.0)", "xarray (>=0.10.4)"]
examples = ["bokeh (>=3.1)", "cftime", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ipython (>=5.4.0)", "matplotlib (>=3)", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "plotly (>=4.0)", "pooch", "pyarrow", "scikit-image", "scipy", "shapely", "streamz (>=0.5.0)", "xarray (>=0.10.4)"]
examples-tests = ["bokeh (>=3.1)", "cftime", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbval", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "plotly (>=4.0)", "pooch", "pyarrow", "scikit-image", "scipy", "shapely", "streamz (>=0.5.0)", "xarray (>=0.10.4)"]
lint = ["pre-commit", "ruff"]
notebook = ["ipython (>=5.4.0)", "notebook"]
recommended = ["bokeh (>=3.1)", "ipython (>=5.4.0)", "matplotlib (>=3)", "notebook"]
tests = ["bokeh (>=3.1)", "cftime", "contourpy", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ibis-framework", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbconvert", "networkx", "pillow", "plotly (>=4.0)", "pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "scipy (>=1.10)", "selenium", "shapely", "spatialpandas", "xarray (>=0.10.4)"]
tests-ci = ["codecov", "pytest-github-actions-annotate-failures"]
tests-core = ["bokeh (>=3.1)", "contourpy", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbconvert", "pillow", "plotly (>=4.0)", "pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist"]
tests-gpu = ["bokeh (>=3.1)", "cftime", "contourpy", "cudf", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ibis-framework", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbconvert", "networkx", "pillow", "plotly (>=4.0)", "pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "scipy (>=1.10)", "selenium", "shapely", "spatialpandas", "xarray (>=0.10.4)"]
tests-nb = ["nbval"]
ui = ["playwright", "pytest-playwright"]
unit-tests = ["bokeh (>=3.1)", "cftime", "contourpy", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ibis-framework", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbconvert", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "plotly (>=4.0)", "pooch", "pre-commit", "pyarrow", "pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "ruff", "scikit-image", "scipy", "scipy (>=1.10)", "selenium", "shapely", "spatialpandas", "streamz (>=0.5.0)", "xarray (>=0.10.4)"]
[[package]]
name = "httpcore"
version = "1.0.5"
@@ -1437,6 +1515,41 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
[[package]]
name = "hvplot"
version = "0.10.0"
description = "A high-level plotting API for the PyData ecosystem built on HoloViews."
optional = false
python-versions = ">=3.8"
files = [
{file = "hvplot-0.10.0-py3-none-any.whl", hash = "sha256:fe90ccb48163a6a62ae5bd6b008c2cb15cbf5b276f6ad6839ef5470b1c480d16"},
{file = "hvplot-0.10.0.tar.gz", hash = "sha256:e87486a95bfe151ab52ef163a5e93d9cbd043992cf0b755ccadd2bf36fedd376"},
]
[package.dependencies]
bokeh = ">=1.0.0"
colorcet = ">=2"
holoviews = ">=1.11.0"
numpy = ">=1.15"
packaging = "*"
pandas = "*"
panel = ">=0.11.0"
param = ">=1.12.0,<3.0"
[package.extras]
dev-extras = ["setuptools-scm (>=6)"]
doc = ["hvplot[examples]", "nbsite (>=0.8.4)", "sphinxext-rediraffe"]
examples = ["dask[dataframe] (>=2021.3.0)", "datashader (>=0.6.5)", "fugue[sql]", "geodatasets (>=2023.12.0)", "hvplot[fugue-sql]", "ibis-framework[duckdb]", "intake (>=0.6.5,<2.0.0)", "intake-parquet (>=0.2.3)", "intake-xarray (>=0.5.0)", "ipywidgets", "matplotlib", "networkx (>=2.6.3)", "notebook (>=5.4)", "numba (>=0.51.0)", "pillow (>=8.2.0)", "plotly", "polars", "pooch (>=1.6.0)", "s3fs (>=2022.1.0)", "scikit-image (>=0.17.2)", "scipy (>=1.5.3)", "selenium (>=3.141.0)", "streamz (>=0.3.0)", "xarray (>=0.18.2)", "xyzservices (>=2022.9.0)"]
examples-tests = ["hvplot[examples]", "hvplot[tests-nb]"]
fugue-sql = ["fugue-sql-antlr (>=0.2.0)", "jinja2", "qpd (>=0.4.4)", "sqlglot"]
geo = ["cartopy", "fiona", "geopandas", "geoviews (>=1.9.0)", "pyproj", "rasterio", "rioxarray", "spatialpandas (>=0.4.3)"]
graphviz = ["pygraphviz"]
hvdev = ["colorcet (>=0.0.1a1)", "datashader (>=0.0.1a1)", "holoviews (>=0.0.1a1)", "panel (>=0.0.1a1)", "param (>=0.0.1a1)", "pyviz-comms (>=0.0.1a1)"]
hvdev-geo = ["geoviews (>=0.0.1a1)"]
tests = ["fugue[sql]", "hvplot[fugue-sql]", "hvplot[tests-core]", "ibis-framework[duckdb]", "polars"]
tests-core = ["dask[dataframe]", "ipywidgets", "matplotlib", "parameterized", "plotly", "pooch", "pre-commit", "pytest", "pytest-cov", "ruff", "scipy", "xarray"]
tests-nb = ["nbval", "pytest-xdist"]
[[package]]
name = "idna"
version = "3.7"
@@ -3001,6 +3114,78 @@ files = [
{file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
]
[[package]]
name = "pandas"
version = "2.2.2"
description = "Powerful data structures for data analysis, time series, and statistics"
optional = false
python-versions = ">=3.9"
files = [
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
{file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"},
{file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"},
{file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"},
{file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"},
{file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"},
{file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"},
{file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"},
{file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"},
{file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"},
{file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"},
{file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"},
{file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"},
{file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"},
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"},
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
{file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92fd6b027924a7e178ac202cfbe25e53368db90d56872d20ffae94b96c7acc57"},
{file = "pandas-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:640cef9aa381b60e296db324337a554aeeb883ead99dc8f6c18e81a93942f5f4"},
{file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"},
]
[package.dependencies]
numpy = [
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
tzdata = ">=2022.7"
[package.extras]
all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"]
aws = ["s3fs (>=2022.11.0)"]
clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"]
compression = ["zstandard (>=0.19.0)"]
computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"]
consortium-standard = ["dataframe-api-compat (>=0.1.7)"]
excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"]
feather = ["pyarrow (>=10.0.1)"]
fss = ["fsspec (>=2022.11.0)"]
gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"]
hdf5 = ["tables (>=3.8.0)"]
html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"]
mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"]
output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"]
parquet = ["pyarrow (>=10.0.1)"]
performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"]
plot = ["matplotlib (>=3.6.3)"]
postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"]
pyarrow = ["pyarrow (>=10.0.1)"]
spss = ["pyreadstat (>=1.2.0)"]
sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"]
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.9.2)"]
[[package]]
name = "pandocfilters"
version = "1.5.1"
@@ -3012,6 +3197,64 @@ files = [
{file = "pandocfilters-1.5.1.tar.gz", hash = "sha256:002b4a555ee4ebc03f8b66307e287fa492e4a77b4ea14d3f934328297bb4939e"},
]
[[package]]
name = "panel"
version = "1.4.4"
description = "The powerful data exploration & web app framework for Python."
optional = false
python-versions = ">=3.9"
files = [
{file = "panel-1.4.4-py3-none-any.whl", hash = "sha256:b49bb9676567b0c0730bf69348c057247080811aec56364dd4fcfba80e5e09a0"},
{file = "panel-1.4.4.tar.gz", hash = "sha256:659e9fc5b495e6519c5d07e8148fa5eeed9bc648356ec83fc299381ba5a726ef"},
]
[package.dependencies]
bleach = "*"
bokeh = ">=3.4.0,<3.5.0"
linkify-it-py = "*"
markdown = "*"
markdown-it-py = "*"
mdit-py-plugins = "*"
pandas = ">=1.2"
param = ">=2.1.0,<3.0"
pyviz-comms = ">=2.0.0"
requests = "*"
tqdm = ">=4.48.0"
typing-extensions = "*"
xyzservices = ">=2021.09.1"
[package.extras]
all = ["aiohttp", "altair", "anywidget", "channels", "croniter", "dask-expr", "datashader", "diskcache", "django (<4)", "fastparquet", "flake8", "folium", "graphviz", "holoviews (>=1.16.0)", "hvplot", "ipyleaflet", "ipympl", "ipython (>=7.0)", "ipyvolume", "ipyvuetify", "ipywidgets", "ipywidgets-bokeh", "jupyter-bokeh (>=3.0.7)", "jupyter-server", "jupyterlab", "lxml", "matplotlib", "nbsite (>=0.8.4)", "nbval", "networkx (>=2.5)", "numba (<0.58)", "numpy", "pandas (<2.1.0)", "pandas (>=1.3)", "parameterized", "pillow", "playwright", "plotly", "plotly (>=4.0)", "pre-commit", "psutil", "pydeck", "pygraphviz", "pyinstrument (>=4.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-playwright", "pytest-rerunfailures", "pytest-xdist", "python-graphviz", "pyvista", "reacton", "scikit-image", "scikit-learn", "scipy", "seaborn", "streamz", "textual", "tomli", "twine", "vega-datasets", "vtk", "watchfiles", "xarray", "xgboost"]
all-pip = ["aiohttp", "altair", "anywidget", "channels", "croniter", "dask-expr", "datashader", "diskcache", "django (<4)", "fastparquet", "flake8", "folium", "graphviz", "holoviews (>=1.16.0)", "hvplot", "ipyleaflet", "ipympl", "ipython (>=7.0)", "ipyvolume", "ipyvuetify", "ipywidgets", "ipywidgets-bokeh", "jupyter-bokeh (>=3.0.7)", "jupyter-server", "jupyterlab", "lxml", "matplotlib", "nbsite (>=0.8.4)", "nbval", "networkx (>=2.5)", "numba (<0.58)", "numpy", "pandas (<2.1.0)", "pandas (>=1.3)", "parameterized", "pillow", "playwright", "plotly", "plotly (>=4.0)", "pre-commit", "psutil", "pydeck", "pyinstrument (>=4.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-playwright", "pytest-rerunfailures", "pytest-xdist", "pyvista", "reacton", "scikit-image", "scikit-learn", "scipy", "seaborn", "streamz", "textual", "tomli", "twine", "vega-datasets", "vtk", "watchfiles", "xarray", "xgboost"]
build = ["bleach", "bokeh (>=3.4.0,<3.5.0)", "cryptography (<39)", "markdown", "packaging", "param (>=2.0.0)", "pyviz-comms (>=2.0.0)", "requests", "setuptools (>=42)", "tqdm (>=4.48.0)", "urllib3 (<2.0)"]
doc = ["holoviews (>=1.16.0)", "jupyterlab", "lxml", "matplotlib", "nbsite (>=0.8.4)", "pandas (<2.1.0)", "pillow", "plotly"]
examples = ["aiohttp", "altair", "channels", "croniter", "dask-expr", "datashader", "django (<4)", "fastparquet", "folium", "graphviz", "holoviews (>=1.16.0)", "hvplot", "ipyleaflet", "ipympl", "ipyvolume", "ipyvuetify", "ipywidgets", "ipywidgets-bokeh", "jupyter-bokeh (>=3.0.7)", "networkx (>=2.5)", "plotly (>=4.0)", "pydeck", "pygraphviz", "pyinstrument (>=4.0)", "python-graphviz", "pyvista", "reacton", "scikit-image", "scikit-learn", "seaborn", "streamz", "textual", "vega-datasets", "vtk", "xarray", "xgboost"]
recommended = ["holoviews (>=1.16.0)", "jupyterlab", "matplotlib", "pillow", "plotly"]
tests = ["altair", "anywidget", "diskcache", "flake8", "folium", "holoviews (>=1.16.0)", "ipympl", "ipython (>=7.0)", "ipyvuetify", "ipywidgets-bokeh", "nbval", "numba (<0.58)", "numpy", "pandas (>=1.3)", "parameterized", "pre-commit", "psutil", "pytest", "pytest-asyncio", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "reacton", "scipy", "textual", "twine", "watchfiles"]
tests-core = ["altair", "anywidget", "diskcache", "flake8", "folium", "holoviews (>=1.16.0)", "ipython (>=7.0)", "nbval", "numpy", "pandas (>=1.3)", "parameterized", "pre-commit", "psutil", "pytest", "pytest-asyncio", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "scipy", "textual", "watchfiles"]
ui = ["jupyter-server", "playwright", "pytest-playwright", "tomli"]
[[package]]
name = "param"
version = "2.1.0"
description = "Make your Python code clearer and more reliable by declaring Parameters."
optional = false
python-versions = ">=3.8"
files = [
{file = "param-2.1.0-py3-none-any.whl", hash = "sha256:f31d3745d227347d29b5868c4e4e3077df07463889b91d3bb28e634fde211e1c"},
{file = "param-2.1.0.tar.gz", hash = "sha256:a7b30b08b547e2b78b02aeba6ed34e3c6a638f8e4824a76a96ffa2d7cf57e71f"},
]
[package.extras]
all = ["param[doc]", "param[lint]", "param[tests-full]"]
doc = ["nbsite (==0.8.4)", "param[examples]", "sphinx-remove-toctrees"]
examples = ["aiohttp", "pandas", "panel"]
lint = ["flake8", "pre-commit"]
tests = ["coverage[toml]", "pytest", "pytest-asyncio"]
tests-deser = ["odfpy", "openpyxl", "pyarrow", "tables", "xlrd"]
tests-examples = ["nbval", "param[examples]", "pytest (<8.1)", "pytest-asyncio", "pytest-xdist"]
tests-full = ["cloudpickle", "gmpy", "ipython", "jsonschema", "nest-asyncio", "numpy", "pandas", "param[tests-deser]", "param[tests-examples]", "param[tests]"]
[[package]]
name = "parso"
version = "0.8.4"
@@ -3310,6 +3553,54 @@ files = [
[package.extras]
tests = ["pytest"]
[[package]]
name = "pyarrow"
version = "16.1.0"
description = "Python library for Apache Arrow"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"},
{file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"},
{file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"},
{file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"},
{file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"},
{file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"},
{file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"},
{file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"},
{file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"},
{file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"},
{file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"},
{file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"},
{file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"},
{file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"},
{file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"},
{file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"},
]
[package.dependencies]
numpy = ">=1.16.6"
[[package]]
name = "pycparser"
version = "2.22"
@@ -3444,6 +3735,36 @@ extra = ["bitsandbytes (==0.41.0)", "hydra-core (>=1.0.5)", "jsonargparse[signat
strategies = ["deepspeed (>=0.8.2,<=0.9.3)"]
test = ["cloudpickle (>=1.3)", "coverage (==7.3.1)", "fastapi", "onnx (>=0.14.0)", "onnxruntime (>=0.15.0)", "pandas (>1.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "uvicorn"]
[[package]]
name = "pytz"
version = "2024.1"
description = "World timezone definitions, modern and historical"
optional = false
python-versions = "*"
files = [
{file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"},
{file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"},
]
[[package]]
name = "pyviz-comms"
version = "3.0.2"
description = "A JupyterLab extension for rendering HoloViz content."
optional = false
python-versions = ">=3.8"
files = [
{file = "pyviz_comms-3.0.2-py3-none-any.whl", hash = "sha256:31541b976a21b7738557c3ea23bd8e44e94e736b9ed269570dcc28db4449d7e3"},
{file = "pyviz_comms-3.0.2.tar.gz", hash = "sha256:3167df932656416c4bd711205dad47e986a3ebae1f316258ddc26f9e01513ef7"},
]
[package.dependencies]
param = "*"
[package.extras]
all = ["pyviz-comms[build]", "pyviz-comms[tests]"]
build = ["jupyterlab (>=4.0,<5.0)", "keyring", "rfc3986", "setuptools (>=40.8.0)", "twine"]
tests = ["flake8", "pytest"]
[[package]]
name = "pywin32"
version = "306"
@@ -4823,6 +5144,17 @@ files = [
mypy-extensions = ">=0.3.0"
typing-extensions = ">=3.7.4"
[[package]]
name = "tzdata"
version = "2024.1"
description = "Provider of IANA time zone data"
optional = false
python-versions = ">=2"
files = [
{file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"},
{file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"},
]
[[package]]
name = "uc-micro-py"
version = "1.0.3"
@@ -5009,6 +5341,17 @@ files = [
{file = "widgetsnbextension-4.0.10.tar.gz", hash = "sha256:64196c5ff3b9a9183a8e699a4227fb0b7002f252c814098e66c4d1cd0644688f"},
]
[[package]]
name = "xyzservices"
version = "2024.4.0"
description = "Source of XYZ tiles providers"
optional = false
python-versions = ">=3.8"
files = [
{file = "xyzservices-2024.4.0-py3-none-any.whl", hash = "sha256:b83e48c5b776c9969fffcfff57b03d02b1b1cd6607a9d9c4e7f568b01ef47f4c"},
{file = "xyzservices-2024.4.0.tar.gz", hash = "sha256:6a04f11487a6fb77d92a98984cd107fbd9157fd5e65f929add9c3d6e604ee88c"},
]
[[package]]
name = "yarl"
version = "1.9.4"
@@ -5115,4 +5458,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "74e84f457ce411d0612153d658f0e43f92f140297fcc1812501b74e7757fedc3"
content-hash = "f607472660b04b7f6f5d49a4561730f788a46f0d1e0176322e872111b00481cd"

View File

@@ -28,6 +28,8 @@ polars = "^0.20.28"
jupyter = "^1.0.0"
safetensors = "^0.4.3"
alive-progress = "^3.1.5"
hvplot = "^0.10.0"
pyarrow = "^16.1.0"
[build-system]

View File

@@ -40,12 +40,12 @@ class Model(nn.Module):
x0, x1 = x
y0 = self.encode_x0(x0)
y1 = self.encode_x1(x1)
x = torch.cat([y0, y1], dim=1)
y = self.ff(x)
y = torch.cat([y0, y1], dim=1)
y = self.ff(y)
if self.return_module_y:
return y, y0, y1
return x, (y, y0, y1)
else:
return y
return x, y
# This is just a quick, lazy way to ensure all models are trained on the same dataset
@@ -59,7 +59,95 @@ def get_singleton_dataset():
)
def main(loss_func=nn.functional.smooth_l1_loss, logger=None, **kwargs):
def smooth_l1_loss(out, y):
_, y_pred = out
return nn.functional.smooth_l1_loss(y_pred, y)
def sech(x):
return torch.reciprocal(torch.cosh(x))
def linear_fit(x, y):
mean_x = torch.mean(x)
mean_y = torch.mean(y)
cov_xy = torch.mean(x * y) - (mean_x * mean_y)
var_x = torch.mean(x * x) - (mean_x * mean_x)
m = cov_xy / var_x
c = mean_y - (m * mean_x)
return m, c
def line(x, m, c):
return (m * x) + c
def linear_residuals(x, y, m, c):
return y - line(x, m, c)
def semantic_loss(x, y_pred, w, a):
m, c = linear_fit(x, y_pred)
residuals = linear_residuals(x, y_pred, m, c)
scaled_residuals = residuals * sech(w * x)
slope_penalty = torch.nn.functional.softmax(a * m, dim=0)
loss = torch.mean(scaled_residuals**2) + torch.mean(slope_penalty)
return loss
def loss(out, y):
x, y_pred = out
x0, x1 = x
# Here, we want to make semantic use of the differential electronegativity of the molecule
# so start by calculating that
mean_electronegativities = torch.tensor(
[i[:, 3].mean() for i in x0], dtype=torch.float32
).to(y_pred.device)
diff_electronegativity = (
torch.tensor(
[
(i[:, 3] - mean).abs().sum()
for i, mean in zip(x0, mean_electronegativities)
],
dtype=torch.float32,
)
* 4.0
).to(y_pred.device)
# Then, we need to get a linear best fit on that. Our semantic info is based on a graph of
# En (y) vs differential electronegativity on the x vs y axes, so y_pred is y here
m, c = linear_fit(diff_electronegativity, y_pred)
# To start with, we want to calculate a penalty based on deviation from a linear relationship
# Scaling is being based on 1/sech(w*r) as this increases multiplier as deviation grows.
# `w` was selected based on noting that the residual spread before eneg scaling was about 25;
# enegs were normalised as x/4, so we want to incentivize a spread of about 25/4~=6, and w=0.2
# causes the penalty function to cross 2 at just over 6. Yes, that's a bit arbitrary but we're
# just steering the model not applying hard constraints to it shold be fine.
residual_penalty = (
(
linear_residuals(diff_electronegativity, y_pred, m, c)
/ sech(0.2 * diff_electronegativity)
)
.abs()
.float()
.mean()
)
# We also need to calculate a penalty that incentivizes a positive slope. For this, im using softmax
# to scale the slope as it will penalise negative slopes while not just creating a reward hack for
# maximizing slope. The softmax function approximates 1 from about 5 onwards, so if we multiply m by
# 500, then our penalty should be almost minimised for any slope above 0.01 and maximised below 0.01.
# This should suffice for incentivizing the model to favour positive slopes.
slope_penalty = (torch.nn.functional.softmax(-m * 500.0) + 1).mean()
# Finally, let's get a smooth L1 loss and scale it based on these penalty functions
return nn.functional.smooth_l1_loss(y_pred, y) * residual_penalty * slope_penalty
# def main(loss_func=smooth_l1_loss, logger=None, **kwargs):
def main(loss_func=loss, logger=None, **kwargs):
import lightning as L
from symbolic_nn_tests.train import TrainingWrapper
@@ -72,7 +160,7 @@ def main(loss_func=nn.functional.smooth_l1_loss, logger=None, **kwargs):
train, val, test = get_singleton_dataset()
lmodel = TrainingWrapper(Model(), loss_func=loss_func)
lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs)
trainer = L.Trainer(max_epochs=20, logger=logger)
trainer = L.Trainer(max_epochs=10, logger=logger, num_sanity_val_steps=0)
trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val)
trainer.test(dataloaders=test)

View File

@@ -1 +1,16 @@
import torch
# TODO: implement semantic loss functions
# These functions would enforce relationships we expect to be present in the results
# if the model is performing correctly. The first ones to implement would be:
# - En should be proportional to molecular mass
# - Bigger molecule = more dof and more strain
# - En should be inversely proportional to the differential in electronegativity
# - Higher electroneg diff = more stable molecule
# - calc diff as `torch.sum( torch.abs( electronegativity - electronegativity.mean() ) )`
# Best way to enforce this relationship would probably be to apply a multiplier based on
# a normalized sigmoid curve. This would incentivize the model to ensure slope has correct sign
# without creating a reward hack for maximizing/minimizing m and preventing exploding gradients.
# It also allows us to avoid the assumption of linearity: we only care about the direction of
# proportionality.

View File

@@ -0,0 +1,20 @@
from symbolic_nn_tests.train import TrainingWrapper
class SemanticModuleTrainingWrapper(TrainingWrapper):
def __init__(self, model, *args, loss_func0, loss_func1, loss_agg, **kwargs):
assert len(args) == 0
super().__init__(model, **kwargs)
self.loss_func0 = loss_func0
self.loss_func1 = loss_func1
self.loss_agg = loss_agg
def _forward_step(self, batch, batch_idx, label=""):
x, y = batch
y_pred, y0, y1 = self.model(x)
loss = self.loss_func(y_pred, y)
loss0 = self.loss_func0(y0, x)
loss1 = self.loss_func1(y1, x)
self.log(f"{label}{'_' if label else ''}loss", loss)
return self.loss_agg(loss, loss0, loss1)